You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/21 07:08:17 UTC

spark git commit: [SPARK-9132][SPARK-9163][SQL] codegen conv

Repository: spark
Updated Branches:
  refs/heads/master 1cbdd8991 -> a3c7a3ce3


[SPARK-9132][SPARK-9163][SQL] codegen conv

Jira: https://issues.apache.org/jira/browse/SPARK-9132
https://issues.apache.org/jira/browse/SPARK-9163

rxin as you proposed in the Jira ticket, I just moved the logic to a separate object. I haven't changed anything of the logic of `NumberConverter`.

Author: Tarek Auel <ta...@googlemail.com>

Closes #7552 from tarekauel/SPARK-9163 and squashes the following commits:

40dcde9 [Tarek Auel] [SPARK-9132][SPARK-9163][SQL] style fix
fa985bd [Tarek Auel] [SPARK-9132][SPARK-9163][SQL] codegen conv


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a3c7a3ce
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a3c7a3ce
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a3c7a3ce

Branch: refs/heads/master
Commit: a3c7a3ce32697ad293b8bcaf29f9384c8255b37f
Parents: 1cbdd89
Author: Tarek Auel <ta...@googlemail.com>
Authored: Mon Jul 20 22:08:12 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon Jul 20 22:08:12 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/math.scala   | 204 ++++---------------
 .../sql/catalyst/util/NumberConverter.scala     | 176 ++++++++++++++++
 .../expressions/MathFunctionsSuite.scala        |   4 +-
 .../catalyst/util/NumberConverterSuite.scala    |  40 ++++
 4 files changed, 263 insertions(+), 161 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a3c7a3ce/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 7a9be02..68cca0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.NumberConverter
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -164,7 +165,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH"
  * @param toBaseExpr to which base
  */
 case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
