You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/17 18:32:32 UTC

spark git commit: [SPARK-8209[SQL]Add function conv

Repository: spark
Updated Branches:
  refs/heads/master 59d24c226 -> 305e77cd8


[SPARK-8209[SQL]Add function conv

cc chenghao-intel  adrian-wang

Author: zhichao.li <zh...@intel.com>

Closes #6872 from zhichao-li/conv and squashes the following commits:

6ef3b37 [zhichao.li] add unittest and comments
78d9836 [zhichao.li] polish dataframe api and add unittest
e2bace3 [zhichao.li] update to use ImplicitCastInputTypes
cbcad3f [zhichao.li] add function conv


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/305e77cd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/305e77cd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/305e77cd

Branch: refs/heads/master
Commit: 305e77cd83f3dbe680a920d5329c2e8c58452d5b
Parents: 59d24c2
Author: zhichao.li <zh...@intel.com>
Authored: Fri Jul 17 09:32:27 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Fri Jul 17 09:32:27 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../spark/sql/catalyst/expressions/math.scala   | 191 +++++++++++++++++++
 .../expressions/MathFunctionsSuite.scala        |  21 +-
 .../scala/org/apache/spark/sql/functions.scala  |  18 ++
 .../apache/spark/sql/MathExpressionsSuite.scala |  13 ++
 5 files changed, 242 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e0beafe..a451817 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -99,6 +99,7 @@ object FunctionRegistry {
     expression[Ceil]("ceil"),
     expression[Ceil]("ceiling"),
     expression[Cos]("cos"),
+    expression[Conv]("conv"),
     expression[EulerNumber]("e"),
     expression[Exp]("exp"),
     expression[Expm1]("expm1"),

http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 84b289c..7a543ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.{lang => jl}
+import java.util.Arrays
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
@@ -139,6 +140,196 @@ case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
 
 case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH")
 
+/**
+ * Convert a num from one base to another
+ * @param numExpr the number to be converted
+ * @param fromBaseExpr from which base
+ * @param toBaseExpr to which base
+ */
+case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
+  extends Expression with ImplicitCastInputTypes{
+
+  override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable
+
+  override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable
+
+  override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
+
+  /** Returns the result of evaluating this expression on a given input Row */
+  override def eval(input: InternalRow): Any = {
+    val num = numExpr.eval(input)
+    val fromBase = fromBaseExpr.eval(input)
+    val toBase = toBaseExpr.eval(input)
+    if (num == null || fromBase == null || toBase == null) {
+      null
+    } else {
+      conv(num.asInstanceOf[UTF8String].getBytes,
+        fromBase.asInstanceOf[Int], toBase.asInstanceOf[Int])
+    }
+  }
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression.  It is
+   * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override def dataType: DataType = StringType
+
+  private val value = new Array[Byte](64)
+
+  /**
+   * Divide x by m as if x is an unsigned 64-bit integer. Examples:
+   * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2
+   * unsignedLongDiv(0, 5) == 0
+   *
+   * @param x is treated as unsigned
+   * @param m is treated as signed
+   */
+  private def unsignedLongDiv(x: Long, m: Int): Long = {
+    if (x >= 0) {
+      x / m
+    } else {
+      // Let uval be the value of the unsigned long with the same bits as x
+      // Two's complement => x = uval - 2*MAX - 2
+      // => uval = x + 2*MAX + 2
+      // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c
+      (x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m)
+    }
+  }
+
+  /**
+   * Decode v into value[].
+   *
+   * @param v is treated as an unsigned 64-bit integer
+   * @param radix must be between MIN_RADIX and MAX_RADIX
+   */
+  private def decode(v: Long, radix: Int): Unit = {
+    var tmpV = v
+    Arrays.fill(value, 0.asInstanceOf[Byte])
+    var i = value.length - 1
+    while (tmpV != 0) {
+      val q = unsignedLongDiv(tmpV, radix)
+      value(i) = (tmpV - q * radix).asInstanceOf[Byte]
+      tmpV = q
+      i -= 1
+    }
+  }
+
+  /**
+   * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a
+   * negative digit is found, ignore the suffix starting there.
+   *
+   * @param radix  must be between MIN_RADIX and MAX_RADIX
+   * @param fromPos is the first element that should be conisdered
+   * @return the result should be treated as an unsigned 64-bit integer.
+   */
+  private def encode(radix: Int, fromPos: Int): Long = {
+    var v: Long = 0L
+    val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once
+    // val
+    // exceeds this value
+    var i = fromPos
+    while (i < value.length && value(i) >= 0) {
+      if (v >= bound) {
+        // Check for overflow
+        if (unsignedLongDiv(-1 - value(i), radix) < v) {
+          return -1
+        }
+      }
+      v = v * radix + value(i)
+      i += 1
+    }
+    return v
+  }
+
+  /**
+   * Convert the bytes in value[] to the corresponding chars.
+   *
+   * @param radix must be between MIN_RADIX and MAX_RADIX
+   * @param fromPos is the first nonzero element
+   */
+  private def byte2char(radix: Int, fromPos: Int): Unit = {
+    var i = fromPos
+    while (i < value.length) {
+      value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte]
+      i += 1
+    }
+  }
+
+  /**
+   * Convert the chars in value[] to the corresponding integers. Convert invalid
+   * characters to -1.
+   *
+   * @param radix must be between MIN_RADIX and MAX_RADIX
+   * @param fromPos is the first nonzero element
+   */
+  private def char2byte(radix: Int, fromPos: Int): Unit = {
+    var i = fromPos
+    while ( i < value.length) {
+      value(i) = Character.digit(value(i), radix).asInstanceOf[Byte]
+      i += 1
+    }
+  }
+
+  /**
+   * Convert numbers between different number bases. If toBase>0 the result is
+   * unsigned, otherwise it is signed.
+   * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv
+   */
+  private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = {
+    if (n == null || fromBase == null || toBase == null || n.isEmpty) {
+      return null
+    }
+
+    if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
+      || Math.abs(toBase) < Character.MIN_RADIX
+      || Math.abs(toBase) > Character.MAX_RADIX) {
+      return null
+    }
+
+    var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
+
+    // Copy the digits in the right side of the array
+    var i = 1
+    while (i <= n.length - first) {
+      value(value.length - i) = n(n.length - i)
+      i += 1
+    }
+    char2byte(fromBase, value.length - n.length + first)
+
+    // Do the conversion by going through a 64 bit integer
+    var v = encode(fromBase, value.length - n.length + first)
+    if (negative && toBase > 0) {
+      if (v < 0) {
+        v = -1
+      } else {
+        v = -v
+      }
+    }
+    if (toBase < 0 && v < 0) {
+      v = -v
+      negative = true
+    }
+    decode(v, Math.abs(toBase))
+
+    // Find the first non-zero digit or the last digits if all are zero.
+    val firstNonZeroPos = {
+      val firstNonZero = value.indexWhere( _ != 0)
+      if (firstNonZero != -1) firstNonZero else value.length - 1
+    }
+
+    byte2char(Math.abs(toBase), firstNonZeroPos)
+
+    var resultStartPos = firstNonZeroPos
+    if (negative && toBase < 0) {
+      resultStartPos = firstNonZeroPos - 1
+      value(resultStartPos) = '-'
+    }
+    UTF8String.fromBytes( Arrays.copyOfRange(value, resultStartPos, value.length))
+  }
+}
+
 case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
 
 case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")

http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 52a874a..ca35c7e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -17,14 +17,13 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import scala.math.BigDecimal.RoundingMode
-
 import com.google.common.math.LongMath
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types._
 
+
 class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   /**
@@ -95,6 +94,24 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null))
   }
 
+  test("conv") {
+    checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
+    checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
+    checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
+    checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
+    checkEvaluation(Conv(Literal(null), Literal(36), Literal(16)), null)
+    checkEvaluation(Conv(Literal("3"), Literal(null), Literal(16)), null)
+    checkEvaluation(
+      Conv(Literal("1234"), Literal(10), Literal(37)), null)
+    checkEvaluation(
+      Conv(Literal(""), Literal(10), Literal(16)), null)
+    checkEvaluation(
+      Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
+    // If there is an invalid digit in the number, the longest valid prefix should be converted.
+    checkEvaluation(
+      Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
+  }
+
   test("e") {
     testLeaf(EulerNumber, math.E)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d6da284..fe511c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -69,6 +69,24 @@ object functions {
   def column(colName: String): Column = Column(colName)
 
   /**
+   * Convert a number from one base to another for the specified expressions
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def conv(num: Column, fromBase: Int, toBase: Int): Column =
+    Conv(num.expr, lit(fromBase).expr, lit(toBase).expr)
+
+  /**
+   * Convert a number from one base to another for the specified expressions
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def conv(numColName: String, fromBase: Int, toBase: Int): Column =
+    conv(Column(numColName), fromBase, toBase)
+
+  /**
    * Creates a [[Column]] of literal value.
    *
    * The passed in object is returned directly if it is already a [[Column]].

http://git-wip-us.apache.org/repos/asf/spark/blob/305e77cd/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 087126b..8eb3fec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -178,6 +178,19 @@ class MathExpressionsSuite extends QueryTest {
       Row(0.0, 1.0, 2.0))
   }
 
+  test("conv") {
+    val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase")
+    checkAnswer(df.select(conv('num, 10, 16)), Row("14D"))
+    checkAnswer(df.select(conv("num", 10, 16)), Row("14D"))
+    checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4"))
+    checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457"))
+    checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101"))
+    checkAnswer(df.selectExpr("""conv("100", 2, 10)"""), Row("4"))
+    checkAnswer(df.selectExpr("""conv("-10", 16, -10)"""), Row("-16"))
+    checkAnswer(
+      df.selectExpr("""conv("9223372036854775807", 36, -16)"""), Row("-1")) // for overflow
+  }
+
   test("floor") {
     testOneToOneMathFunction(floor, math.floor)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org