Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-966339 Support Arguments in TableFunctions #62

Merged
merged 21 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/DataFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,81 @@ public DataFrame join(
JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(orderBy))));
}

/**
* Joins the current DataFrame with the output of the specified table function `func`.
*
* <p>Pre-defined table functions can be found in `TableFunctions` class.
*
* <p>For example:
*
* <pre>{@code
* df.join(TableFunctions.flatten(
* Functions.parse_json(df.col("col")),
* "path", true, true, "both"
* ));
* }</pre>
*
* <p>Or load any Snowflake builtin table function via TableFunction Class.
*
* <pre>{@code
* Map<String, Column> args = new HashMap<>();
* args.put("input", Functions.parse_json(df.col("a")));
* df.join(new TableFunction("flatten").call(args));
* }</pre>
*
* @since 1.10.0
* @param func Column object, which can be one of the values in the TableFunctions class or an
* object that you create from the `new TableFunction("name").call()`.
* @return The result DataFrame
*/
public DataFrame join(Column func) {
return new DataFrame(this.df.join(func.toScalaColumn()));
}

/**
* Joins the current DataFrame with the output of the specified table function `func`.
*
* <p>To specify a PARTITION BY or ORDER BY clause, use the `partitionBy` and `orderBy` arguments.
*
* <p>Pre-defined table functions can be found in `TableFunctions` class.
*
* <p>For example:
*
* <pre>{@code
* df.join(TableFunctions.flatten(
* Functions.parse_json(df.col("col1")),
* "path", true, true, "both"
* ),
* new Column[] {df.col("col2")},
* new Column[] {df.col("col1")}
* );
* }</pre>
*
* <p>Or load any Snowflake builtin table function via TableFunction Class.
*
* <pre>{@code
* Map<String, Column> args = new HashMap<>();
* args.put("input", Functions.parse_json(df.col("col1")));
* df.join(new TableFunction("flatten").call(args),
* new Column[] {df.col("col2")},
* new Column[] {df.col("col1")});
* }</pre>
*
* @since 1.10.0
* @param func Column object, which can be one of the values in the TableFunctions class or an
* object that you create from the `new TableFunction("name").call()`.
* @param partitionBy An array of columns partitioned by.
* @param orderBy An array of columns ordered by.
* @return The result DataFrame
*/
public DataFrame join(Column func, Column[] partitionBy, Column[] orderBy) {
return new DataFrame(
this.df.join(
func.toScalaColumn(),
JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(partitionBy)),
JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(orderBy))));
}

com.snowflake.snowpark.DataFrame getScalaDataFrame() {
return this.df;
}
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,27 @@ public DataFrame tableFunction(TableFunction func, Map<String, Column> args) {
func.getScalaTableFunction(), JavaUtils.javaStringColumnMapToScala(scalaArgs)));
}

/**
* Creates a new DataFrame from the given table function and arguments.
*
* <p>Example
*
* <pre>{@code
* session.tableFunction(TableFunctions.flatten(
* Functions.parse_json(df.col("col")),
* "path", true, true, "both"
* ));
* }</pre>
*
* @since 1.10.0
* @param func Column object, which can be one of the values in the TableFunctions class or an
* object that you create from the `new TableFunction("name").call()`.
* @return The result DataFrame
*/
public DataFrame tableFunction(Column func) {
return new DataFrame(session.tableFunction(func.toScalaColumn()));
}

/**
* Returns a SProcRegistration object that you can use to register Stored Procedures.
*
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/TableFunction.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package com.snowflake.snowpark_java;

import com.snowflake.snowpark.internal.JavaUtils;
import java.util.HashMap;
import java.util.Map;

/**
* Looks up table functions by funcName and returns tableFunction object which can be used in {@code
* DataFrame.join} and {@code Session.tableFunction} methods.
Expand Down Expand Up @@ -38,4 +42,31 @@ com.snowflake.snowpark.TableFunction getScalaTableFunction() {
public String funcName() {
return func.funcName();
}

/**
* Create a Column reference by passing arguments in the TableFunction object.
*
* @param args A list of Column objects representing the arguments of the given table function
* @return A Column reference
* @since 1.10.0
*/
public Column call(Column... args) {
return new Column(this.func.apply(JavaUtils.columnArrayToSeq(Column.toScalaColumnArray(args))));
}

/**
* Create a Column reference by passing arguments in the TableFunction object.
*
* @param args function arguments map of the given table function. Some functions, like flatten,
* have named parameters. use this map to assign values to the corresponding parameters.
* @return A Column reference
* @since 1.10.0
*/
public Column call(Map<String, Column> args) {
Map<String, com.snowflake.snowpark.Column> scalaArgs = new HashMap<>();
for (Map.Entry<String, Column> entry : args.entrySet()) {
scalaArgs.put(entry.getKey(), entry.getValue().toScalaColumn());
}
return new Column(this.func.apply(JavaUtils.javaStringColumnMapToScala(scalaArgs)));
}
}
73 changes: 73 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/TableFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ public static TableFunction split_to_table() {
return new TableFunction(com.snowflake.snowpark.tableFunctions.split_to_table());
}

