You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/04 20:55:22 UTC

spark git commit: [SPARK-8822][SQL] clean up type checking in math.scala.

Repository: spark
Updated Branches:
  refs/heads/master 347cab85c -> c991ef5ab


[SPARK-8822][SQL] clean up type checking in math.scala.

Author: Reynold Xin <rx...@databricks.com>

Closes #7220 from rxin/SPARK-8822 and squashes the following commits:

0cda076 [Reynold Xin] Test cases.
22d0463 [Reynold Xin] Fixed type precedence.
beb2a97 [Reynold Xin] [SPARK-8822][SQL] clean up type checking in math.scala.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c991ef5a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c991ef5a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c991ef5a

Branch: refs/heads/master
Commit: c991ef5abbb501933b2a68eea1987cf8d88794a5
Parents: 347cab8
Author: Reynold Xin <rx...@databricks.com>
Authored: Sat Jul 4 11:55:20 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Jul 4 11:55:20 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/math.scala   | 260 ++++++++-----------
 .../expressions/MathFunctionsSuite.scala        |  31 ++-
 2 files changed, 123 insertions(+), 168 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c991ef5a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 0fc320f..45b7e4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -17,10 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import java.lang.{Long => JLong}
-import java.util.Arrays
+import java.{lang => jl}
 
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -206,7 +204,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu
     if (evalE == null) {
       null
     } else {
-      val input = evalE.asInstanceOf[Integer]
+      val input = evalE.asInstanceOf[jl.Integer]
       if (input > 20 || input < 0) {
         null
       } else {
@@ -290,7 +288,7 @@ case class Bin(child: Expression)
     if (evalE == null) {
       null
     } else {
-      UTF8String.fromString(JLong.toBinaryString(evalE.asInstanceOf[Long]))
+      UTF8String.fromString(jl.Long.toBinaryString(evalE.asInstanceOf[Long]))
     }
   }
 
@@ -300,27 +298,18 @@ case class Bin(child: Expression)
   }
 }
 
-
 /**
  * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
  * Otherwise if the number is a STRING, it converts each character into its hex representation
  * and returns the resulting STRING. Negative numbers would be treated as two's complement.
  */
-case class Hex(child: Expression) extends UnaryExpression with Serializable  {
+case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+  // TODO: Create code-gen version.
 
-  override def dataType: DataType = StringType
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(TypeCollection(LongType, StringType, BinaryType))
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (child.dataType.isInstanceOf[StringType]
-      || child.dataType.isInstanceOf[IntegerType]
-      || child.dataType.isInstanceOf[LongType]
-      || child.dataType.isInstanceOf[BinaryType]
-      || child.dataType == NullType) {
-      TypeCheckResult.TypeCheckSuccess
-    } else {
-      TypeCheckResult.TypeCheckFailure(s"hex doesn't accepts ${child.dataType} type")
-    }
-  }
+  override def dataType: DataType = StringType
 
   override def eval(input: InternalRow): Any = {
     val num = child.eval(input)
@@ -329,7 +318,6 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable  {
     } else {
       child.dataType match {
         case LongType => hex(num.asInstanceOf[Long])
-        case IntegerType => hex(num.asInstanceOf[Integer].toLong)
         case BinaryType => hex(num.asInstanceOf[Array[Byte]])
         case StringType => hex(num.asInstanceOf[UTF8String])
       }
@@ -371,7 +359,55 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable  {
         Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte
       numBuf >>>= 4
     } while (numBuf != 0)
-    UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length))
+    UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length))
+  }
+}
+
+
+/**
+ * Performs the inverse operation of HEX.
+ * Resulting characters are returned as a byte array.
+ */
+case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+  // TODO: Create code-gen version.
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+
+  override def dataType: DataType = BinaryType
+
+  override def eval(input: InternalRow): Any = {
+    val num = child.eval(input)
+    if (num == null) {
+      null
+    } else {
+      unhex(num.asInstanceOf[UTF8String].getBytes)
+    }
+  }
+
+  private val unhexDigits = {
+    val array = Array.fill[Byte](128)(-1)
+    (0 to 9).foreach(i => array('0' + i) = i.toByte)
+    (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
+    (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
+    array
+  }
+
+  private def unhex(inputBytes: Array[Byte]): Array[Byte] = {
+    var bytes = inputBytes
+    if ((bytes.length & 0x01) != 0) {
+      bytes = '0'.toByte +: bytes
+    }
+    val out = new Array[Byte](bytes.length >> 1)
+    // two characters form the hex value.
+    var i = 0
+    while (i < bytes.length) {
+      val first = unhexDigits(bytes(i))
+      val second = unhexDigits(bytes(i + 1))
+      if (first == -1 || second == -1) { return null}
+      out(i / 2) = (((first << 4) | second) & 0xFF).toByte
+      i += 2
+    }
+    out
   }
 }
 
@@ -423,22 +459,19 @@ case class Pow(left: Expression, right: Expression)
   }
 }
 
