You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/04 00:39:19 UTC

spark git commit: [SPARK-8226] [SQL] Add function shiftrightunsigned

Repository: spark
Updated Branches:
  refs/heads/master 2848f4da4 -> ab535b9a1


[SPARK-8226] [SQL] Add function shiftrightunsigned

Author: zhichao.li <zh...@intel.com>

Closes #7035 from zhichao-li/shiftRightUnsigned and squashes the following commits:

6bcca5a [zhichao.li] change coding style
3e9f5ae [zhichao.li] python style
d85ae0b [zhichao.li] add shiftrightunsigned


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab535b9a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab535b9a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab535b9a

Branch: refs/heads/master
Commit: ab535b9a1dab40ea7335ff9abb9b522fc2b5ed66
Parents: 2848f4d
Author: zhichao.li <zh...@intel.com>
Authored: Fri Jul 3 15:39:16 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Fri Jul 3 15:39:16 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 13 ++++++
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../spark/sql/catalyst/expressions/math.scala   | 49 ++++++++++++++++++++
 .../expressions/MathFunctionsSuite.scala        | 13 ++++++
 .../scala/org/apache/spark/sql/functions.scala  | 20 ++++++++
 .../apache/spark/sql/MathExpressionsSuite.scala | 17 +++++++
 6 files changed, 113 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 12263e6..69e563e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -436,6 +436,19 @@ def shiftRight(col, numBits):
     return Column(jc)
 
 
+@since(1.5)
+def shiftRightUnsigned(col, numBits):
+    """Unsigned shift the the given value numBits right.
+
+    >>> sqlContext.createDataFrame([(-42,)], ['a']).select(shiftRightUnsigned('a', 1).alias('r'))\
+    .collect()
+    [Row(r=9223372036854775787)]
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.shiftRightUnsigned(_to_java_column(col), numBits)
+    return Column(jc)
+
+
 @since(1.4)
 def sparkPartitionId():
     """A column for partition ID of the Spark task.

http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 9163b03..cd5ba12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -129,6 +129,7 @@ object FunctionRegistry {
     expression[Rint]("rint"),
     expression[ShiftLeft]("shiftleft"),
     expression[ShiftRight]("shiftright"),
+    expression[ShiftRightUnsigned]("shiftrightunsigned"),
     expression[Signum]("sign"),
     expression[Signum]("signum"),
     expression[Sin]("sin"),

http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 273a6c5..0fc320f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -521,6 +521,55 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
   }
 }
 
+case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (left.dataType, right.dataType) match {
+      case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
+      case (_, IntegerType) => left.dataType match {
+        case LongType | IntegerType | ShortType | ByteType =>
+          return TypeCheckResult.TypeCheckSuccess
+        case _ => // failed
+      }
+      case _ => // failed
+    }
+    TypeCheckResult.TypeCheckFailure(
+      s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
+        s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val valueLeft = left.eval(input)
+    if (valueLeft != null) {
+      val valueRight = right.eval(input)
+      if (valueRight != null) {
+        valueLeft match {
+          case l: Long => l >>> valueRight.asInstanceOf[Integer]
+          case i: Integer => i >>> valueRight.asInstanceOf[Integer]
+          case s: Short => s >>> valueRight.asInstanceOf[Integer]
+          case b: Byte => b >>> valueRight.asInstanceOf[Integer]
+        }
+      } else {
+        null
+      }
+    } else {
+      null
+    }
+  }
+
+  override def dataType: DataType = {
+    left.dataType match {
+      case LongType => LongType
+      case IntegerType | ShortType | ByteType => IntegerType
+      case _ => NullType
+    }
+  }
+
+  override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
+  }
+}
+
 /**
  * Performs the inverse operation of HEX.
  * Resulting characters are returned as a byte array.

http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 8457864..20839c8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -264,6 +264,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
   }
 
+  test("shift right unsigned") {
+    checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null)
+    checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null)
+    checkEvaluation(
+      ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+    checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
+    checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21)
+    checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21)
+    checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
+
+    checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
+  }
+
   test("hex") {
     checkEvaluation(Hex(Literal(28)), "1C")
     checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")

http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0d5d49c..4b70dc5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1344,6 +1344,26 @@ object functions {
   def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
 
   /**
+   * Unsigned shift the the given value numBits right. If the given value is a long value,
+   * it will return a long value else it will return an integer value.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def shiftRightUnsigned(columnName: String, numBits: Int): Column =
+    shiftRightUnsigned(Column(columnName), numBits)
+
+  /**
+   * Unsigned shift the the given value numBits right. If the given value is a long value,
+   * it will return a long value else it will return an integer value.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def shiftRightUnsigned(e: Column, numBits: Int): Column =
+    ShiftRightUnsigned(e.expr, lit(numBits).expr)
+
+  /**
    * Shift the the given value numBits right. If the given value is a long value, it will return
    * a long value else it will return an integer value.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/ab535b9a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index dc8f994..24bef21 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -304,6 +304,23 @@ class MathExpressionsSuite extends QueryTest {
       Row(21.toLong, 21, 21.toShort, 21.toByte, null))
   }
 
+  test("shift right unsigned") {
+    val df = Seq[(Long, Integer, Short, Byte, Integer, Integer)]((-42, 42, 42, 42, 42, null))
+      .toDF("a", "b", "c", "d", "e", "f")
+
+    checkAnswer(
+      df.select(
+        shiftRightUnsigned('a, 1), shiftRightUnsigned('b, 1), shiftRightUnsigned('c, 1),
+        shiftRightUnsigned('d, 1), shiftRightUnsigned('f, 1)),
+      Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
+
+    checkAnswer(
+      df.selectExpr(
+        "shiftRightUnsigned(a, 1)", "shiftRightUnsigned(b, 1)", "shiftRightUnsigned(c, 1)",
+        "shiftRightUnsigned(d, 1)", "shiftRightUnsigned(f, 1)"),
+      Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null))
+  }
+
   test("binary log") {
     val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b")
     checkAnswer(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org