/**
* This table function splits a string (based on a specified delimiter) and flattens the results
* into rows.
*
* <p>Example
*
* <pre>{@code
* session.tableFunction(TableFunctions.split_to_table(,
* Functions.lit("split by space"), Functions.lit(" ")));
* }</pre>
*
* @since 1.10.0
* @param str Text to be split.
* @param delimiter Text to split string by.
* @return The result Column reference
*/
public static Column split_to_table(Column str, String delimiter) {
return new Column(
com.snowflake.snowpark.tableFunctions.split_to_table(str.toScalaColumn(), delimiter));
}

/**
* Flattens (explodes) compound values into multiple rows.
*
Expand Down Expand Up @@ -77,4 +98,56 @@ public static TableFunction split_to_table() {
public static TableFunction flatten() {
return new TableFunction(com.snowflake.snowpark.tableFunctions.flatten());
}

/**
* Flattens (explodes) compound values into multiple rows.
*
* <p>Example
*
* <pre>{@code
* df.join(TableFunctions.flatten(
* Functions.parse_json(df.col("col")), "path", true, true, "both"));
* }</pre>
*
* @since 1.10.0
* @param input The expression that will be unseated into rows. The expression must be of data
* type VariantType, MapType or ArrayType.
* @param path The path to the element within a VariantType data structure which needs to be
* flattened. Can be a zero-length string (i.e. empty path) if the outermost element is to be
* flattened. Default: Zero-length string (i.e. empty path)
* @param outer If FALSE, any input rows that cannot be expanded, either because they cannot be
* accessed in the path or because they have zero fields or entries, are completely omitted
* from the output. If TRUE, exactly one row is generated for zero-row expansions (with NULL
* in the KEY, INDEX, and VALUE columns).
* @param recursive If FALSE, only the element referenced by PATH is expanded. If TRUE, the
* expansion is performed for all sub-elements recursively. Default: FALSE
* @param mode ("object", "array", or "both") Specifies whether only objects, arrays, or both
* should be flattened.
* @return The result Column reference
*/
public static Column flatten(
Column input, String path, boolean outer, boolean recursive, String mode) {
return new Column(
com.snowflake.snowpark.tableFunctions.flatten(
input.toScalaColumn(), path, outer, recursive, mode));
}

/**
* Flattens (explodes) compound values into multiple rows.
*
* <p>Example
*
* <pre>{@code
* df.join(TableFunctions.flatten(
* Functions.parse_json(df.col("col"))));
* }</pre>
*
* @since 1.10.0
* @param input The expression that will be unseated into rows. The expression must be of data
* type VariantType, MapType or ArrayType.
* @return The result Column reference
*/
public static Column flatten(Column input) {
return new Column(com.snowflake.snowpark.tableFunctions.flatten(input.toScalaColumn()));
}
}
65 changes: 60 additions & 5 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import com.snowflake.snowpark.internal.{Logging, Utils}
import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.types._
import com.github.vertical_blank.sqlformatter.SqlFormatter
import com.snowflake.snowpark.internal.Utils.{TempObjectType, randomNameForTempObject}
import com.snowflake.snowpark.internal.Utils.{
TempObjectType,
getTableFunctionExpression,
randomNameForTempObject
}

import javax.xml.bind.DatatypeConverter
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -1785,7 +1789,7 @@ class DataFrame private[snowpark] (
* @param args A list of arguments to pass to the specified table function.
*/
def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func(args: _*), None)
TableFunctionJoin(this.plan, func.call(args: _*), None)
}

/**
Expand Down Expand Up @@ -1821,7 +1825,7 @@ class DataFrame private[snowpark] (
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
func(args: _*),
func.call(args: _*),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

Expand Down Expand Up @@ -1856,7 +1860,7 @@ class DataFrame private[snowpark] (
* Use this map to specify the parameter names and their corresponding values.
*/
def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func(args), None)
TableFunctionJoin(this.plan, func.call(args), None)
}

/**
Expand Down Expand Up @@ -1899,7 +1903,58 @@ class DataFrame private[snowpark] (
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
func(args),
func.call(args),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

/**
* Joins the current DataFrame with the output of the specified table function `func`.
*
*
* For example:
* {{{
* // The following example uses the flatten function to explode compound values from
* // column 'a' in this DataFrame into multiple columns.
*
* import com.snowflake.snowpark.functions._
* import com.snowflake.snowpark.tableFunctions._
*
* df.join(
* tableFunctions.flatten(parse_json(df("a")))
* )
* }}}
*
* @group transform
* @since 1.10.0
* @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]]
* object or an object that you create from the [[TableFunction.apply()]].
*/
def join(func: Column): DataFrame = withPlan {
TableFunctionJoin(this.plan, getTableFunctionExpression(func), None)
}

/**
* Joins the current DataFrame with the output of the specified user-defined table function
* (UDTF) `func`.
*
* To specify a PARTITION BY or ORDER BY clause, use the `partitionBy` and `orderBy` arguments.
*
* For example:
* {{{
* val tf = session.udtf.registerTemporary(TableFunc1)
* df.join(tf(Map("arg1" -> df("col1")),Seq(df("col2")), Seq(df("col1"))))
* }}}
*
* @group transform
* @since 1.10.0
* @param func [[TableFunction]] object that represents a user-defined table function.
* @param partitionBy A list of columns partitioned by.
* @param orderBy A list of columns ordered by.
*/
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
getTableFunctionExpression(func),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

Expand Down
Loading
Loading