You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/02 19:02:26 UTC
spark git commit: [SPARK-8223] [SPARK-8224] [SQL] shift left and
shift right
Repository: spark
Updated Branches:
refs/heads/master 0a468a46b -> 5b3338130
[SPARK-8223] [SPARK-8224] [SQL] shift left and shift right
Jira:
https://issues.apache.org/jira/browse/SPARK-8223
https://issues.apache.org/jira/browse/SPARK-8224
~~I am aware of #7174 and will update this pr, if it's merged.~~ Done
I don't know if #7034 can simplify this, but we can have a look on it, if it gets merged
rxin In the Jira ticket the function as no second argument. I added a `numBits` argument that allows to specify the number of bits. I guess this improves the usability. I wanted to add `shiftleft(value)` as well, but the `selectExpr` dataframe tests crashes, if I have both. I order to do this, I added the following to the functions.scala `def shiftRight(e: Column): Column = ShiftRight(e.expr, lit(1).expr)`, but as I mentioned this doesn't pass tests like `df.selectExpr("shiftRight(a)", ...` (not enough arguments exception).
If we need the bitwise shift in order to be hive compatible, I suggest to add `shiftLeft` and something like `shiftLeftX`
Author: Tarek Auel <ta...@googlemail.com>
Closes #7178 from tarekauel/8223 and squashes the following commits:
8023bb5 [Tarek Auel] [SPARK-8223][SPARK-8224] fixed test
f3f64e6 [Tarek Auel] [SPARK-8223][SPARK-8224] Integer -> Int
f628706 [Tarek Auel] [SPARK-8223][SPARK-8224] removed toString; updated function description
3b56f2a [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223
5189690 [Tarek Auel] [SPARK-8223][SPARK-8224] minor fix and style fix
9434a28 [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223
44ee324 [Tarek Auel] [SPARK-8223][SPARK-8224] docu fix
ac7fe9d [Tarek Auel] [SPARK-8223][SPARK-8224] right and left bit shift
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b333813
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b333813
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b333813
Branch: refs/heads/master
Commit: 5b3338130dfd9db92c4894a348839a62ebb57ef3
Parents: 0a468a4
Author: Tarek Auel <ta...@googlemail.com>
Authored: Thu Jul 2 10:02:19 2015 -0700
Committer: Davies Liu <da...@databricks.com>
Committed: Thu Jul 2 10:02:19 2015 -0700
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 24 +++++
.../catalyst/analysis/FunctionRegistry.scala | 2 +
.../spark/sql/catalyst/expressions/math.scala | 98 ++++++++++++++++++++
.../expressions/MathFunctionsSuite.scala | 28 +++++-
.../scala/org/apache/spark/sql/functions.scala | 38 ++++++++
.../apache/spark/sql/MathExpressionsSuite.scala | 34 +++++++
6 files changed, 223 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f9a15d4..bccde60 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -412,6 +412,30 @@ def sha2(col, numBits):
return Column(jc)
+@since(1.5)
+def shiftLeft(col, numBits):
+ """Shift the the given value numBits left.
+
+ >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect()
+ [Row(r=42)]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits)
+ return Column(jc)
+
+
+@since(1.5)
+def shiftRight(col, numBits):
+ """Shift the the given value numBits right.
+
+ >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect()
+ [Row(r=21)]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits)
+ return Column(jc)
+
+
@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6f04298..aa051b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -125,6 +125,8 @@ object FunctionRegistry {
expression[Pow]("power"),
expression[UnaryPositive]("positive"),
expression[Rint]("rint"),
+ expression[ShiftLeft]("shiftleft"),
+ expression[ShiftRight]("shiftright"),
expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 8633eb0..7504c6a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -351,6 +351,104 @@ case class Pow(left: Expression, right: Expression)
}
}
+case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression {
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
+ case (_, IntegerType) => left.dataType match {
+ case LongType | IntegerType | ShortType | ByteType =>
+ return TypeCheckResult.TypeCheckSuccess
+ case _ => // failed
+ }
+ case _ => // failed
+ }
+ TypeCheckResult.TypeCheckFailure(
+ s"ShiftLeft expects long, integer, short or byte value as first argument and an " +
+ s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val valueLeft = left.eval(input)
+ if (valueLeft != null) {
+ val valueRight = right.eval(input)
+ if (valueRight != null) {
+ valueLeft match {
+ case l: Long => l << valueRight.asInstanceOf[Integer]
+ case i: Integer => i << valueRight.asInstanceOf[Integer]
+ case s: Short => s << valueRight.asInstanceOf[Integer]
+ case b: Byte => b << valueRight.asInstanceOf[Integer]
+ }
+ } else {
+ null
+ }
+ } else {
+ null
+ }
+ }
+
+ override def dataType: DataType = {
+ left.dataType match {
+ case LongType => LongType
+ case IntegerType | ShortType | ByteType => IntegerType
+ case _ => NullType
+ }
+ }
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;")
+ }
+}
+
+case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression {
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
+ case (_, IntegerType) => left.dataType match {
+ case LongType | IntegerType | ShortType | ByteType =>
+ return TypeCheckResult.TypeCheckSuccess
+ case _ => // failed
+ }
+ case _ => // failed
+ }
+ TypeCheckResult.TypeCheckFailure(
+ s"ShiftRight expects long, integer, short or byte value as first argument and an " +
+ s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val valueLeft = left.eval(input)
+ if (valueLeft != null) {
+ val valueRight = right.eval(input)
+ if (valueRight != null) {
+ valueLeft match {
+ case l: Long => l >> valueRight.asInstanceOf[Integer]
+ case i: Integer => i >> valueRight.asInstanceOf[Integer]
+ case s: Short => s >> valueRight.asInstanceOf[Integer]
+ case b: Byte => b >> valueRight.asInstanceOf[Integer]
+ }
+ } else {
+ null
+ }
+ } else {
+ null
+ }
+ }
+
+ override def dataType: DataType = {
+ left.dataType match {
+ case LongType => LongType
+ case IntegerType | ShortType | ByteType => IntegerType
+ case _ => NullType
+ }
+ }
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;")
+ }
+}
+
/**
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index b3345d7..aa27fe3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{DataType, DoubleType, LongType}
+import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType}
class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -225,6 +225,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
}
+ test("shift left") {
+ checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null)
+ checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+ checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42)
+ checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42)
+ checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42)
+ checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
+
+ checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong)
+ }
+
+ test("shift right") {
+ checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null)
+ checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+ checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21)
+ checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21)
+ checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21)
+ checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
+
+ checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
+ }
+
test("hex") {
checkEvaluation(Hex(Literal(28)), "1C")
checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index e6f623b..a5b6828 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1299,6 +1299,44 @@ object functions {
def rint(columnName: String): Column = rint(Column(columnName))
/**
+ * Shift the the given value numBits left. If the given value is a long value, this function
+ * will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr)
+
+ /**
+ * Shift the the given value numBits left. If the given value is a long value, this function
+ * will return a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftLeft(columnName: String, numBits: Int): Column =
+ shiftLeft(Column(columnName), numBits)
+
+ /**
+ * Shift the the given value numBits right. If the given value is a long value, it will return
+ * a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
+
+ /**
+ * Shift the the given value numBits right. If the given value is a long value, it will return
+ * a long value else it will return an integer value.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
+ def shiftRight(columnName: String, numBits: Int): Column =
+ shiftRight(Column(columnName), numBits)
+
+ /**
* Computes the signum of the given value.
*
* @group math_funcs
http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index c03cde3..4c5696d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -259,6 +259,40 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
}
+ test("shift left") {
+ val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null))
+ .toDF("a", "b", "c", "d", "e", "f")
+
+ checkAnswer(
+ df.select(
+ shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1),
+ shiftLeft('f, 1)),
+ Row(42.toLong, 42, 42.toShort, 42.toByte, null))
+
+ checkAnswer(
+ df.selectExpr(
+ "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)",
+ "shiftLeft(f, 1)"),
+ Row(42.toLong, 42, 42.toShort, 42.toByte, null))
+ }
+
+ test("shift right") {
+ val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null))
+ .toDF("a", "b", "c", "d", "e", "f")
+
+ checkAnswer(
+ df.select(
+ shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1),
+ shiftRight('f, 1)),
+ Row(21.toLong, 21, 21.toShort, 21.toByte, null))
+
+ checkAnswer(
+ df.selectExpr(
+ "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)",
+ "shiftRight(f, 1)"),
+ Row(21.toLong, 21, 21.toShort, 21.toByte, null))
+ }
+
test("binary log") {
val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
checkAnswer(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org