Skip to content

Commit

Permalink
SNOW-966358 Support TableFunction in DataFrame Select() (#65)
Browse files Browse the repository at this point in the history
* add error code

tf in select

refactor dataframe join

rename df join

* fix error

* add test

* fix test

* address comments
  • Loading branch information
sfc-gh-bli authored Nov 21, 2023
1 parent 84e12fe commit 327882f
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 35 deletions.
95 changes: 63 additions & 32 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -561,18 +561,35 @@ class DataFrame private[snowpark] (
"Provide at least one column expression for select(). " +
s"This DataFrame has column names (${output.length}): " +
s"${output.map(_.name).mkString(", ")}\n")

val resultDF = withPlan { Project(columns.map(_.named), plan) }
// do not rename back if this project contains internal alias.
// because no named duplicated if just renamed.
val hasInternalAlias: Boolean = columns.map(_.expr).exists {
case Alias(_, _, true) => true
case _ => false
}
if (hasInternalAlias) {
resultDF
} else {
renameBackIfDeduped(resultDF)
// todo: error message
val tf = columns.filter(_.expr.isInstanceOf[TableFunctionExpression])
tf.size match {
case 0 => // no table function
val resultDF = withPlan {
Project(columns.map(_.named), plan)
}
// do not rename back if this project contains internal alias.
// because no named duplicated if just renamed.
val hasInternalAlias: Boolean = columns.map(_.expr).exists {
case Alias(_, _, true) => true
case _ => false
}
if (hasInternalAlias) {
resultDF
} else {
renameBackIfDeduped(resultDF)
}
case 1 => // 1 table function
val base = this.join(tf.head)
val baseColumns = base.schema.map(field => base(field.name))
val inputDFColumnSize = this.schema.size
val tfColumns = baseColumns.splitAt(inputDFColumnSize)._2
val (beforeTf, afterTf) = columns.span(_ != tf.head)
val resultColumns = beforeTf ++ tfColumns ++ afterTf.tail
base.select(resultColumns)
case _ =>
// more than 1 TF
throw ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT()
}
}

Expand Down Expand Up @@ -1788,9 +1805,8 @@ class DataFrame private[snowpark] (
* object or an object that you create from the [[TableFunction]] class.
* @param args A list of arguments to pass to the specified table function.
*/
def join(func: TableFunction, args: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func.call(args: _*), None)
}
def join(func: TableFunction, args: Seq[Column]): DataFrame =
joinTableFunction(func.call(args: _*), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table
Expand Down Expand Up @@ -1822,12 +1838,10 @@ class DataFrame private[snowpark] (
func: TableFunction,
args: Seq[Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
orderBy: Seq[Column]): DataFrame =
joinTableFunction(
func.call(args: _*),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

/**
* Joins the current DataFrame with the output of the specified table function `func` that takes
Expand Down Expand Up @@ -1859,9 +1873,8 @@ class DataFrame private[snowpark] (
* Some functions, like `flatten`, have named parameters.
* Use this map to specify the parameter names and their corresponding values.
*/
def join(func: TableFunction, args: Map[String, Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, func.call(args), None)
}
def join(func: TableFunction, args: Map[String, Column]): DataFrame =
joinTableFunction(func.call(args), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table function
Expand Down Expand Up @@ -1900,12 +1913,10 @@ class DataFrame private[snowpark] (
func: TableFunction,
args: Map[String, Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
orderBy: Seq[Column]): DataFrame =
joinTableFunction(
func.call(args),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

/**
* Joins the current DataFrame with the output of the specified table function `func`.
Expand All @@ -1929,9 +1940,8 @@ class DataFrame private[snowpark] (
* @param func [[TableFunction]] object, which can be one of the values in the [[tableFunctions]]
* object or an object that you create from the [[TableFunction.apply()]].
*/
def join(func: Column): DataFrame = withPlan {
TableFunctionJoin(this.plan, getTableFunctionExpression(func), None)
}
def join(func: Column): DataFrame =
joinTableFunction(getTableFunctionExpression(func), None)

/**
* Joins the current DataFrame with the output of the specified user-defined table function
Expand All @@ -1951,11 +1961,32 @@ class DataFrame private[snowpark] (
* @param partitionBy A list of columns partitioned by.
* @param orderBy A list of columns ordered by.
*/
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(
this.plan,
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame =
joinTableFunction(
getTableFunctionExpression(func),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))

private def joinTableFunction(
func: TableFunctionExpression,
partitionByOrderBy: Option[WindowSpecDefinition]): DataFrame = {
val originalResult = withPlan {
TableFunctionJoin(this.plan, func, partitionByOrderBy)
}
val resultSchema = originalResult.schema
val columnNames = resultSchema.map(_.name)
// duplicated names
val dup = columnNames.diff(columnNames.distinct).distinct.map(quoteName)
// guarantee no duplicated names in the result
if (dup.nonEmpty) {
val dfPrefix = DataFrame.generatePrefix('o')
val renamedDf =
this.select(this.output.map(_.name).map(aliasIfNeeded(this, _, dfPrefix, dup.toSet)))
withPlan {
TableFunctionJoin(renamedDf.plan, func, partitionByOrderBy)
}
} else {
originalResult
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private[snowpark] object ErrorMessage {
"0128" -> "DataFrameWriter doesn't support to set option '%s' as '%s' in '%s' mode when writing to a %s.",
"0129" -> "DataFrameWriter doesn't support mode '%s' when writing to a %s.",
"0130" -> "Unsupported join operations, Dataframes can join with other Dataframes or TableFunctions only",
"0131" -> "At most one table function can be called inside select() function",
// Begin to define UDF related messages
"0200" -> "Incorrect number of arguments passed to the UDF: Expected: %d, Found: %d",
"0201" -> "Attempted to call an unregistered UDF. You must register the UDF before calling it.",
Expand Down Expand Up @@ -244,6 +245,9 @@ private[snowpark] object ErrorMessage {
def DF_JOIN_WITH_WRONG_ARGUMENT(): SnowparkClientException =
createException("0130")

def DF_MORE_THAN_ONE_TF_IN_SELECT(): SnowparkClientException =
createException("0131")

/*
* 2NN: UDF error code
*/
Expand Down
6 changes: 3 additions & 3 deletions src/test/java/com/snowflake/snowpark_test/JavaUDTFSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void basicTypes() {
"create or replace temp table "
+ tableName
+ "(i1 smallint, i2 int, l1 bigint, f1 float, d1 double, "
+ "decimal number(38, 18), b boolean, s string, bi binary)";
+ "de number(38, 18), b boolean, s string, bi binary)";
runQuery(crt);
String insert =
"insert into "
Expand All @@ -68,7 +68,7 @@ public void basicTypes() {
col("l1"),
col("f1"),
col("d1"),
col("decimal"),
col("de"),
col("b"),
col("s"),
col("bi"));
Expand All @@ -82,7 +82,7 @@ public void basicTypes() {
.append("|--L1: Long (nullable = true)")
.append("|--F1: Double (nullable = true)")
.append("|--D1: Double (nullable = true)")
.append("|--DECIMAL: Decimal(38, 18) (nullable = true)")
.append("|--DE: Decimal(38, 18) (nullable = true)")
.append("|--B: Boolean (nullable = true)")
.append("|--S: String (nullable = true)")
.append("|--BI: Binary (nullable = true)")
Expand Down
8 changes: 8 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ class ErrorMessageSuite extends FunSuite {
" or TableFunctions only"))
}

test("DF_MORE_THAN_ONE_TF_IN_SELECT") {
val ex = ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT()
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0131")))
assert(
ex.message.startsWith("Error Code: 0131, Error message: " +
"At most one table function can be called inside select() function"))
}

test("UDF_INCORRECT_ARGS_NUMBER") {
val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2)
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,53 @@ class TableFunctionSuite extends TestData {
.select("value"),
Seq(Row("77"), Row("88")))
}

test("table function in select") {
val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "data")
// only tf
val result1 = df.select(tableFunctions.split_to_table(df("data"), ","))
assert(result1.schema.map(_.name) == Seq("SEQ", "INDEX", "VALUE"))
checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4")))

// columns + tf
val result2 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","))
assert(result2.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE"))
checkAnswer(
result2,
Seq(Row(1, 1, 1, "1"), Row(1, 1, 2, "2"), Row(2, 2, 1, "3"), Row(2, 2, 2, "4")))

// columns + tf + columns
val result3 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","), df("idx"))
assert(result3.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE", "IDX"))
checkAnswer(
result3,
Seq(Row(1, 1, 1, "1", 1), Row(1, 1, 2, "2", 1), Row(2, 2, 1, "3", 2), Row(2, 2, 2, "4", 2)))

// tf + other express
val result4 = df.select(tableFunctions.split_to_table(df("data"), ","), df("idx") + 100)
checkAnswer(
result4,
Seq(Row(1, 1, "1", 101), Row(1, 2, "2", 101), Row(2, 1, "3", 102), Row(2, 2, "4", 102)))
}

test("table function join with duplicated column name") {
val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "value")
val result = df.join(tableFunctions.split_to_table(df("value"), lit(",")))
// only one VALUE in the result
checkAnswer(result.select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("4")))
checkAnswer(result.select(result("value")), Seq(Row("1"), Row("2"), Row("3"), Row("4")))
checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4")))
}

test("table function select with duplicated column name") {
val df = Seq((1, "1,2"), (2, "3,4")).toDF("idx", "value")
val result1 = df.select(tableFunctions.split_to_table(df("value"), lit(",")))
checkAnswer(result1, Seq(Row(1, 1, "1"), Row(1, 2, "2"), Row(2, 1, "3"), Row(2, 2, "4")))
val result = df.select(df("value"), tableFunctions.split_to_table(df("value"), lit(",")))
// only one VALUE in the result
checkAnswer(result.select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("4")))
checkAnswer(result.select(result("value")), Seq(Row("1"), Row("2"), Row("3"), Row("4")))
checkAnswer(result.select(df("value")), Seq(Row("1,2"), Row("1,2"), Row("3,4"), Row("3,4")))
}

}

0 comments on commit 327882f

Please sign in to comment.