-  extends Expression with ImplicitCastInputTypes with CodegenFallback {
+  extends Expression with ImplicitCastInputTypes {
 
   override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable
 
@@ -179,169 +180,54 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
   /** Returns the result of evaluating this expression on a given input Row */
   override def eval(input: InternalRow): Any = {
     val num = numExpr.eval(input)
-    val fromBase = fromBaseExpr.eval(input)
-    val toBase = toBaseExpr.eval(input)
-    if (num == null || fromBase == null || toBase == null) {
-      null
-    } else {
-      conv(
-        num.asInstanceOf[UTF8String].getBytes,
-        fromBase.asInstanceOf[Int],
-        toBase.asInstanceOf[Int])
-    }
-  }
-
-  private val value = new Array[Byte](64)
-
-  /**
-   * Divide x by m as if x is an unsigned 64-bit integer. Examples:
-   * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2
-   * unsignedLongDiv(0, 5) == 0
-   *
-   * @param x is treated as unsigned
-   * @param m is treated as signed
-   */
-  private def unsignedLongDiv(x: Long, m: Int): Long = {
-    if (x >= 0) {
-      x / m
-    } else {
-      // Let uval be the value of the unsigned long with the same bits as x
-      // Two's complement => x = uval - 2*MAX - 2
-      // => uval = x + 2*MAX + 2
-      // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c
-      x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m
-    }
-  }
-
-  /**
-   * Decode v into value[].
-   *
-   * @param v is treated as an unsigned 64-bit integer
-   * @param radix must be between MIN_RADIX and MAX_RADIX
-   */
-  private def decode(v: Long, radix: Int): Unit = {
-    var tmpV = v
-    java.util.Arrays.fill(value, 0.asInstanceOf[Byte])
-    var i = value.length - 1
-    while (tmpV != 0) {
-      val q = unsignedLongDiv(tmpV, radix)
-      value(i) = (tmpV - q * radix).asInstanceOf[Byte]
-      tmpV = q
-      i -= 1
-    }
-  }
-
-  /**
-   * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a
-   * negative digit is found, ignore the suffix starting there.
-   *
-   * @param radix  must be between MIN_RADIX and MAX_RADIX
-   * @param fromPos is the first element that should be conisdered
-   * @return the result should be treated as an unsigned 64-bit integer.
-   */
-  private def encode(radix: Int, fromPos: Int): Long = {
-    var v: Long = 0L
-    val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once
-    // val
-    // exceeds this value
-    var i = fromPos
-    while (i < value.length && value(i) >= 0) {
-      if (v >= bound) {
-        // Check for overflow
-        if (unsignedLongDiv(-1 - value(i), radix) < v) {
-          return -1
+    if (num != null) {
+      val fromBase = fromBaseExpr.eval(input)
+      if (fromBase != null) {
+        val toBase = toBaseExpr.eval(input)
+        if (toBase != null) {
+          NumberConverter.convert(
+            num.asInstanceOf[UTF8String].getBytes,
+            fromBase.asInstanceOf[Int],
+            toBase.asInstanceOf[Int])
+        } else {
+          null
         }
-      }
-      v = v * radix + value(i)
-      i += 1
-    }
-    v
-  }
-
-  /**
-   * Convert the bytes in value[] to the corresponding chars.
-   *
-   * @param radix must be between MIN_RADIX and MAX_RADIX
-   * @param fromPos is the first nonzero element
-   */
-  private def byte2char(radix: Int, fromPos: Int): Unit = {
-    var i = fromPos
-    while (i < value.length) {
-      value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte]
-      i += 1
-    }
-  }
-
-  /**
-   * Convert the chars in value[] to the corresponding integers. Convert invalid
-   * characters to -1.
-   *
-   * @param radix must be between MIN_RADIX and MAX_RADIX
-   * @param fromPos is the first nonzero element
-   */
-  private def char2byte(radix: Int, fromPos: Int): Unit = {
-    var i = fromPos
-    while ( i < value.length) {
-      value(i) = Character.digit(value(i), radix).asInstanceOf[Byte]
-      i += 1
-    }
-  }
-
-  /**
-   * Convert numbers between different number bases. If toBase>0 the result is
-   * unsigned, otherwise it is signed.
-   * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv
-   */
-  private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = {
-    if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
-      || Math.abs(toBase) < Character.MIN_RADIX
-      || Math.abs(toBase) > Character.MAX_RADIX) {
-      return null
-    }
-
-    if (n.length == 0) {
-      return null
-    }
-
-    var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
-
-    // Copy the digits in the right side of the array
-    var i = 1
-    while (i <= n.length - first) {
-      value(value.length - i) = n(n.length - i)
-      i += 1
-    }
-    char2byte(fromBase, value.length - n.length + first)
-
-    // Do the conversion by going through a 64 bit integer
-    var v = encode(fromBase, value.length - n.length + first)
-    if (negative && toBase > 0) {
-      if (v < 0) {
-        v = -1
       } else {
-        v = -v
+        null
       }
+    } else {
+      null
     }
-    if (toBase < 0 && v < 0) {
-      v = -v
-      negative = true
-    }
-    decode(v, Math.abs(toBase))
-
-    // Find the first non-zero digit or the last digits if all are zero.
-    val firstNonZeroPos = {
-      val firstNonZero = value.indexWhere( _ != 0)
-      if (firstNonZero != -1) firstNonZero else value.length - 1
-    }
-
-    byte2char(Math.abs(toBase), firstNonZeroPos)
+  }
 
-    var resultStartPos = firstNonZeroPos
-    if (negative && toBase < 0) {
-      resultStartPos = firstNonZeroPos - 1
-      value(resultStartPos) = '-'
-    }
-    UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length))
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val numGen = numExpr.gen(ctx)
+    val from = fromBaseExpr.gen(ctx)
+    val to = toBaseExpr.gen(ctx)
+
+    val numconv = NumberConverter.getClass.getName.stripSuffix("$")
+    s"""
+       ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+       ${numGen.code}
+       boolean ${ev.isNull} = ${numGen.isNull};
+       if (!${ev.isNull}) {
+         ${from.code}
+         if (!${from.isNull}) {
+           ${to.code}
+           if (!${to.isNull}) {
+             ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(),
+               ${from.primitive}, ${to.primitive});
+             if (${ev.primitive} == null) {
+               ${ev.isNull} = true;
+             }
+           } else {
+             ${ev.isNull} = true;
+           }
+         } else {
+           ${ev.isNull} = true;
+         }
+       }
+     """
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a3c7a3ce/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
new file mode 100644
index 0000000..9fefc56
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.unsafe.types.UTF8String
+
+object NumberConverter {
+
+  private val value = new Array[Byte](64)
+
+  /**
+   * Divide x by m as if x is an unsigned 64-bit integer. Examples:
+   * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2
+   * unsignedLongDiv(0, 5) == 0
+   *
+   * @param x is treated as unsigned
+   * @param m is treated as signed
+   */
+  private def unsignedLongDiv(x: Long, m: Int): Long = {
+    if (x >= 0) {
+      x / m
+    } else {
+      // Let uval be the value of the unsigned long with the same bits as x
+      // Two's complement => x = uval - 2*MAX - 2
+      // => uval = x + 2*MAX + 2
+      // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c
+      x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m
+    }
+  }
+
+  /**
+   * Decode v into value[].
+   *
+   * @param v is treated as an unsigned 64-bit integer
+   * @param radix must be between MIN_RADIX and MAX_RADIX
+   */
+  private def decode(v: Long, radix: Int): Unit = {
+    var tmpV = v
+    java.util.Arrays.fill(value, 0.asInstanceOf[Byte])
+    var i = value.length - 1
+    while (tmpV != 0) {
+      val q = unsignedLongDiv(tmpV, radix)
+      value(i) = (tmpV - q * radix).asInstanceOf[Byte]
+      tmpV = q
+      i -= 1
+    }
+  }
+
+  /**
+   * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a
+   * negative digit is found, ignore the suffix starting there.
+   *
+   * @param radix  must be between MIN_RADIX and MAX_RADIX
+   * @param fromPos is the first element that should be conisdered
+   * @return the result should be treated as an unsigned 64-bit integer.
+   */
+  private def encode(radix: Int, fromPos: Int): Long = {
+    var v: Long = 0L
+    val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once
+    // val
+    // exceeds this value
+    var i = fromPos
+    while (i < value.length && value(i) >= 0) {
+      if (v >= bound) {
+        // Check for overflow
+        if (unsignedLongDiv(-1 - value(i), radix) < v) {
+          return -1
+        }
+      }
+      v = v * radix + value(i)
+      i += 1
+    }
+    v
+  }
+
+  /**
+   * Convert the bytes in value[] to the corresponding chars.
+   *
+   * @param radix must be between MIN_RADIX and MAX_RADIX
+   * @param fromPos is the first nonzero element
+   */
+  private def byte2char(radix: Int, fromPos: Int): Unit = {
+    var i = fromPos
+    while (i < value.length) {
+      value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte]
+      i += 1
+    }
+  }
+
+  /**
+   * Convert the chars in value[] to the corresponding integers. Convert invalid
+   * characters to -1.
+   *
+   * @param radix must be between MIN_RADIX and MAX_RADIX
+   * @param fromPos is the first nonzero element
+   */
+  private def char2byte(radix: Int, fromPos: Int): Unit = {
+    var i = fromPos
+    while ( i < value.length) {
+      value(i) = Character.digit(value(i), radix).asInstanceOf[Byte]
+      i += 1
+    }
+  }
+
+  /**
+   * Convert numbers between different number bases. If toBase>0 the result is
+   * unsigned, otherwise it is signed.
+   * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv
+   */
+  def convert(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = {
+    if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
+      || Math.abs(toBase) < Character.MIN_RADIX
+      || Math.abs(toBase) > Character.MAX_RADIX) {
+      return null
+    }
+
+    if (n.length == 0) {
+      return null
+    }
+
+    var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
+
+    // Copy the digits in the right side of the array
+    var i = 1
+    while (i <= n.length - first) {
+      value(value.length - i) = n(n.length - i)
+      i += 1
+    }
+    char2byte(fromBase, value.length - n.length + first)
+
+    // Do the conversion by going through a 64 bit integer
+    var v = encode(fromBase, value.length - n.length + first)
+    if (negative && toBase > 0) {
+      if (v < 0) {
+        v = -1
+      } else {
+        v = -v
+      }
+    }
+    if (toBase < 0 && v < 0) {
+      v = -v
+      negative = true
+    }
+    decode(v, Math.abs(toBase))
+
+    // Find the first non-zero digit or the last digits if all are zero.
+    val firstNonZeroPos = {
+      val firstNonZero = value.indexWhere( _ != 0)
+      if (firstNonZero != -1) firstNonZero else value.length - 1
+    }
+
+    byte2char(Math.abs(toBase), firstNonZeroPos)
+
+    var resultStartPos = firstNonZeroPos
+    if (negative && toBase < 0) {
+      resultStartPos = firstNonZeroPos - 1
+      value(resultStartPos) = '-'
+    }
+    UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a3c7a3ce/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 04acd5b..a2b0fad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -115,8 +115,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
     checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
     checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
-    checkEvaluation(Conv(Literal(null), Literal(36), Literal(16)), null)
-    checkEvaluation(Conv(Literal("3"), Literal(null), Literal(16)), null)
+    checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
+    checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
     checkEvaluation(
       Conv(Literal("1234"), Literal(10), Literal(37)), null)
     checkEvaluation(

http://git-wip-us.apache.org/repos/asf/spark/blob/a3c7a3ce/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
new file mode 100644
index 0000000..13265a1
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.NumberConverter.convert
+import org.apache.spark.unsafe.types.UTF8String
+
+class NumberConverterSuite extends SparkFunSuite {
+
+  private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
+    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
+      UTF8String.fromString(expected))
+  }
+
+  test("convert") {
+    checkConv("3", 10, 2, "11")
+    checkConv("-15", 10, -16, "-F")
+    checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
+    checkConv("big", 36, 16, "3A48")
+    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
+    checkConv("11abc", 10, 16, "B")
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org