You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/06/26 07:07:48 UTC

spark git commit: [SPARK-8237] [SQL] Add misc function sha2

Repository: spark
Updated Branches:
  refs/heads/master c392a9efa -> 47c874bab


[SPARK-8237] [SQL] Add misc function sha2

JIRA: https://issues.apache.org/jira/browse/SPARK-8237

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #6934 from viirya/expr_sha2 and squashes the following commits:

35e0bb3 [Liang-Chi Hsieh] For comments.
68b5284 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2
8573aff [Liang-Chi Hsieh] Remove unnecessary Product.
ee61e06 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2
59e41aa [Liang-Chi Hsieh] Add misc function: sha2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/47c874ba
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/47c874ba
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/47c874ba

Branch: refs/heads/master
Commit: 47c874babe7779c7a2f32e0b891503ef6bebcab0
Parents: c392a9e
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Jun 25 22:07:37 2015 -0700
Committer: Davies Liu <da...@databricks.com>
Committed: Thu Jun 25 22:07:37 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 19 ++++
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../spark/sql/catalyst/expressions/misc.scala   | 98 +++++++++++++++++++-
 .../expressions/MiscFunctionsSuite.scala        | 14 ++-
 .../scala/org/apache/spark/sql/functions.scala  | 20 ++++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 17 ++++
 6 files changed, 165 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index cfa87ae..7d3d036 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -42,6 +42,7 @@ __all__ = [
     'monotonicallyIncreasingId',
     'rand',
     'randn',
+    'sha2',
     'sparkPartitionId',
     'struct',
     'udf',
@@ -363,6 +364,24 @@ def randn(seed=None):
     return Column(jc)
 
 
+@ignore_unicode_prefix
+@since(1.5)
+def sha2(col, numBits):
+    """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
+    and SHA-512). The numBits indicates the desired bit length of the result, which must have a
+    value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
+
+    >>> digests = df.select(sha2(df.name, 256).alias('s')).collect()
+    >>> digests[0]
+    Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')
+    >>> digests[1]
+    Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.sha2(_to_java_column(col), numBits)
+    return Column(jc)
+
+
 @since(1.4)
 def sparkPartitionId():
     """A column for partition ID of the Spark task.

http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 5fb3369..457948a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -135,6 +135,7 @@ object FunctionRegistry {
 
     // misc functions
     expression[Md5]("md5"),
+    expression[Sha2]("sha2"),
 
     // aggregate functions
     expression[Average]("avg"),

http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 4bee8cb..e80706f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -17,9 +17,12 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.security.MessageDigest
+import java.security.NoSuchAlgorithmException
+
 import org.apache.commons.codec.digest.DigestUtils
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{BinaryType, StringType, DataType}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType}
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
@@ -44,7 +47,96 @@ case class Md5(child: Expression)
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     defineCodeGen(ctx, ev, c =>
-      "org.apache.spark.unsafe.types.UTF8String.fromString" +
-        s"(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
+      s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
+  }
+}
+
+/**
+ * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512)
+ * and returns it as a hex string. The first argument is the string or binary to be hashed. The
+ * second argument indicates the desired bit length of the result, which must have a value of 224,
+ * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If
+ * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or
+ * the hash length is not one of the permitted values, the return value is NULL.
+ */
+case class Sha2(left: Expression, right: Expression)
+  extends BinaryExpression with Serializable with ExpectsInputTypes {
+
+  override def dataType: DataType = StringType
+
+  override def toString: String = s"SHA2($left, $right)"
+
+  override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
+
+  override def eval(input: InternalRow): Any = {
+    val evalE1 = left.eval(input)
+    if (evalE1 == null) {
+      null
+    } else {
+      val evalE2 = right.eval(input)
+      if (evalE2 == null) {
+        null
+      } else {
+        val bitLength = evalE2.asInstanceOf[Int]
+        val input = evalE1.asInstanceOf[Array[Byte]]
+        bitLength match {
+          case 224 =>
+            // DigestUtils doesn't support SHA-224 now
+            try {
+              val md = MessageDigest.getInstance("SHA-224")
+              md.update(input)
+              UTF8String.fromBytes(md.digest())
+            } catch {
+              // SHA-224 is not supported on the system, return null
+              case noa: NoSuchAlgorithmException => null
+            }
+          case 256 | 0 =>
+            UTF8String.fromString(DigestUtils.sha256Hex(input))
+          case 384 =>
+            UTF8String.fromString(DigestUtils.sha384Hex(input))
+          case 512 =>
+            UTF8String.fromString(DigestUtils.sha512Hex(input))
+          case _ => null
+        }
+      }
+    }
+  }
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val eval1 = left.gen(ctx)
+    val eval2 = right.gen(ctx)
+    val digestUtils = "org.apache.commons.codec.digest.DigestUtils"
+
+    s"""
+      ${eval1.code}
+      boolean ${ev.isNull} = ${eval1.isNull};
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      if (!${ev.isNull}) {
+        ${eval2.code}
+        if (!${eval2.isNull}) {
+          if (${eval2.primitive} == 224) {
+            try {
+              java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224");
+              md.update(${eval1.primitive});
+              ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
+            } catch (java.security.NoSuchAlgorithmException e) {
+              ${ev.isNull} = true;
+            }
+          } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) {
+            ${ev.primitive} =
+              ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive}));
+          } else if (${eval2.primitive} == 384) {
+            ${ev.primitive} =
+              ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive}));
+          } else if (${eval2.primitive} == 512) {
+            ${ev.primitive} =
+              ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive}));
+          } else {
+            ${ev.isNull} = true;
+          }
+        } else {
+          ${ev.isNull} = true;
+        }
+      }
+    """
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 48b8413..38482c5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.commons.codec.digest.DigestUtils
+
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.{StringType, BinaryType}
+import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}
 
 class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -29,4 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
   }
 
+  test("sha2") {
+    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
+    checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
+      DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
+    // unsupported bit length
+    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
+    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
+    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
+    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 38d9085..355ce0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1414,6 +1414,26 @@ object functions {
    */
   def md5(columnName: String): Column = md5(Column(columnName))
 
+  /**
+   * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
+   *
+   * @group misc_funcs
+   * @since 1.5.0
+   */
+  def sha2(e: Column, numBits: Int): Column = {
+    require(Seq(0, 224, 256, 384, 512).contains(numBits),
+      s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)")
+    Sha2(e.expr, lit(numBits).expr)
+  }
+
+  /**
+   * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
+   *
+   * @group misc_funcs
+   * @since 1.5.0
+   */
+  def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits)
+
   //////////////////////////////////////////////////////////////////////////////////////////////
   // String functions
   //////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/47c874ba/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 8b53b38..8baed57 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -144,6 +144,23 @@ class DataFrameFunctionsSuite extends QueryTest {
       Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c"))
   }
 
+  test("misc sha2 function") {
+    val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
+    checkAnswer(
+      df.select(sha2($"a", 256), sha2("b", 256)),
+      Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78",
+        "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89"))
+
+    checkAnswer(
+      df.selectExpr("sha2(a, 256)", "sha2(b, 256)"),
+      Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78",
+        "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89"))
+
+    intercept[IllegalArgumentException] {
+      df.select(sha2($"a", 1024))
+    }
+  }
+
   test("string length function") {
     checkAnswer(
       nullStrings.select(strlen($"s"), strlen("s")),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org