You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/05/31 18:38:28 UTC

spark git commit: [SPARK-23900][SQL] format_number support user specifed format as argument

Repository: spark
Updated Branches:
  refs/heads/master 223df5d9d -> cc976f6cb


[SPARK-23900][SQL] format_number support user specifed format as argument

## What changes were proposed in this pull request?

`format_number` support user specifed format as argument. For example:
```sql
spark-sql> SELECT format_number(12332.123456, '##################.###');
12332.123
```

## How was this patch tested?

unit test

Author: Yuming Wang <yu...@ebay.com>

Closes #21010 from wangyum/SPARK-23900.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cc976f6c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cc976f6c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cc976f6c

Branch: refs/heads/master
Commit: cc976f6cb858adb5f52987b56dda54769915ce50
Parents: 223df5d
Author: Yuming Wang <yu...@ebay.com>
Authored: Thu May 31 11:38:23 2018 -0700
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Thu May 31 11:38:23 2018 -0700

----------------------------------------------------------------------
 .../expressions/stringExpressions.scala         | 142 ++++++++++++-------
 .../expressions/StringExpressionsSuite.scala    |  24 ++++
 2 files changed, 116 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cc976f6c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 9823b2f..bedad7d 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1916,12 +1916,15 @@ case class Encode(value: Expression, charset: Expression)
   usage = """
     _FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2`
       decimal places. If `expr2` is 0, the result has no decimal point or fractional part.
+      `expr2` also accept a user specified format.
       This is supposed to function like MySQL's FORMAT.
   """,
   examples = """
     Examples:
       > SELECT _FUNC_(12332.123456, 4);
        12,332.1235
+      > SELECT _FUNC_(12332.123456, '##################.###');
+       12332.123
   """)
 case class FormatNumber(x: Expression, d: Expression)
   extends BinaryExpression with ExpectsInputTypes {
@@ -1930,14 +1933,20 @@ case class FormatNumber(x: Expression, d: Expression)
   override def right: Expression = d
   override def dataType: DataType = StringType
   override def nullable: Boolean = true
-  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(NumericType, TypeCollection(IntegerType, StringType))
+
+  private val defaultFormat = "#,###,###,###,###,###,##0"
 
   // Associated with the pattern, for the last d value, and we will update the
   // pattern (DecimalFormat) once the new coming d value differ with the last one.
   // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after
   // serialization (numberFormat has not been updated for dValue = 0).
   @transient
-  private var lastDValue: Option[Int] = None
+  private var lastDIntValue: Option[Int] = None
+
+  @transient
+  private var lastDStringValue: Option[String] = None
 
   // A cached DecimalFormat, for performance concern, we will change it
   // only if the d value changed.
@@ -1950,33 +1959,49 @@ case class FormatNumber(x: Expression, d: Expression)
   private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US))
 
   override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
