You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/05/31 18:38:28 UTC
spark git commit: [SPARK-23900][SQL] format_number support user
specifed format as argument
Repository: spark
Updated Branches:
refs/heads/master 223df5d9d -> cc976f6cb
[SPARK-23900][SQL] format_number support user specifed format as argument
## What changes were proposed in this pull request?
`format_number` support user specifed format as argument. For example:
```sql
spark-sql> SELECT format_number(12332.123456, '##################.###');
12332.123
```
## How was this patch tested?
unit test
Author: Yuming Wang <yu...@ebay.com>
Closes #21010 from wangyum/SPARK-23900.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cc976f6c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cc976f6c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cc976f6c
Branch: refs/heads/master
Commit: cc976f6cb858adb5f52987b56dda54769915ce50
Parents: 223df5d
Author: Yuming Wang <yu...@ebay.com>
Authored: Thu May 31 11:38:23 2018 -0700
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Thu May 31 11:38:23 2018 -0700
----------------------------------------------------------------------
.../expressions/stringExpressions.scala | 142 ++++++++++++-------
.../expressions/StringExpressionsSuite.scala | 24 ++++
2 files changed, 116 insertions(+), 50 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/cc976f6c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 9823b2f..bedad7d 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1916,12 +1916,15 @@ case class Encode(value: Expression, charset: Expression)
usage = """
_FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2`
decimal places. If `expr2` is 0, the result has no decimal point or fractional part.
+ `expr2` also accept a user specified format.
This is supposed to function like MySQL's FORMAT.
""",
examples = """
Examples:
> SELECT _FUNC_(12332.123456, 4);
12,332.1235
+ > SELECT _FUNC_(12332.123456, '##################.###');
+ 12332.123
""")
case class FormatNumber(x: Expression, d: Expression)
extends BinaryExpression with ExpectsInputTypes {
@@ -1930,14 +1933,20 @@ case class FormatNumber(x: Expression, d: Expression)
override def right: Expression = d
override def dataType: DataType = StringType
override def nullable: Boolean = true
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(NumericType, TypeCollection(IntegerType, StringType))
+
+ private val defaultFormat = "#,###,###,###,###,###,##0"
// Associated with the pattern, for the last d value, and we will update the
// pattern (DecimalFormat) once the new coming d value differ with the last one.
// This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after
// serialization (numberFormat has not been updated for dValue = 0).
@transient
- private var lastDValue: Option[Int] = None
+ private var lastDIntValue: Option[Int] = None
+
+ @transient
+ private var lastDStringValue: Option[String] = None
// A cached DecimalFormat, for performance concern, we will change it
// only if the d value changed.
@@ -1950,33 +1959,49 @@ case class FormatNumber(x: Expression, d: Expression)
private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US))
override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
- val dValue = dObject.asInstanceOf[Int]
- if (dValue < 0) {
- return null
- }
-
- lastDValue match {
- case Some(last) if last == dValue =>
- // use the current pattern
- case _ =>
- // construct a new DecimalFormat only if a new dValue
- pattern.delete(0, pattern.length)
- pattern.append("#,###,###,###,###,###,##0")
-
- // decimal place
- if (dValue > 0) {
- pattern.append(".")
-
- var i = 0
- while (i < dValue) {
- i += 1
- pattern.append("0")
- }
+ right.dataType match {
+ case IntegerType =>
+ val dValue = dObject.asInstanceOf[Int]
+ if (dValue < 0) {
+ return null
}
- lastDValue = Some(dValue)
+ lastDIntValue match {
+ case Some(last) if last == dValue =>
+ // use the current pattern
+ case _ =>
+ // construct a new DecimalFormat only if a new dValue
+ pattern.delete(0, pattern.length)
+ pattern.append(defaultFormat)
+
+ // decimal place
+ if (dValue > 0) {
+ pattern.append(".")
+
+ var i = 0
+ while (i < dValue) {
+ i += 1
+ pattern.append("0")
+ }
+ }
+
+ lastDIntValue = Some(dValue)
- numberFormat.applyLocalizedPattern(pattern.toString)
+ numberFormat.applyLocalizedPattern(pattern.toString)
+ }
+ case StringType =>
+ val dValue = dObject.asInstanceOf[UTF8String].toString
+ lastDStringValue match {
+ case Some(last) if last == dValue =>
+ case _ =>
+ pattern.delete(0, pattern.length)
+ lastDStringValue = Some(dValue)
+ if (dValue.isEmpty) {
+ numberFormat.applyLocalizedPattern(defaultFormat)
+ } else {
+ numberFormat.applyLocalizedPattern(dValue)
+ }
+ }
}
x.dataType match {
@@ -2008,35 +2033,52 @@ case class FormatNumber(x: Expression, d: Expression)
// SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
// as a decimal separator.
val usLocale = "US"
- val i = ctx.freshName("i")
- val dFormat = ctx.freshName("dFormat")
- val lastDValue =
- ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
- val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
val numberFormat = ctx.addMutableState(df, "numberFormat",
v => s"""$v = new $df("", new $dfs($l.$usLocale));""")
- s"""
- if ($d >= 0) {
- $pattern.delete(0, $pattern.length());
- if ($d != $lastDValue) {
- $pattern.append("#,###,###,###,###,###,##0");
-
- if ($d > 0) {
- $pattern.append(".");
- for (int $i = 0; $i < $d; $i++) {
- $pattern.append("0");
+ right.dataType match {
+ case IntegerType =>
+ val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
+ val i = ctx.freshName("i")
+ val lastDValue =
+ ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
+ s"""
+ if ($d >= 0) {
+ $pattern.delete(0, $pattern.length());
+ if ($d != $lastDValue) {
+ $pattern.append("$defaultFormat");
+
+ if ($d > 0) {
+ $pattern.append(".");
+ for (int $i = 0; $i < $d; $i++) {
+ $pattern.append("0");
+ }
+ }
+ $lastDValue = $d;
+ $numberFormat.applyLocalizedPattern($pattern.toString());
}
+ ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
+ } else {
+ ${ev.value} = null;
+ ${ev.isNull} = true;
}
- $lastDValue = $d;
- $numberFormat.applyLocalizedPattern($pattern.toString());
- }
- ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
- } else {
- ${ev.value} = null;
- ${ev.isNull} = true;
- }
- """
+ """
+ case StringType =>
+ val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""")
+ val dValue = ctx.freshName("dValue")
+ s"""
+ String $dValue = $d.toString();
+ if (!$dValue.equals($lastDValue)) {
+ $lastDValue = $dValue;
+ if ($dValue.isEmpty()) {
+ $numberFormat.applyLocalizedPattern("$defaultFormat");
+ } else {
+ $numberFormat.applyLocalizedPattern($dValue);
+ }
+ }
+ ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
+ """
+ }
})
}
http://git-wip-us.apache.org/repos/asf/spark/blob/cc976f6c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index f1a6f9b..aa334e0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -706,6 +706,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false)
+
+ checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123")
+ checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123")
+ checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4")
+ checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4")
+ checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4")
+ checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4")
+ checkEvaluation(FormatNumber(Literal(12831273.23481d),
+ Literal("###,###,###,###,###.###")), "12,831,273.235")
+ checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274")
+ checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")),
+ "123,123,324,123")
+ checkEvaluation(
+ FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)),
+ Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778")
+ checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null)
+ assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false)
+
+ checkEvaluation(FormatNumber(Literal(12332.123456), Literal("#,###,###,###,###,###,##0")),
+ "12,332")
+ checkEvaluation(FormatNumber(
+ Literal.create(null, IntegerType), Literal.create(null, StringType)), null)
+ checkEvaluation(FormatNumber(
+ Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
}
test("find in set") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org