diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 1ca240fc..f84e6082 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -4494,6 +4494,97 @@ public static Column randn(long seed) { return new Column(functions.randn(seed)); } + /** + * Shift the given value numBits left. If the given value is a long value, this function will + * return a long value else it will return an integer value. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
+   * df.select(Functions.shiftleft(Functions.col("a"), 1).as("shiftleft")).show();
+   * ---------------
+   * |"SHIFTLEFT"  |
+   * ---------------
+   * |2            |
+   * |4            |
+   * |6            |
+   * ---------------
+   * }
+ * + * @since 1.14.0 + * @return Column object. + */ + public static Column shiftleft(Column c, int numBits) { + return new Column(functions.shiftleft(c.toScalaColumn(), numBits)); + } + + /** + * Shift the given value numBits right. If the given value is a long value, it will return a long + * value else it will return an integer value. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
+   * df.select(Functions.shiftright(Functions.col("a"), 1).as("shiftright")).show();
+   * ---------------
+   * |"SHIFTRIGHT"  |
+   * ---------------
+   * |0            |
+   * |1            |
+   * |1            |
+   * ---------------
+   * }
+ * + * @since 1.14.0 + * @return Column object. + */ + public static Column shiftright(Column c, int numBits) { + return new Column(functions.shiftright(c.toScalaColumn(), numBits)); + } + + /** + * Computes hex value of the given column. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
+   * df.select(Functions.hex(Functions.col("a")).as("hex")).show();
+   * ---------
+   * |"HEX"  |
+   * ---------
+   * |31     |
+   * |32     |
+   * |33     |
+   * ---------
+   * }
+ * + * @since 1.14.0 + * @return Column object. + */ + public static Column hex(Column c) { + return new Column(functions.hex(c.toScalaColumn())); + } + + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the + * byte representation of number. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values(31),(32),(33) as T(a)");
+   * df.select(Functions.unhex(Functions.col("a")).as("unhex")).show();
+   * -----------
+   * |"UNHEX"  |
+   * -----------
+   * |1        |
+   * |2        |
+   * |3        |
+   * -----------
+   * }
+ * + * @since 1.14.0 + * @return Column object. + */ + public static Column unhex(Column c) { + return new Column(functions.unhex(c.toScalaColumn())); + } + /** * Calls a user-defined function (UDF) by name. * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index fd8b7ab1..fdfc3189 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -3790,7 +3790,7 @@ object functions { * from the standard normal distribution. * Calls to the Snowflake RANDOM function. * NOTE: Snowflake returns integers of 17-19 digits. - *Example + * Example * {{{ * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a") * df.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed").show() @@ -3810,6 +3810,99 @@ object functions { def randn(seed: Long): Column = builtin("RANDOM")(seed) + /** + * Shift the given value numBits left. If the given value is a long value, + * this function will return a long value else it will return an integer value. + * Example + * {{{ + * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + * df.select(shiftleft(col("A"), 1).as("shiftleft")).show() + * --------------- + * |"SHIFTLEFT" | + * --------------- + * |2 | + * |4 | + * |6 | + * --------------- + * }}} + * + * @since 1.14.0 + * @param c Column to modify. + * @param numBits Number of bits to shift. + * @return Column object. + */ + def shiftleft(c: Column, numBits: Int): Column = + bitshiftleft(c, lit(numBits)) + + /** + * Shift the given value numBits right. If the given value is a long value, + * it will return a long value else it will return an integer value. + * Example + * {{{ + * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + * df.select(shiftright(col("A"), 1).as("shiftright")).show() + * ---------------- + * |"SHIFTRIGHT" | + * ---------------- + * |0 | + * |1 | + * |1 | + * ---------------- + * }}} + * + * @since 1.14.0 + * @param c Column to modify. + * @param numBits Number of bits to shift. + * @return Column object. + */ + def shiftright(c: Column, numBits: Int): Column = + bitshiftright(c, lit(numBits)) + + /** + * Computes hex value of the given column. + * Example + * {{{ + * val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + * df.withColumn("hex_col", hex(col("A"))).select("hex_col").show() + * ------------- + * |"HEX_COL" | + * ------------- + * |31 | + * |32 | + * |33 | + * ------------- + * }}} + * + * @since 1.14.0 + * @param c Column to encode. + * @return Encoded string. + */ + def hex(c: Column): Column = + builtin("HEX_ENCODE")(c) + + /** + * Inverse of hex. Interprets each pair of characters as a hexadecimal number + * and converts to the byte representation of number. + * Example + * {{{ + * val df = session.createDataFrame(Seq((31), (32), (33))).toDF("a") + * df.withColumn("unhex_col", unhex(col("A"))).select("unhex_col").show() + * --------------- + * |"UNHEX_COL" | + * --------------- + * |1 | + * |2 | + * |3 | + * --------------- + * }}} + * + * @param c Column to encode. + * @since 1.14.0 + * @return Encoded string. + */ + def unhex(c: Column): Column = + builtin("HEX_DECODE_STRING")(c) + /** * Invokes a built-in snowflake function with the specified name and arguments. diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 8b3369f2..e34cee94 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -3007,4 +3007,33 @@ public void randn_seed() { expected, false); } + + @Test + public void shiftLeft() { + DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)"); + Row[] expected = {Row.create(2), Row.create(4), Row.create(6)}; + checkAnswer(df.select(Functions.shiftleft(Functions.col("a"), 1)), expected, false); + } + + @Test + public void shiftRight() { + DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)"); + Row[] expected = {Row.create(0), Row.create(1), Row.create(1)}; + checkAnswer(df.select(Functions.shiftright(Functions.col("a"), 1)), expected, false); + } + + @Test + public void hex() { + DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)"); + df.select(Functions.hex(Functions.col("a")).as("hex")).show(); + Row[] expected = {Row.create("31"), Row.create("32"), Row.create("33")}; + checkAnswer(df.select(Functions.hex(Functions.col("a"))), expected, false); + } + + @Test + public void unhex() { + DataFrame df = getSession().sql("select * from values(31),(32),(33) as T(a)"); + Row[] expected = {Row.create("1"), Row.create("2"), Row.create("3")}; + checkAnswer(df.select(Functions.unhex(Functions.col("a"))), expected, false); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index b2d21880..3ae6372f 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -2379,6 +2379,32 @@ trait FunctionSuite extends TestData { assert(input.withColumn("randn", randn()).select("randn").first() != null) } + test("shiftleft") { + val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + checkAnswer(input.select(shiftleft(col("A"), 1)), Seq(Row(2), Row(4), Row(6)), sort = false) + } + + test("shiftright") { + val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + checkAnswer(input.select(shiftright(col("A"), 1)), Seq(Row(0), Row(1), Row(1)), sort = false) + } + + test("hex") { + val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a") + checkAnswer( + input.withColumn("hex_col", hex(col("A"))).select("hex_col"), + Seq(Row("31"), Row("32"), Row("33")), + sort = false) + } + + test("unhex") { + val input = session.createDataFrame(Seq((31), (32), (33))).toDF("a") + checkAnswer( + input.withColumn("unhex_col", unhex(col("A"))).select("unhex_col"), + Seq(Row("1"), Row("2"), Row("3")), + sort = false) + } + } class EagerFunctionSuite extends FunctionSuite with EagerSession