-case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression {
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    (left.dataType, right.dataType) match {
-      case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
-      case (_, IntegerType) => left.dataType match {
-        case LongType | IntegerType | ShortType | ByteType =>
-          return TypeCheckResult.TypeCheckSuccess
-        case _ => // failed
-      }
-      case _ => // failed
-    }
-    TypeCheckResult.TypeCheckFailure(
-        s"ShiftLeft expects long, integer, short or byte value as first argument and an " +
-          s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
-  }
+/**
+ * Bitwise unsigned left shift.
+ * @param left the base number to shift.
+ * @param right number of bits to left shift.
+ */
+case class ShiftLeft(left: Expression, right: Expression)
+  extends BinaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(TypeCollection(IntegerType, LongType), IntegerType)
+
+  override def dataType: DataType = left.dataType
 
   override def eval(input: InternalRow): Any = {
     val valueLeft = left.eval(input)
@@ -446,10 +479,8 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi
       val valueRight = right.eval(input)
       if (valueRight != null) {
         valueLeft match {
-          case l: Long => l << valueRight.asInstanceOf[Integer]
-          case i: Integer => i << valueRight.asInstanceOf[Integer]
-          case s: Short => s << valueRight.asInstanceOf[Integer]
-          case b: Byte => b << valueRight.asInstanceOf[Integer]
+          case l: jl.Long => l << valueRight.asInstanceOf[jl.Integer]
+          case i: jl.Integer => i << valueRight.asInstanceOf[jl.Integer]
         }
       } else {
         null
@@ -459,35 +490,24 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi
     }
   }
 
-  override def dataType: DataType = {
-    left.dataType match {
-      case LongType => LongType
-      case IntegerType | ShortType | ByteType => IntegerType
-      case _ => NullType
-    }
-  }
-
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;")
   }
 }
 
-case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression {
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    (left.dataType, right.dataType) match {
-      case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
-      case (_, IntegerType) => left.dataType match {
-        case LongType | IntegerType | ShortType | ByteType =>
-          return TypeCheckResult.TypeCheckSuccess
-        case _ => // failed
-      }
-      case _ => // failed
-    }
-    TypeCheckResult.TypeCheckFailure(
-          s"ShiftRight expects long, integer, short or byte value as first argument and an " +
-            s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
-  }
+/**
+ * Bitwise unsigned left shift.
+ * @param left the base number to shift.
+ * @param right number of bits to left shift.
+ */
+case class ShiftRight(left: Expression, right: Expression)
+  extends BinaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(TypeCollection(IntegerType, LongType), IntegerType)
+
+  override def dataType: DataType = left.dataType
 
   override def eval(input: InternalRow): Any = {
     val valueLeft = left.eval(input)
@@ -495,10 +515,8 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
       val valueRight = right.eval(input)
       if (valueRight != null) {
         valueLeft match {
-          case l: Long => l >> valueRight.asInstanceOf[Integer]
-          case i: Integer => i >> valueRight.asInstanceOf[Integer]
-          case s: Short => s >> valueRight.asInstanceOf[Integer]
-          case b: Byte => b >> valueRight.asInstanceOf[Integer]
+          case l: jl.Long => l >> valueRight.asInstanceOf[jl.Integer]
+          case i: jl.Integer => i >> valueRight.asInstanceOf[jl.Integer]
         }
       } else {
         null
@@ -508,35 +526,24 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
     }
   }
 
-  override def dataType: DataType = {
-    left.dataType match {
-      case LongType => LongType
-      case IntegerType | ShortType | ByteType => IntegerType
-      case _ => NullType
-    }
-  }
-
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;")
   }
 }
 
-case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    (left.dataType, right.dataType) match {
-      case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
-      case (_, IntegerType) => left.dataType match {
-        case LongType | IntegerType | ShortType | ByteType =>
-          return TypeCheckResult.TypeCheckSuccess
-        case _ => // failed
-      }
-      case _ => // failed
-    }
-    TypeCheckResult.TypeCheckFailure(
-      s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
-        s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
-  }
+/**
+ * Bitwise unsigned right shift, for integer and long data type.
+ * @param left the base number.
+ * @param right the number of bits to right shift.
+ */
+case class ShiftRightUnsigned(left: Expression, right: Expression)
+  extends BinaryExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(TypeCollection(IntegerType, LongType), IntegerType)
+
+  override def dataType: DataType = left.dataType
 
   override def eval(input: InternalRow): Any = {
     val valueLeft = left.eval(input)
@@ -544,10 +551,8 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar
       val valueRight = right.eval(input)
       if (valueRight != null) {
         valueLeft match {
-          case l: Long => l >>> valueRight.asInstanceOf[Integer]
-          case i: Integer => i >>> valueRight.asInstanceOf[Integer]
-          case s: Short => s >>> valueRight.asInstanceOf[Integer]
-          case b: Byte => b >>> valueRight.asInstanceOf[Integer]
+          case l: jl.Long => l >>> valueRight.asInstanceOf[jl.Integer]
+          case i: jl.Integer => i >>> valueRight.asInstanceOf[jl.Integer]
         }
       } else {
         null
@@ -557,74 +562,21 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar
     }
   }
 
