You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/02 19:02:26 UTC

spark git commit: [SPARK-8223] [SPARK-8224] [SQL] shift left and shift right

Repository: spark
Updated Branches:
  refs/heads/master 0a468a46b -> 5b3338130


[SPARK-8223] [SPARK-8224] [SQL] shift left and shift right

Jira:
https://issues.apache.org/jira/browse/SPARK-8223
https://issues.apache.org/jira/browse/SPARK-8224

~~I am aware of #7174 and will update this pr, if it's merged.~~ Done
I don't know if #7034 can simplify this, but we can have a look on it, if it gets merged

rxin In the Jira ticket the function as no second argument. I added a `numBits` argument that allows to specify the number of bits. I guess this improves the usability. I wanted to add `shiftleft(value)` as well, but the `selectExpr` dataframe tests crashes, if I have both. I order to do this, I added the following to the functions.scala `def shiftRight(e: Column): Column = ShiftRight(e.expr, lit(1).expr)`, but as I mentioned this doesn't pass tests like `df.selectExpr("shiftRight(a)", ...` (not enough arguments exception).

If we need the bitwise shift in order to be hive compatible, I suggest to add `shiftLeft` and something like `shiftLeftX`

Author: Tarek Auel <ta...@googlemail.com>

Closes #7178 from tarekauel/8223 and squashes the following commits:

8023bb5 [Tarek Auel] [SPARK-8223][SPARK-8224] fixed test
f3f64e6 [Tarek Auel] [SPARK-8223][SPARK-8224] Integer -> Int
f628706 [Tarek Auel] [SPARK-8223][SPARK-8224] removed toString; updated function description
3b56f2a [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223
5189690 [Tarek Auel] [SPARK-8223][SPARK-8224] minor fix and style fix
9434a28 [Tarek Auel] Merge remote-tracking branch 'origin/master' into 8223
44ee324 [Tarek Auel] [SPARK-8223][SPARK-8224] docu fix
ac7fe9d [Tarek Auel] [SPARK-8223][SPARK-8224] right and left bit shift


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b333813
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b333813
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b333813

Branch: refs/heads/master
Commit: 5b3338130dfd9db92c4894a348839a62ebb57ef3
Parents: 0a468a4
Author: Tarek Auel <ta...@googlemail.com>
Authored: Thu Jul 2 10:02:19 2015 -0700
Committer: Davies Liu <da...@databricks.com>
Committed: Thu Jul 2 10:02:19 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 24 +++++
 .../catalyst/analysis/FunctionRegistry.scala    |  2 +
 .../spark/sql/catalyst/expressions/math.scala   | 98 ++++++++++++++++++++
 .../expressions/MathFunctionsSuite.scala        | 28 +++++-
 .../scala/org/apache/spark/sql/functions.scala  | 38 ++++++++
 .../apache/spark/sql/MathExpressionsSuite.scala | 34 +++++++
 6 files changed, 223 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f9a15d4..bccde60 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -412,6 +412,30 @@ def sha2(col, numBits):
     return Column(jc)
 
 
+@since(1.5)
+def shiftLeft(col, numBits):
+    """Shift the the given value numBits left.
+
+    >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect()
+    [Row(r=42)]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.shiftLeft(_to_java_column(col), numBits)
+    return Column(jc)
+
+
+@since(1.5)
+def shiftRight(col, numBits):
+    """Shift the the given value numBits right.
+
+    >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect()
+    [Row(r=21)]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.shiftRight(_to_java_column(col), numBits)
+    return Column(jc)
+
+
 @since(1.4)
 def sparkPartitionId():
     """A column for partition ID of the Spark task.

http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6f04298..aa051b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -125,6 +125,8 @@ object FunctionRegistry {
     expression[Pow]("power"),
     expression[UnaryPositive]("positive"),
     expression[Rint]("rint"),
+    expression[ShiftLeft]("shiftleft"),
+    expression[ShiftRight]("shiftright"),
     expression[Signum]("sign"),
     expression[Signum]("signum"),
     expression[Sin]("sin"),

http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 8633eb0..7504c6a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -351,6 +351,104 @@ case class Pow(left: Expression, right: Expression)
   }
 }
 
+case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression {
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (left.dataType, right.dataType) match {
+      case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
+      case (_, IntegerType) => left.dataType match {
+        case LongType | IntegerType | ShortType | ByteType =>
+          return TypeCheckResult.TypeCheckSuccess
+        case _ => // failed
+      }
+      case _ => // failed
+    }
+    TypeCheckResult.TypeCheckFailure(
+        s"ShiftLeft expects long, integer, short or byte value as first argument and an " +
+          s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val valueLeft = left.eval(input)
+    if (valueLeft != null) {
+      val valueRight = right.eval(input)
+      if (valueRight != null) {
+        valueLeft match {
+          case l: Long => l << valueRight.asInstanceOf[Integer]
+          case i: Integer => i << valueRight.asInstanceOf[Integer]
+          case s: Short => s << valueRight.asInstanceOf[Integer]
+          case b: Byte => b << valueRight.asInstanceOf[Integer]
+        }
+      } else {
+        null
+      }
+    } else {
+      null
+    }
+  }
+
+  override def dataType: DataType = {
+    left.dataType match {
+      case LongType => LongType
+      case IntegerType | ShortType | ByteType => IntegerType
+      case _ => NullType
+    }
+  }
+
+  override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;")
+  }
+}
+
+case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression {
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (left.dataType, right.dataType) match {
+      case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
+      case (_, IntegerType) => left.dataType match {
+        case LongType | IntegerType | ShortType | ByteType =>
+          return TypeCheckResult.TypeCheckSuccess
+        case _ => // failed
+      }
+      case _ => // failed
+    }
+    TypeCheckResult.TypeCheckFailure(
+          s"ShiftRight expects long, integer, short or byte value as first argument and an " +
+            s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val valueLeft = left.eval(input)
+    if (valueLeft != null) {
+      val valueRight = right.eval(input)
+      if (valueRight != null) {
+        valueLeft match {
+          case l: Long => l >> valueRight.asInstanceOf[Integer]
+          case i: Integer => i >> valueRight.asInstanceOf[Integer]
+          case s: Short => s >> valueRight.asInstanceOf[Integer]
+          case b: Byte => b >> valueRight.asInstanceOf[Integer]
+        }
+      } else {
+        null
+      }
+    } else {
+      null
+    }
+  }
+
+  override def dataType: DataType = {
+    left.dataType match {
+      case LongType => LongType
+      case IntegerType | ShortType | ByteType => IntegerType
+      case _ => NullType
+    }
+  }
+
+  override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;")
+  }
+}
+
 /**
  * Performs the inverse operation of HEX.
  * Resulting characters are returned as a byte array.

http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index b3345d7..aa27fe3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{DataType, DoubleType, LongType}
+import org.apache.spark.sql.types.{IntegerType, DataType, DoubleType, LongType}
 
 class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -225,6 +225,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true)
   }
 
+  test("shift left") {
+    checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null)
+    checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null)
+    checkEvaluation(
+      ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+    checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42)
+    checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42)
+    checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42)
+    checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
+
+    checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong)
+  }
+
+  test("shift right") {
+    checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null)
+    checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null)
+    checkEvaluation(
+      ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+    checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21)
+    checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21)
+    checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21)
+    checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
+
+    checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
+  }
+
   test("hex") {
     checkEvaluation(Hex(Literal(28)), "1C")
     checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")

http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index e6f623b..a5b6828 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1299,6 +1299,44 @@ object functions {
   def rint(columnName: String): Column = rint(Column(columnName))
 
   /**
+   * Shift the the given value numBits left. If the given value is a long value, this function
+   * will return a long value else it will return an integer value.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr)
+
+  /**
+   * Shift the the given value numBits left. If the given value is a long value, this function
+   * will return a long value else it will return an integer value.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def shiftLeft(columnName: String, numBits: Int): Column =
+    shiftLeft(Column(columnName), numBits)
+
+  /**
+   * Shift the the given value numBits right. If the given value is a long value, it will return
+   * a long value else it will return an integer value.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
+
+  /**
+   * Shift the the given value numBits right. If the given value is a long value, it will return
+   * a long value else it will return an integer value.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def shiftRight(columnName: String, numBits: Int): Column =
+    shiftRight(Column(columnName), numBits)
+
+  /**
    * Computes the signum of the given value.
    *
    * @group math_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/5b333813/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index c03cde3..4c5696d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -259,6 +259,40 @@ class MathExpressionsSuite extends QueryTest {
     testOneToOneNonNegativeMathFunction(log1p, math.log1p)
   }
 
+  test("shift left") {
+    val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((21, 21, 21, 21, 21, null))
+      .toDF("a", "b", "c", "d", "e", "f")
+
+    checkAnswer(
+      df.select(
+        shiftLeft('a, 1), shiftLeft('b, 1), shiftLeft('c, 1), shiftLeft('d, 1),
+        shiftLeft('f, 1)),
+        Row(42.toLong, 42, 42.toShort, 42.toByte, null))
+
+    checkAnswer(
+      df.selectExpr(
+        "shiftLeft(a, 1)", "shiftLeft(b, 1)", "shiftLeft(b, 1)", "shiftLeft(d, 1)",
+        "shiftLeft(f, 1)"),
+      Row(42.toLong, 42, 42.toShort, 42.toByte, null))
+  }
+
+  test("shift right") {
+    val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((42, 42, 42, 42, 42, null))
+      .toDF("a", "b", "c", "d", "e", "f")
+
+    checkAnswer(
+      df.select(
+        shiftRight('a, 1), shiftRight('b, 1), shiftRight('c, 1), shiftRight('d, 1),
+        shiftRight('f, 1)),
+      Row(21.toLong, 21, 21.toShort, 21.toByte, null))
+
+    checkAnswer(
+      df.selectExpr(
+        "shiftRight(a, 1)", "shiftRight(b, 1)", "shiftRight(c, 1)", "shiftRight(d, 1)",
+        "shiftRight(f, 1)"),
+      Row(21.toLong, 21, 21.toShort, 21.toByte, null))
+  }
+
   test("binary log") {
     val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
     checkAnswer(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org