Skip to content

Commit

Permalink
add java and scala functions for shiftleft, shiftright, hex and unhex (
Browse files Browse the repository at this point in the history
…#152)

* add java and scala shiftleft, shiftright, hex and unhex functions

* rebuild

* update comments
  • Loading branch information
sfc-gh-gmahadevan authored Aug 27, 2024
1 parent 3ab8a52 commit 2258be0
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 1 deletion.
91 changes: 91 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -4494,6 +4494,97 @@ public static Column randn(long seed) {
return new Column(functions.randn(seed));
}

/**
* Shift the given value numBits left. If the given value is a long value, this function will
* return a long value else it will return an integer value.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
* df.select(Functions.shiftleft(Functions.col("a"), 1).as("shiftleft")).show();
* ---------------
* |"SHIFTLEFT" |
* ---------------
* |2 |
* |4 |
* |6 |
* ---------------
* }</pre>
*
* @since 1.14.0
* @return Column object.
*/
public static Column shiftleft(Column c, int numBits) {
return new Column(functions.shiftleft(c.toScalaColumn(), numBits));
}

/**
* Shift the given value numBits right. If the given value is a long value, it will return a long
* value else it will return an integer value.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
* df.select(Functions.shiftright(Functions.col("a"), 1).as("shiftright")).show();
* ---------------
* |"SHIFTRIGHT" |
* ---------------
* |0 |
* |1 |
* |1 |
* ---------------
* }</pre>
*
* @since 1.14.0
* @return Column object.
*/
public static Column shiftright(Column c, int numBits) {
return new Column(functions.shiftright(c.toScalaColumn(), numBits));
}

/**
* Computes hex value of the given column.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
* df.select(Functions.hex(Functions.col("a")).as("hex")).show();
* ---------
* |"HEX" |
* ---------
* |31 |
* |32 |
* |33 |
* ---------
* }</pre>
*
* @since 1.14.0
* @return Column object.
*/
public static Column hex(Column c) {
return new Column(functions.hex(c.toScalaColumn()));
}

/**
* Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the
* byte representation of number.
*
* <pre>{@code
* DataFrame df = getSession().sql("select * from values(31),(32),(33) as T(a)");
* df.select(Functions.unhex(Functions.col("a")).as("unhex")).show();
* -----------
* |"UNHEX" |
* -----------
* |1 |
* |2 |
* |3 |
* -----------
* }</pre>
*
* @since 1.14.0
* @return Column object.
*/
public static Column unhex(Column c) {
return new Column(functions.unhex(c.toScalaColumn()));
}

/**
* Calls a user-defined function (UDF) by name.
*
Expand Down
95 changes: 94 additions & 1 deletion src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3790,7 +3790,7 @@ object functions {
* from the standard normal distribution.
* Calls to the Snowflake RANDOM function.
* NOTE: Snowflake returns integers of 17-19 digits.
*Example
* Example
* {{{
* val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
* df.withColumn("randn_with_seed", randn(123L)).select("randn_with_seed").show()
Expand All @@ -3810,6 +3810,99 @@ object functions {
def randn(seed: Long): Column =
builtin("RANDOM")(seed)

/**
* Shift the given value numBits left. If the given value is a long value,
* this function will return a long value else it will return an integer value.
* Example
* {{{
* val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
* df.select(shiftleft(col("A"), 1).as("shiftleft")).show()
* ---------------
* |"SHIFTLEFT" |
* ---------------
* |2 |
* |4 |
* |6 |
* ---------------
* }}}
*
* @since 1.14.0
* @param c Column to modify.
* @param numBits Number of bits to shift.
* @return Column object.
*/
def shiftleft(c: Column, numBits: Int): Column =
bitshiftleft(c, lit(numBits))

/**
* Shift the given value numBits right. If the given value is a long value,
* it will return a long value else it will return an integer value.
* Example
* {{{
* val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
* df.select(shiftright(col("A"), 1).as("shiftright")).show()
* ----------------
* |"SHIFTRIGHT" |
* ----------------
* |0 |
* |1 |
* |1 |
* ----------------
* }}}
*
* @since 1.14.0
* @param c Column to modify.
* @param numBits Number of bits to shift.
* @return Column object.
*/
def shiftright(c: Column, numBits: Int): Column =
bitshiftright(c, lit(numBits))

/**
* Computes hex value of the given column.
* Example
* {{{
* val df = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
* df.withColumn("hex_col", hex(col("A"))).select("hex_col").show()
* -------------
* |"HEX_COL" |
* -------------
* |31 |
* |32 |
* |33 |
* -------------
* }}}
*
* @since 1.14.0
* @param c Column to encode.
* @return Encoded string.
*/
def hex(c: Column): Column =
builtin("HEX_ENCODE")(c)

/**
* Inverse of hex. Interprets each pair of characters as a hexadecimal number
* and converts to the byte representation of number.
* Example
* {{{
* val df = session.createDataFrame(Seq((31), (32), (33))).toDF("a")
* df.withColumn("unhex_col", unhex(col("A"))).select("unhex_col").show()
* ---------------
* |"UNHEX_COL" |
* ---------------
* |1 |
* |2 |
* |3 |
* ---------------
* }}}
*
* @param c Column to encode.
* @since 1.14.0
* @return Encoded string.
*/
def unhex(c: Column): Column =
builtin("HEX_DECODE_STRING")(c)

/**
* Invokes a built-in snowflake function with the specified name and arguments.
Expand Down
29 changes: 29 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -3007,4 +3007,33 @@ public void randn_seed() {
expected,
false);
}

@Test
public void shiftLeft() {
DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
Row[] expected = {Row.create(2), Row.create(4), Row.create(6)};
checkAnswer(df.select(Functions.shiftleft(Functions.col("a"), 1)), expected, false);
}

@Test
public void shiftRight() {
DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
Row[] expected = {Row.create(0), Row.create(1), Row.create(1)};
checkAnswer(df.select(Functions.shiftright(Functions.col("a"), 1)), expected, false);
}

@Test
public void hex() {
DataFrame df = getSession().sql("select * from values(1),(2),(3) as T(a)");
df.select(Functions.hex(Functions.col("a")).as("hex")).show();
Row[] expected = {Row.create("31"), Row.create("32"), Row.create("33")};
checkAnswer(df.select(Functions.hex(Functions.col("a"))), expected, false);
}

@Test
public void unhex() {
DataFrame df = getSession().sql("select * from values(31),(32),(33) as T(a)");
Row[] expected = {Row.create("1"), Row.create("2"), Row.create("3")};
checkAnswer(df.select(Functions.unhex(Functions.col("a"))), expected, false);
}
}
26 changes: 26 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2379,6 +2379,32 @@ trait FunctionSuite extends TestData {
assert(input.withColumn("randn", randn()).select("randn").first() != null)
}

test("shiftleft") {
val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
checkAnswer(input.select(shiftleft(col("A"), 1)), Seq(Row(2), Row(4), Row(6)), sort = false)
}

test("shiftright") {
val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
checkAnswer(input.select(shiftright(col("A"), 1)), Seq(Row(0), Row(1), Row(1)), sort = false)
}

test("hex") {
val input = session.createDataFrame(Seq((1), (2), (3))).toDF("a")
checkAnswer(
input.withColumn("hex_col", hex(col("A"))).select("hex_col"),
Seq(Row("31"), Row("32"), Row("33")),
sort = false)
}

test("unhex") {
val input = session.createDataFrame(Seq((31), (32), (33))).toDF("a")
checkAnswer(
input.withColumn("unhex_col", unhex(col("A"))).select("unhex_col"),
Seq(Row("1"), Row("2"), Row("3")),
sort = false)
}

}

class EagerFunctionSuite extends FunctionSuite with EagerSession
Expand Down

0 comments on commit 2258be0

Please sign in to comment.