-  override def dataType: DataType = {
-    left.dataType match {
-      case LongType => LongType
-      case IntegerType | ShortType | ByteType => IntegerType
-      case _ => NullType
-    }
-  }
-
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
   }
 }
 
-/**
- * Performs the inverse operation of HEX.
- * Resulting characters are returned as a byte array.
- */
-case class UnHex(child: Expression) extends UnaryExpression with Serializable {
-
-  override def dataType: DataType = BinaryType
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) {
-      TypeCheckResult.TypeCheckSuccess
-    } else {
-      TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}")
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    val num = child.eval(input)
-    if (num == null) {
-      null
-    } else {
-      unhex(num.asInstanceOf[UTF8String].getBytes)
-    }
-  }
-
-  private val unhexDigits = {
-    val array = Array.fill[Byte](128)(-1)
-    (0 to 9).foreach(i => array('0' + i) = i.toByte)
-    (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
-    (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
-    array
-  }
-
-  private def unhex(inputBytes: Array[Byte]): Array[Byte] = {
-    var bytes = inputBytes
-    if ((bytes.length & 0x01) != 0) {
-      bytes = '0'.toByte +: bytes
-    }
-    val out = new Array[Byte](bytes.length >> 1)
-    // two characters form the hex value.
-    var i = 0
-    while (i < bytes.length) {
-        val first = unhexDigits(bytes(i))
-        val second = unhexDigits(bytes(i + 1))
-        if (first == -1 || second == -1) { return null}
-        out(i / 2) = (((first << 4) | second) & 0xFF).toByte
-        i += 2
-    }
-    out
-  }
-}
 
 case class Hypot(left: Expression, right: Expression)
   extends BinaryMathExpression(math.hypot, "HYPOT")
 
+
+/**
+ * Computes the logarithm of a number.
+ * @param left the logarithm base, default to e.
+ * @param right the number to compute the logarithm of.
+ */
 case class Logarithm(left: Expression, right: Expression)
   extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
 
@@ -642,7 +594,7 @@ case class Logarithm(left: Expression, right: Expression)
       defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
     }
     logCode + s"""
-      if (Double.valueOf(${ev.primitive}).isNaN()) {
+      if (Double.isNaN(${ev.primitive})) {
         ${ev.isNull} = true;
       }
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/c991ef5a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 20839c8..03d8400 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -161,11 +161,10 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("factorial") {
-    val dataLong = (0 to 20)
-    dataLong.foreach { value =>
+    (0 to 20).foreach { value =>
       checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow)
     }
-    checkEvaluation((Literal.create(null, IntegerType)), null, create_row(null))
+    checkEvaluation(Literal.create(null, IntegerType), null, create_row(null))
     checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow)
     checkEvaluation(Factorial(Literal(21)), null, EmptyRow)
   }
@@ -244,10 +243,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(
       ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
     checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42)
-    checkEvaluation(ShiftLeft(Literal(21.toByte), Literal(1)), 42)
-    checkEvaluation(ShiftLeft(Literal(21.toShort), Literal(1)), 42)
-    checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
 
+    checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
     checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong)
   }
 
@@ -257,10 +254,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(
       ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
     checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21)
-    checkEvaluation(ShiftRight(Literal(42.toByte), Literal(1)), 21)
-    checkEvaluation(ShiftRight(Literal(42.toShort), Literal(1)), 21)
-    checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
 
+    checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
     checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
   }
 
@@ -270,16 +265,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(
       ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
     checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
-    checkEvaluation(ShiftRightUnsigned(Literal(42.toByte), Literal(1)), 21)
-    checkEvaluation(ShiftRightUnsigned(Literal(42.toShort), Literal(1)), 21)
-    checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
 
+    checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
     checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
   }
 
   test("hex") {
-    checkEvaluation(Hex(Literal(28)), "1C")
-    checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4")
     checkEvaluation(Hex(Literal(100800200404L)), "177828FED4")
     checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C")
     checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578")
@@ -313,6 +304,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
       checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow)
     }
+
+    // null input should yield null output
     checkEvaluation(
       Logarithm(Literal.create(null, DoubleType), Literal(1.0)),
       null,
@@ -321,5 +314,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       Logarithm(Literal(1.0), Literal.create(null, DoubleType)),
       null,
       create_row(null))
+
+    // negative input should yield null output
+    checkEvaluation(
+      Logarithm(Literal(-1.0), Literal(1.0)),
+      null,
+      create_row(null))
+    checkEvaluation(
+      Logarithm(Literal(1.0), Literal(-1.0)),
+      null,
+      create_row(null))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org