-    val dValue = dObject.asInstanceOf[Int]
-    if (dValue < 0) {
-      return null
-    }
-
-    lastDValue match {
-      case Some(last) if last == dValue =>
-        // use the current pattern
-      case _ =>
-        // construct a new DecimalFormat only if a new dValue
-        pattern.delete(0, pattern.length)
-        pattern.append("#,###,###,###,###,###,##0")
-
-        // decimal place
-        if (dValue > 0) {
-          pattern.append(".")
-
-          var i = 0
-          while (i < dValue) {
-            i += 1
-            pattern.append("0")
-          }
+    right.dataType match {
+      case IntegerType =>
+        val dValue = dObject.asInstanceOf[Int]
+        if (dValue < 0) {
+          return null
         }
 
-        lastDValue = Some(dValue)
+        lastDIntValue match {
+          case Some(last) if last == dValue =>
+          // use the current pattern
+          case _ =>
+            // construct a new DecimalFormat only if a new dValue
+            pattern.delete(0, pattern.length)
+            pattern.append(defaultFormat)
+
+            // decimal place
+            if (dValue > 0) {
+              pattern.append(".")
+
+              var i = 0
+              while (i < dValue) {
+                i += 1
+                pattern.append("0")
+              }
+            }
+
+            lastDIntValue = Some(dValue)
 
-        numberFormat.applyLocalizedPattern(pattern.toString)
+            numberFormat.applyLocalizedPattern(pattern.toString)
+        }
+      case StringType =>
+        val dValue = dObject.asInstanceOf[UTF8String].toString
+        lastDStringValue match {
+          case Some(last) if last == dValue =>
+          case _ =>
+            pattern.delete(0, pattern.length)
+            lastDStringValue = Some(dValue)
+            if (dValue.isEmpty) {
+              numberFormat.applyLocalizedPattern(defaultFormat)
+            } else {
+              numberFormat.applyLocalizedPattern(dValue)
+            }
+        }
     }
 
     x.dataType match {
@@ -2008,35 +2033,52 @@ case class FormatNumber(x: Expression, d: Expression)
       // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
       // as a decimal separator.
       val usLocale = "US"
-      val i = ctx.freshName("i")
-      val dFormat = ctx.freshName("dFormat")
-      val lastDValue =
-        ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
-      val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
       val numberFormat = ctx.addMutableState(df, "numberFormat",
         v => s"""$v = new $df("", new $dfs($l.$usLocale));""")
 
-      s"""
-        if ($d >= 0) {
-          $pattern.delete(0, $pattern.length());
-          if ($d != $lastDValue) {
-            $pattern.append("#,###,###,###,###,###,##0");
-
-            if ($d > 0) {
-              $pattern.append(".");
-              for (int $i = 0; $i < $d; $i++) {
-                $pattern.append("0");
+      right.dataType match {
+        case IntegerType =>
+          val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
+          val i = ctx.freshName("i")
+          val lastDValue =
+            ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
+          s"""
+            if ($d >= 0) {
+              $pattern.delete(0, $pattern.length());
+              if ($d != $lastDValue) {
+                $pattern.append("$defaultFormat");
+
+                if ($d > 0) {
+                  $pattern.append(".");
+                  for (int $i = 0; $i < $d; $i++) {
+                    $pattern.append("0");
+                  }
+                }
+                $lastDValue = $d;
+                $numberFormat.applyLocalizedPattern($pattern.toString());
               }
+              ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
+            } else {
+              ${ev.value} = null;
+              ${ev.isNull} = true;
             }
-            $lastDValue = $d;
-            $numberFormat.applyLocalizedPattern($pattern.toString());
-          }
-          ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
-        } else {
-          ${ev.value} = null;
-          ${ev.isNull} = true;
-        }
-       """
+           """
+        case StringType =>
+          val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""")
+          val dValue = ctx.freshName("dValue")
+          s"""
+            String $dValue = $d.toString();
+            if (!$dValue.equals($lastDValue)) {
+              $lastDValue = $dValue;
+              if ($dValue.isEmpty()) {
+                $numberFormat.applyLocalizedPattern("$defaultFormat");
+              } else {
+                $numberFormat.applyLocalizedPattern($dValue);
+              }
+            }
+            ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
+           """
+      }
     })
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cc976f6c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index f1a6f9b..aa334e0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -706,6 +706,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
       "15,159,339,180,002,773.2778")
     checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
     assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false)
+
+    checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123")
+    checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123")
+    checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4")
+    checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4")
+    checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4")
+    checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4")
+    checkEvaluation(FormatNumber(Literal(12831273.23481d),
+      Literal("###,###,###,###,###.###")), "12,831,273.235")
+    checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274")
+    checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")),
+      "123,123,324,123")
+    checkEvaluation(
+      FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)),
+        Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778")
+    checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null)
+    assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false)
+
+    checkEvaluation(FormatNumber(Literal(12332.123456), Literal("#,###,###,###,###,###,##0")),
+      "12,332")
+    checkEvaluation(FormatNumber(
+      Literal.create(null, IntegerType), Literal.create(null, StringType)), null)
+    checkEvaluation(FormatNumber(
+      Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
   }
 
   test("find in set") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org