You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/01/20 09:23:38 UTC
[spark] branch master updated: [SPARK-28137][SQL] Data Type Formatting Functions: `to_number`
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9c02dd4 [SPARK-28137][SQL] Data Type Formatting Functions: `to_number`
9c02dd4 is described below
commit 9c02dd4035c9412ca03e5a5f4721ee223953c004
Author: Jiaan Geng <be...@163.com>
AuthorDate: Thu Jan 20 17:22:44 2022 +0800
[SPARK-28137][SQL] Data Type Formatting Functions: `to_number`
### What changes were proposed in this pull request?
Many database support the function `to_number` to convert a string to number.
The implement of `to_number` has many different between `Postgresql` ,`Oracle` and `Phoenix`.
So, this PR follows the implement of `to_number` in `Oracle` that give a strict parameter verification.
So, this PR follows the implement of `to_number` in `Phoenix` that uses BigDecimal.
This PR support the patterns for numeric formatting as follows:
Pattern | Description
-- | --
9 | digit position
0 | digit position
. (period) | decimal point (only allowed once)
, (comma) | group (thousands) separator
S | sign anchored to number (only allowed once)
$ | value with a leading dollar sign (only allowed once)
D | decimal point (only allowed once)
G | group (thousands) separator
There are some mainstream database support the syntax.
**PostgreSQL:**
https://www.postgresql.org/docs/12/functions-formatting.html
**Oracle:**
https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/TO_NUMBER.html#GUID-D4807212-AFD7-48A7-9AED-BEC3E8809866
**Vertica**
https://www.vertica.com/docs/10.0.x/HTML/Content/Authoring/SQLReferenceManual/Functions/Formatting/TO_NUMBER.htm?tocpath=SQL%20Reference%20Manual%7CSQL%20Functions%7CFormatting%20Functions%7C_____7
**Redshift**
https://docs.aws.amazon.com/redshift/latest/dg/r_TO_NUMBER.html
**DB2**
https://www.ibm.com/support/knowledgecenter/SSGU8G_14.1.0/com.ibm.sqls.doc/ids_sqs_1544.htm
**Teradata**
https://docs.teradata.com/r/kmuOwjp1zEYg98JsB8fu_A/TH2cDXBn6tala29S536nqg
**Snowflake:**
https://docs.snowflake.net/manuals/sql-reference/functions/to_decimal.html
**Exasol**
https://docs.exasol.com/sql_references/functions/alphabeticallistfunctions/to_number.htm#TO_NUMBER
**Phoenix**
http://phoenix.incubator.apache.org/language/functions.html#to_number
**Singlestore**
https://docs.singlestore.com/v7.3/reference/sql-reference/numeric-functions/to-number/
**Intersystems**
https://docs.intersystems.com/latest/csp/docbook/DocBook.UI.Page.cls?KEY=RSQL_TONUMBER
The syntax like:
> select to_number('12,454.8-', '99G999D9S');
-12454.8
### Why are the changes needed?
`to_number` is very useful for formatted currency to number conversion.
### Does this PR introduce any user-facing change?
Yes. New feature.
### How was this patch tested?
New tests
Closes #35060 from beliefer/SPARK-28137-new.
Authored-by: Jiaan Geng <be...@163.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
docs/_data/menu-sql.yaml | 2 +
docs/sql-ref-number-pattern.md | 22 ++
docs/sql-ref.md | 1 +
.../sql/catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/numberFormatExpressions.scala | 105 +++++++++
.../spark/sql/catalyst/util/NumberFormatter.scala | 243 +++++++++++++++++++++
.../spark/sql/catalyst/util/NumberUtils.scala | 189 ----------------
.../spark/sql/errors/QueryCompilationErrors.scala | 8 -
.../spark/sql/errors/QueryExecutionErrors.scala | 5 +-
.../expressions/StringExpressionsSuite.scala | 167 ++++++++++++++
...UtilsSuite.scala => NumberFormatterSuite.scala} | 154 +++++++------
.../sql-functions/sql-expression-schema.md | 3 +-
.../sql-tests/inputs/postgreSQL/numeric.sql | 18 +-
.../sql-tests/inputs/string-functions.sql | 12 +-
.../results/ansi/string-functions.sql.out | 66 +++++-
.../sql-tests/results/postgreSQL/numeric.sql.out | 76 ++++++-
.../sql-tests/results/string-functions.sql.out | 66 +++++-
17 files changed, 846 insertions(+), 292 deletions(-)
diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml
index 22e01df..7d9e6f4 100644
--- a/docs/_data/menu-sql.yaml
+++ b/docs/_data/menu-sql.yaml
@@ -79,6 +79,8 @@
url: sql-ref-datatypes.html
- text: Datetime Pattern
url: sql-ref-datetime-pattern.html
+ - text: Number Pattern
+ url: sql-ref-number-pattern.html
- text: Functions
url: sql-ref-functions.html
- text: Identifiers
diff --git a/docs/sql-ref-number-pattern.md b/docs/sql-ref-number-pattern.md
new file mode 100644
index 0000000..dc7d696
--- /dev/null
+++ b/docs/sql-ref-number-pattern.md
@@ -0,0 +1,22 @@
+---
+layout: global
+title: Number patterns
+displayTitle: Number Patterns for Formatting and Parsing
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+TODO: Add the content of Number Patterns for Formatting and Parsing
diff --git a/docs/sql-ref.md b/docs/sql-ref.md
index 32e7e96..026d072 100644
--- a/docs/sql-ref.md
+++ b/docs/sql-ref.md
@@ -25,6 +25,7 @@ Spark SQL is Apache Spark's module for working with structured data. This guide
* [ANSI Compliance](sql-ref-ansi-compliance.html)
* [Data Types](sql-ref-datatypes.html)
* [Datetime Pattern](sql-ref-datetime-pattern.html)
+ * [Number Pattern](sql-ref-number-pattern.html)
* [Functions](sql-ref-functions.html)
* [Built-in Functions](sql-ref-functions-builtin.html)
* [Scalar User-Defined Functions (UDFs)](sql-ref-functions-udf-scalar.html)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index c995ff8..e98759b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -474,6 +474,7 @@ object FunctionRegistry {
expression[FindInSet]("find_in_set"),
expression[FormatNumber]("format_number"),
expression[FormatString]("format_string"),
+ expression[ToNumber]("to_number"),
expression[GetJsonObject]("get_json_object"),
expression[InitCap]("initcap"),
expression[StringInstr]("instr"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
new file mode 100644
index 0000000..e29a425
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.util.Locale
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
+import org.apache.spark.sql.catalyst.util.NumberFormatter
+import org.apache.spark.sql.types.{DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A function that converts string to numeric.
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(strExpr, formatExpr) - Convert `strExpr` to a number based on the `formatExpr`.
+ The format can consist of the following characters:
+ '0' or '9': digit position
+ '.' or 'D': decimal point (only allowed once)
+ ',' or 'G': group (thousands) separator
+ '-' or 'S': sign anchored to number (only allowed once)
+ '$': value with a leading dollar sign (only allowed once)
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_('454', '999');
+ 454
+ > SELECT _FUNC_('454.00', '000D00');
+ 454.00
+ > SELECT _FUNC_('12,454', '99G999');
+ 12454
+ > SELECT _FUNC_('$78.12', '$99.99');
+ 78.12
+ > SELECT _FUNC_('12,454.8-', '99G999D9S');
+ -12454.8
+ """,
+ since = "3.3.0",
+ group = "string_funcs")
+case class ToNumber(left: Expression, right: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
+
+ private lazy val numberFormat = right.eval().toString.toUpperCase(Locale.ROOT)
+ private lazy val numberFormatter = new NumberFormatter(numberFormat)
+
+ override def dataType: DataType = numberFormatter.parsedDecimalType
+
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val inputTypeCheck = super.checkInputDataTypes()
+ if (inputTypeCheck.isSuccess) {
+ if (right.foldable) {
+ numberFormatter.check()
+ } else {
+ TypeCheckResult.TypeCheckFailure(s"Format expression must be foldable, but got $right")
+ }
+ } else {
+ inputTypeCheck
+ }
+ }
+
+ override def prettyName: String = "to_number"
+
+ override def nullSafeEval(string: Any, format: Any): Any = {
+ val input = string.asInstanceOf[UTF8String]
+ numberFormatter.parse(input)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val builder =
+ ctx.addReferenceObj("builder", numberFormatter, classOf[NumberFormatter].getName)
+ val eval = left.genCode(ctx)
+ ev.copy(code =
+ code"""
+ |${eval.code}
+ |boolean ${ev.isNull} = ${eval.isNull};
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ |if (!${ev.isNull}) {
+ | ${ev.value} = $builder.parse(${eval.value});
+ |}
+ """.stripMargin)
+ }
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): ToNumber = copy(left = newLeft, right = newRight)
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberFormatter.scala
new file mode 100644
index 0000000..a14aceb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberFormatter.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import java.math.BigDecimal
+import java.text.{DecimalFormat, ParsePosition}
+import java.util.Locale
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types.{Decimal, DecimalType}
+import org.apache.spark.unsafe.types.UTF8String
+
+object NumberFormatter {
+ final val POINT_SIGN = '.'
+ final val POINT_LETTER = 'D'
+ final val COMMA_SIGN = ','
+ final val COMMA_LETTER = 'G'
+ final val MINUS_SIGN = '-'
+ final val MINUS_LETTER = 'S'
+ final val DOLLAR_SIGN = '$'
+ final val NINE_DIGIT = '9'
+ final val ZERO_DIGIT = '0'
+ final val POUND_SIGN = '#'
+
+ final val COMMA_SIGN_STRING = COMMA_SIGN.toString
+ final val POUND_SIGN_STRING = POUND_SIGN.toString
+
+ final val SIGN_SET = Set(POINT_SIGN, COMMA_SIGN, MINUS_SIGN, DOLLAR_SIGN)
+}
+
+class NumberFormatter(originNumberFormat: String, isParse: Boolean = true) extends Serializable {
+ import NumberFormatter._
+
+ protected val normalizedNumberFormat = normalize(originNumberFormat)
+
+ private val transformedFormat = transform(normalizedNumberFormat)
+
+ private lazy val numberDecimalFormat = {
+ val decimalFormat = new DecimalFormat(transformedFormat)
+ decimalFormat.setParseBigDecimal(true)
+ decimalFormat
+ }
+
+ private lazy val (precision, scale) = {
+ val formatSplits = normalizedNumberFormat.split(POINT_SIGN).map(_.filterNot(isSign))
+ assert(formatSplits.length <= 2)
+ val precision = formatSplits.map(_.length).sum
+ val scale = if (formatSplits.length == 2) formatSplits.last.length else 0
+ (precision, scale)
+ }
+
+ def parsedDecimalType: DecimalType = DecimalType(precision, scale)
+
+ /**
+ * DecimalFormat provides '#' and '0' as placeholder of digit, ',' as grouping separator,
+ * '.' as decimal separator, '-' as minus, '$' as dollar, but not '9', 'G', 'D', 'S'. So we need
+ * replace them show below:
+ * 1. '9' -> '#'
+ * 2. 'G' -> ','
+ * 3. 'D' -> '.'
+ * 4. 'S' -> '-'
+ *
+ * Note: When calling format, we must preserve the digits after decimal point, so the digits
+ * after decimal point should be replaced as '0'. For example: '999.9' will be normalized as
+ * '###.0' and '999.99' will be normalized as '###.00', so if the input is 454, the format
+ * output will be 454.0 and 454.00 respectively.
+ *
+ * @param format number format string
+ * @return normalized number format string
+ */
+ private def normalize(format: String): String = {
+ var notFindDecimalPoint = true
+ val normalizedFormat = format.toUpperCase(Locale.ROOT).map {
+ case NINE_DIGIT if notFindDecimalPoint => POUND_SIGN
+ case ZERO_DIGIT if isParse && notFindDecimalPoint => POUND_SIGN
+ case NINE_DIGIT if !notFindDecimalPoint => ZERO_DIGIT
+ case COMMA_LETTER => COMMA_SIGN
+ case POINT_LETTER | POINT_SIGN =>
+ notFindDecimalPoint = false
+ POINT_SIGN
+ case MINUS_LETTER => MINUS_SIGN
+ case other => other
+ }
+ // If the comma is at the beginning or end of number format, then DecimalFormat will be
+ // invalid. For example, "##,###," or ",###,###" for DecimalFormat is invalid, so we must use
+ // "##,###" or "###,###".
+ normalizedFormat.stripPrefix(COMMA_SIGN_STRING).stripSuffix(COMMA_SIGN_STRING)
+ }
+
+ private def isSign(c: Char): Boolean = {
+ SIGN_SET.contains(c)
+ }
+
+ private def transform(format: String): String = {
+ if (format.contains(MINUS_SIGN)) {
+ // For example: '#.######' represents a positive number,
+ // but '#.######;#.######-' represents a negative number.
+ val positiveFormatString = format.replaceAll("-", "")
+ s"$positiveFormatString;$format"
+ } else {
+ format
+ }
+ }
+
+ def check(): TypeCheckResult = {
+ def invalidSignPosition(c: Char): Boolean = {
+ val signIndex = normalizedNumberFormat.indexOf(c)
+ signIndex > 0 && signIndex < normalizedNumberFormat.length - 1
+ }
+
+ def multipleSignInNumberFormatError(message: String): String = {
+ s"At most one $message is allowed in the number format: '$originNumberFormat'"
+ }
+
+ def nonFistOrLastCharInNumberFormatError(message: String): String = {
+ s"$message must be the first or last char in the number format: '$originNumberFormat'"
+ }
+
+ if (normalizedNumberFormat.length == 0) {
+ TypeCheckResult.TypeCheckFailure("Number format cannot be empty")
+ } else if (normalizedNumberFormat.count(_ == POINT_SIGN) > 1) {
+ TypeCheckResult.TypeCheckFailure(
+ multipleSignInNumberFormatError(s"'$POINT_LETTER' or '$POINT_SIGN'"))
+ } else if (normalizedNumberFormat.count(_ == MINUS_SIGN) > 1) {
+ TypeCheckResult.TypeCheckFailure(
+ multipleSignInNumberFormatError(s"'$MINUS_LETTER' or '$MINUS_SIGN'"))
+ } else if (normalizedNumberFormat.count(_ == DOLLAR_SIGN) > 1) {
+ TypeCheckResult.TypeCheckFailure(multipleSignInNumberFormatError(s"'$DOLLAR_SIGN'"))
+ } else if (invalidSignPosition(MINUS_SIGN)) {
+ TypeCheckResult.TypeCheckFailure(
+ nonFistOrLastCharInNumberFormatError(s"'$MINUS_LETTER' or '$MINUS_SIGN'"))
+ } else if (invalidSignPosition(DOLLAR_SIGN)) {
+ TypeCheckResult.TypeCheckFailure(
+ nonFistOrLastCharInNumberFormatError(s"'$DOLLAR_SIGN'"))
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ /**
+ * Convert string to numeric based on the given number format.
+ * The format can consist of the following characters:
+ * '0' or '9': digit position
+ * '.' or 'D': decimal point (only allowed once)
+ * ',' or 'G': group (thousands) separator
+ * '-' or 'S': sign anchored to number (only allowed once)
+ * '$': value with a leading dollar sign (only allowed once)
+ *
+ * @param input the string need to converted
+ * @return decimal obtained from string parsing
+ */
+ def parse(input: UTF8String): Decimal = {
+ val inputStr = input.toString.trim
+ val inputSplits = inputStr.split(POINT_SIGN)
+ assert(inputSplits.length <= 2)
+ if (inputSplits.length == 1) {
+ if (inputStr.filterNot(isSign).length > precision - scale) {
+ throw QueryExecutionErrors.invalidNumberFormatError(input, originNumberFormat)
+ }
+ } else if (inputSplits(0).filterNot(isSign).length > precision - scale ||
+ inputSplits(1).filterNot(isSign).length > scale) {
+ throw QueryExecutionErrors.invalidNumberFormatError(input, originNumberFormat)
+ }
+
+ try {
+ val number = numberDecimalFormat.parse(inputStr, new ParsePosition(0))
+ assert(number.isInstanceOf[BigDecimal])
+ Decimal(number.asInstanceOf[BigDecimal])
+ } catch {
+ case _: IllegalArgumentException =>
+ throw QueryExecutionErrors.invalidNumberFormatError(input, originNumberFormat)
+ }
+ }
+
+ /**
+ * Convert numeric to string based on the given number format.
+ * The format can consist of the following characters:
+ * '9': digit position (can be dropped if insignificant)
+ * '0': digit position (will not be dropped, even if insignificant)
+ * '.' or 'D': decimal point (only allowed once)
+ * ',' or 'G': group (thousands) separator
+ * '-' or 'S': sign anchored to number (only allowed once)
+ * '$': value with a leading dollar sign (only allowed once)
+ *
+ * @param input the decimal to format
+ * @param numberFormat the format string
+ * @return The string after formatting input decimal
+ */
+ def format(input: Decimal): String = {
+ val bigDecimal = input.toJavaBigDecimal
+ val decimalPlainStr = bigDecimal.toPlainString
+ if (decimalPlainStr.length > transformedFormat.length) {
+ transformedFormat.replaceAll("0", POUND_SIGN_STRING)
+ } else {
+ var resultStr = numberDecimalFormat.format(bigDecimal)
+ // Since we trimmed the comma at the beginning or end of number format in function
+ // `normalize`, we restore the comma to the result here.
+ // For example, if the specified number format is "99,999," or ",999,999", function
+ // `normalize` normalize them to "##,###" or "###,###".
+ // new DecimalFormat("##,###").parse(12454) and new DecimalFormat("###,###").parse(124546)
+ // will return "12,454" and "124,546" respectively. So we add ',' at the end and head of
+ // the result, then the final output are "12,454," or ",124,546".
+ if (originNumberFormat.last == COMMA_SIGN || originNumberFormat.last == COMMA_LETTER) {
+ resultStr = resultStr + COMMA_SIGN
+ }
+ if (originNumberFormat.charAt(0) == COMMA_SIGN ||
+ originNumberFormat.charAt(0) == COMMA_LETTER) {
+ resultStr = COMMA_SIGN + resultStr
+ }
+
+ resultStr
+ }
+ }
+}
+
+// Visible for testing
+class TestNumberFormatter(originNumberFormat: String, isParse: Boolean = true)
+ extends NumberFormatter(originNumberFormat, isParse) {
+ def checkWithException(): Unit = {
+ check() match {
+ case TypeCheckResult.TypeCheckFailure(message) =>
+ throw new AnalysisException(message)
+ case _ =>
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberUtils.scala
deleted file mode 100644
index 6efde2a..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberUtils.scala
+++ /dev/null
@@ -1,189 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.util
-
-import java.math.BigDecimal
-import java.text.{DecimalFormat, NumberFormat, ParsePosition}
-import java.util.Locale
-
-import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
-import org.apache.spark.sql.types.Decimal
-import org.apache.spark.unsafe.types.UTF8String
-
-object NumberUtils {
-
- private val pointSign = '.'
- private val letterPointSign = 'D'
- private val commaSign = ','
- private val letterCommaSign = 'G'
- private val minusSign = '-'
- private val letterMinusSign = 'S'
- private val dollarSign = '$'
-
- private val commaSignStr = commaSign.toString
-
- private def normalize(format: String): String = {
- var notFindDecimalPoint = true
- val normalizedFormat = format.toUpperCase(Locale.ROOT).map {
- case '9' if notFindDecimalPoint => '#'
- case '9' if !notFindDecimalPoint => '0'
- case `letterPointSign` =>
- notFindDecimalPoint = false
- pointSign
- case `letterCommaSign` => commaSign
- case `letterMinusSign` => minusSign
- case `pointSign` =>
- notFindDecimalPoint = false
- pointSign
- case other => other
- }
- // If the comma is at the beginning or end of number format, then DecimalFormat will be invalid.
- // For example, "##,###," or ",###,###" for DecimalFormat is invalid, so we must use "##,###"
- // or "###,###".
- normalizedFormat.stripPrefix(commaSignStr).stripSuffix(commaSignStr)
- }
-
- private def isSign(c: Char): Boolean = {
- Set(pointSign, commaSign, minusSign, dollarSign).contains(c)
- }
-
- private def transform(format: String): String = {
- if (format.contains(minusSign)) {
- val positiveFormatString = format.replaceAll("-", "")
- s"$positiveFormatString;$format"
- } else {
- format
- }
- }
-
- private def check(normalizedFormat: String, numberFormat: String) = {
- def invalidSignPosition(format: String, c: Char): Boolean = {
- val signIndex = format.indexOf(c)
- signIndex > 0 && signIndex < format.length - 1
- }
-
- if (normalizedFormat.count(_ == pointSign) > 1) {
- throw QueryCompilationErrors.multipleSignInNumberFormatError(
- s"'$letterPointSign' or '$pointSign'", numberFormat)
- } else if (normalizedFormat.count(_ == minusSign) > 1) {
- throw QueryCompilationErrors.multipleSignInNumberFormatError(
- s"'$letterMinusSign' or '$minusSign'", numberFormat)
- } else if (normalizedFormat.count(_ == dollarSign) > 1) {
- throw QueryCompilationErrors.multipleSignInNumberFormatError(s"'$dollarSign'", numberFormat)
- } else if (invalidSignPosition(normalizedFormat, minusSign)) {
- throw QueryCompilationErrors.nonFistOrLastCharInNumberFormatError(
- s"'$letterMinusSign' or '$minusSign'", numberFormat)
- } else if (invalidSignPosition(normalizedFormat, dollarSign)) {
- throw QueryCompilationErrors.nonFistOrLastCharInNumberFormatError(
- s"'$dollarSign'", numberFormat)
- }
- }
-
- /**
- * Convert string to numeric based on the given number format.
- * The format can consist of the following characters:
- * '9': digit position (can be dropped if insignificant)
- * '0': digit position (will not be dropped, even if insignificant)
- * '.': decimal point (only allowed once)
- * ',': group (thousands) separator
- * 'S': sign anchored to number (uses locale)
- * 'D': decimal point (uses locale)
- * 'G': group separator (uses locale)
- * '$': specifies that the input value has a leading $ (Dollar) sign.
- *
- * @param input the string need to converted
- * @param numberFormat the given number format
- * @return decimal obtained from string parsing
- */
- def parse(input: UTF8String, numberFormat: String): Decimal = {
- val normalizedFormat = normalize(numberFormat)
- check(normalizedFormat, numberFormat)
-
- val precision = normalizedFormat.filterNot(isSign).length
- val formatSplits = normalizedFormat.split(pointSign)
- val scale = if (formatSplits.length == 1) {
- 0
- } else {
- formatSplits(1).filterNot(isSign).length
- }
- val transformedFormat = transform(normalizedFormat)
- val numberFormatInstance = NumberFormat.getInstance()
- val numberDecimalFormat = numberFormatInstance.asInstanceOf[DecimalFormat]
- numberDecimalFormat.setParseBigDecimal(true)
- numberDecimalFormat.applyPattern(transformedFormat)
- val inputStr = input.toString.trim
- val inputSplits = inputStr.split(pointSign)
- if (inputSplits.length == 1) {
- if (inputStr.filterNot(isSign).length > precision - scale) {
- throw QueryExecutionErrors.invalidNumberFormatError(numberFormat)
- }
- } else if (inputSplits(0).filterNot(isSign).length > precision - scale ||
- inputSplits(1).filterNot(isSign).length > scale) {
- throw QueryExecutionErrors.invalidNumberFormatError(numberFormat)
- }
- val number = numberDecimalFormat.parse(inputStr, new ParsePosition(0))
- Decimal(number.asInstanceOf[BigDecimal])
- }
-
- /**
- * Convert numeric to string based on the given number format.
- * The format can consist of the following characters:
- * '9': digit position (can be dropped if insignificant)
- * '0': digit position (will not be dropped, even if insignificant)
- * '.': decimal point (only allowed once)
- * ',': group (thousands) separator
- * 'S': sign anchored to number (uses locale)
- * 'D': decimal point (uses locale)
- * 'G': group separator (uses locale)
- * '$': specifies that the input value has a leading $ (Dollar) sign.
- *
- * @param input the decimal to format
- * @param numberFormat the format string
- * @return The string after formatting input decimal
- */
- def format(input: Decimal, numberFormat: String): String = {
- val normalizedFormat = normalize(numberFormat)
- check(normalizedFormat, numberFormat)
-
- val transformedFormat = transform(normalizedFormat)
- val bigDecimal = input.toJavaBigDecimal
- val decimalPlainStr = bigDecimal.toPlainString
- if (decimalPlainStr.length > transformedFormat.length) {
- transformedFormat.replaceAll("0", "#")
- } else {
- val decimalFormat = new DecimalFormat(transformedFormat)
- var resultStr = decimalFormat.format(bigDecimal)
- // Since we trimmed the comma at the beginning or end of number format in function
- // `normalize`, we restore the comma to the result here.
- // For example, if the specified number format is "99,999," or ",999,999", function
- // `normalize` normalize them to "##,###" or "###,###".
- // new DecimalFormat("##,###").parse(12454) and new DecimalFormat("###,###").parse(124546)
- // will return "12,454" and "124,546" respectively. So we add ',' at the end and head of
- // the result, then the final output are "12,454," or ",124,546".
- if (numberFormat.last == commaSign || numberFormat.last == letterCommaSign) {
- resultStr = resultStr + commaSign
- }
- if (numberFormat.charAt(0) == commaSign || numberFormat.charAt(0) == letterCommaSign) {
- resultStr = commaSign + resultStr
- }
-
- resultStr
- }
- }
-
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index fcbcb54..14f8053 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -2380,12 +2380,4 @@ object QueryCompilationErrors {
def tableNotSupportTimeTravelError(tableName: Identifier): UnsupportedOperationException = {
new UnsupportedOperationException(s"Table $tableName does not support time travel.")
}
-
- def multipleSignInNumberFormatError(message: String, numberFormat: String): Throwable = {
- new AnalysisException(s"Multiple $message in '$numberFormat'")
- }
-
- def nonFistOrLastCharInNumberFormatError(message: String, numberFormat: String): Throwable = {
- new AnalysisException(s"$message must be the first or last char in '$numberFormat'")
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index ede4c39..975d748 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1935,9 +1935,8 @@ object QueryExecutionErrors {
s" to at least $numWrittenParts.")
}
- def invalidNumberFormatError(format: String): Throwable = {
+ def invalidNumberFormatError(input: UTF8String, format: String): Throwable = {
new IllegalArgumentException(
- s"Format '$format' used for parsing string to number or " +
- "formatting number to string is invalid")
+ s"The input string '$input' does not match the given number format: '$format'")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 443a94b..b54d0a6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.internal.SQLConf
@@ -888,6 +889,172 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
}
+ test("ToNumber") {
+ ToNumber(Literal("454"), Literal("")).checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("Number format cannot be empty"))
+ }
+ ToNumber(Literal("454"), NonFoldableLiteral.create("999", StringType))
+ .checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("Format expression must be foldable"))
+ }
+
+ // Test '0' and '9'
+
+ Seq("454", "054", "54", "450").foreach { input =>
+ val invalidFormat1 = 0.until(input.length - 1).map(_ => '0').mkString
+ val invalidFormat2 = 0.until(input.length - 2).map(_ => '0').mkString
+ val invalidFormat3 = 0.until(input.length - 1).map(_ => '9').mkString
+ val invalidFormat4 = 0.until(input.length - 2).map(_ => '9').mkString
+ Seq(invalidFormat1, invalidFormat2, invalidFormat3, invalidFormat4)
+ .filter(_.nonEmpty).foreach { format =>
+ checkExceptionInExpression[IllegalArgumentException](
+ ToNumber(Literal(input), Literal(format)),
+ s"The input string '$input' does not match the given number format: '$format'")
+ }
+
+ val format1 = 0.until(input.length).map(_ => '0').mkString
+ val format2 = 0.until(input.length).map(_ => '9').mkString
+ val format3 = 0.until(input.length).map(i => i % 2 * 9).mkString
+ val format4 = 0.until(input.length + 1).map(_ => '0').mkString
+ val format5 = 0.until(input.length + 1).map(_ => '9').mkString
+ val format6 = 0.until(input.length + 1).map(i => i % 2 * 9).mkString
+ Seq(format1, format2, format3, format4, format5, format6).foreach { format =>
+ checkEvaluation(ToNumber(Literal(input), Literal(format)), Decimal(input))
+ }
+ }
+
+ // Test '.' and 'D'
+ checkExceptionInExpression[IllegalArgumentException](
+ ToNumber(Literal("454.2"), Literal("999")),
+ "The input string '454.2' does not match the given number format: '999'")
+ Seq("999.9", "000.0", "99.99", "00.00", "0000.0", "9999.9", "00.000", "99.999")
+ .foreach { format =>
+ checkExceptionInExpression[IllegalArgumentException](
+ ToNumber(Literal("454.23"), Literal(format)),
+ s"The input string '454.23' does not match the given number format: '$format'")
+ val format2 = format.replace('.', 'D')
+ checkExceptionInExpression[IllegalArgumentException](
+ ToNumber(Literal("454.23"), Literal(format2)),
+ s"The input string '454.23' does not match the given number format: '$format2'")
+ }
+
+ Seq(
+ ("454.2", "000.0") -> Decimal(454.2),
+ ("454.23", "000.00") -> Decimal(454.23),
+ ("454.2", "000.00") -> Decimal(454.2),
+ ("454.0", "000.0") -> Decimal(454),
+ ("454.00", "000.00") -> Decimal(454),
+ (".4542", ".0000") -> Decimal(0.4542),
+ ("4542.", "0000.") -> Decimal(4542)
+ ).foreach { case ((str, format), expected) =>
+ checkEvaluation(ToNumber(Literal(str), Literal(format)), expected)
+ val format2 = format.replace('.', 'D')
+ checkEvaluation(ToNumber(Literal(str), Literal(format2)), expected)
+ val format3 = format.replace('0', '9')
+ checkEvaluation(ToNumber(Literal(str), Literal(format3)), expected)
+ val format4 = format3.replace('.', 'D')
+ checkEvaluation(ToNumber(Literal(str), Literal(format4)), expected)
+ }
+
+ Seq("999.9.9", "999D9D9", "999.9D9", "999D9.9").foreach { str =>
+ ToNumber(Literal("454.3.2"), Literal(str)).checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains(s"At most one 'D' or '.' is allowed in the number format: '$str'"))
+ }
+ }
+
+ // Test ',' and 'G'
+ checkExceptionInExpression[IllegalArgumentException](
+ ToNumber(Literal("123,456"), Literal("9G9")),
+ "The input string '123,456' does not match the given number format: '9G9'")
+ checkExceptionInExpression[IllegalArgumentException](
+ ToNumber(Literal("123,456,789"), Literal("999,999")),
+ "The input string '123,456,789' does not match the given number format: '999,999'")
+
+ Seq(
+ ("12,454", "99,999") -> Decimal(12454),
+ ("12,454", "99,999,999") -> Decimal(12454),
+ ("12,454,367", "99,999,999") -> Decimal(12454367),
+ ("12,454,", "99,999,") -> Decimal(12454),
+ (",454,367", ",999,999") -> Decimal(454367),
+ (",454,367", "999,999") -> Decimal(454367)
+ ).foreach { case ((str, format), expected) =>
+ checkEvaluation(ToNumber(Literal(str), Literal(format)), expected)
+ val format2 = format.replace(',', 'G')
+ checkEvaluation(ToNumber(Literal(str), Literal(format2)), expected)
+ val format3 = format.replace('9', '0')
+ checkEvaluation(ToNumber(Literal(str), Literal(format3)), expected)
+ val format4 = format3.replace(',', 'G')
+ checkEvaluation(ToNumber(Literal(str), Literal(format4)), expected)
+ val format5 = s"${format}9"
+ checkEvaluation(ToNumber(Literal(str), Literal(format5)), expected)
+ val format6 = s"${format}0"
+ checkEvaluation(ToNumber(Literal(str), Literal(format6)), expected)
+ val format7 = s"9${format}9"
+ checkEvaluation(ToNumber(Literal(str), Literal(format7)), expected)
+ val format8 = s"0${format}0"
+ checkEvaluation(ToNumber(Literal(str), Literal(format8)), expected)
+ val format9 = s"${format3}9"
+ checkEvaluation(ToNumber(Literal(str), Literal(format9)), expected)
+ val format10 = s"${format3}0"
+ checkEvaluation(ToNumber(Literal(str), Literal(format10)), expected)
+ val format11 = s"9${format3}9"
+ checkEvaluation(ToNumber(Literal(str), Literal(format11)), expected)
+ val format12 = s"0${format3}0"
+ checkEvaluation(ToNumber(Literal(str), Literal(format12)), expected)
+ }
+
+ // Test '$'
+ Seq(
+ ("$78.12", "$99.99") -> Decimal(78.12),
+ ("$78.12", "$00.00") -> Decimal(78.12),
+ ("78.12$", "99.99$") -> Decimal(78.12),
+ ("78.12$", "00.00$") -> Decimal(78.12)
+ ).foreach { case ((str, format), expected) =>
+ checkEvaluation(ToNumber(Literal(str), Literal(format)), expected)
+ }
+
+ ToNumber(Literal("$78$.12"), Literal("$99$.99")).checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("At most one '$' is allowed in the number format: '$99$.99'"))
+ }
+ ToNumber(Literal("78$.12"), Literal("99$.99")).checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("'$' must be the first or last char in the number format: '99$.99'"))
+ }
+
+ // Test '-' and 'S'
+ Seq(
+ ("454-", "999-") -> Decimal(-454),
+ ("-454", "-999") -> Decimal(-454),
+ ("12,454.8-", "99G999D9-") -> Decimal(-12454.8),
+ ("00,454.8-", "99G999.9-") -> Decimal(-454.8)
+ ).foreach { case ((str, format), expected) =>
+ checkEvaluation(ToNumber(Literal(str), Literal(format)), expected)
+ val format2 = format.replace('9', '0')
+ checkEvaluation(ToNumber(Literal(str), Literal(format2)), expected)
+ val format3 = format.replace('-', 'S')
+ checkEvaluation(ToNumber(Literal(str), Literal(format3)), expected)
+ val format4 = format2.replace('-', 'S')
+ checkEvaluation(ToNumber(Literal(str), Literal(format4)), expected)
+ }
+
+ ToNumber(Literal("454.3--"), Literal("999D9SS")).checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("At most one 'S' or '-' is allowed in the number format: '999D9SS'"))
+ }
+
+ Seq("9S99", "9-99").foreach { str =>
+ ToNumber(Literal("-454"), Literal(str)).checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains(
+ s"'S' or '-' must be the first or last char in the number format: '$str'"))
+ }
+ }
+ }
+
test("find in set") {
checkEvaluation(
FindInSet(Literal.create(null, StringType), Literal.create(null, StringType)), null)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberFormatterSuite.scala
similarity index 65%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberUtilsSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberFormatterSuite.scala
index 66a17dc..81264f4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberFormatterSuite.scala
@@ -19,43 +19,37 @@ package org.apache.spark.sql.catalyst.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.util.NumberUtils.{format, parse}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.UTF8String
-class NumberUtilsSuite extends SparkFunSuite {
+class NumberFormatterSuite extends SparkFunSuite {
- private def failParseWithInvalidInput(
- input: UTF8String, numberFormat: String, errorMsg: String): Unit = {
- val e = intercept[IllegalArgumentException](parse(input, numberFormat))
+ private def invalidNumberFormat(numberFormat: String, errorMsg: String): Unit = {
+ val testNumberFormatter = new TestNumberFormatter(numberFormat)
+ val e = intercept[AnalysisException](testNumberFormatter.checkWithException())
assert(e.getMessage.contains(errorMsg))
}
- private def failParseWithAnalysisException(
+ private def failParseWithInvalidInput(
input: UTF8String, numberFormat: String, errorMsg: String): Unit = {
- val e = intercept[AnalysisException](parse(input, numberFormat))
- assert(e.getMessage.contains(errorMsg))
- }
-
- private def failFormatWithAnalysisException(
- input: Decimal, numberFormat: String, errorMsg: String): Unit = {
- val e = intercept[AnalysisException](format(input, numberFormat))
+ val testNumberFormatter = new TestNumberFormatter(numberFormat)
+ val e = intercept[IllegalArgumentException](testNumberFormatter.parse(input))
assert(e.getMessage.contains(errorMsg))
}
test("parse") {
- failParseWithInvalidInput(UTF8String.fromString("454"), "",
- "Format '' used for parsing string to number or formatting number to string is invalid")
+ invalidNumberFormat("", "Number format cannot be empty")
// Test '9' and '0'
failParseWithInvalidInput(UTF8String.fromString("454"), "9",
- "Format '9' used for parsing string to number or formatting number to string is invalid")
+ "The input string '454' does not match the given number format: '9'")
failParseWithInvalidInput(UTF8String.fromString("454"), "99",
- "Format '99' used for parsing string to number or formatting number to string is invalid")
+ "The input string '454' does not match the given number format: '99'")
Seq(
("454", "999") -> Decimal(454),
("054", "999") -> Decimal(54),
+ ("54", "999") -> Decimal(54),
("404", "999") -> Decimal(404),
("450", "999") -> Decimal(450),
("454", "9999") -> Decimal(454),
@@ -63,17 +57,20 @@ class NumberUtilsSuite extends SparkFunSuite {
("404", "9999") -> Decimal(404),
("450", "9999") -> Decimal(450)
).foreach { case ((str, format), expected) =>
- assert(parse(UTF8String.fromString(str), format) === expected)
+ val builder = new TestNumberFormatter(format)
+ builder.check()
+ assert(builder.parse(UTF8String.fromString(str)) === expected)
}
failParseWithInvalidInput(UTF8String.fromString("454"), "0",
- "Format '0' used for parsing string to number or formatting number to string is invalid")
+ "The input string '454' does not match the given number format: '0'")
failParseWithInvalidInput(UTF8String.fromString("454"), "00",
- "Format '00' used for parsing string to number or formatting number to string is invalid")
+ "The input string '454' does not match the given number format: '00'")
Seq(
("454", "000") -> Decimal(454),
("054", "000") -> Decimal(54),
+ ("54", "000") -> Decimal(54),
("404", "000") -> Decimal(404),
("450", "000") -> Decimal(450),
("454", "0000") -> Decimal(454),
@@ -81,14 +78,16 @@ class NumberUtilsSuite extends SparkFunSuite {
("404", "0000") -> Decimal(404),
("450", "0000") -> Decimal(450)
).foreach { case ((str, format), expected) =>
- assert(parse(UTF8String.fromString(str), format) === expected)
+ val builder = new TestNumberFormatter(format)
+ builder.check()
+ assert(builder.parse(UTF8String.fromString(str)) === expected)
}
// Test '.' and 'D'
failParseWithInvalidInput(UTF8String.fromString("454.2"), "999",
- "Format '999' used for parsing string to number or formatting number to string is invalid")
+ "The input string '454.2' does not match the given number format: '999'")
failParseWithInvalidInput(UTF8String.fromString("454.23"), "999.9",
- "Format '999.9' used for parsing string to number or formatting number to string is invalid")
+ "The input string '454.23' does not match the given number format: '999.9'")
Seq(
("454.2", "999.9") -> Decimal(454.2),
@@ -116,17 +115,19 @@ class NumberUtilsSuite extends SparkFunSuite {
("4542.", "9999D") -> Decimal(4542),
("4542.", "0000D") -> Decimal(4542)
).foreach { case ((str, format), expected) =>
- assert(parse(UTF8String.fromString(str), format) === expected)
+ val builder = new TestNumberFormatter(format)
+ builder.check()
+ assert(builder.parse(UTF8String.fromString(str)) === expected)
}
- failParseWithAnalysisException(UTF8String.fromString("454.3.2"), "999.9.9",
- "Multiple 'D' or '.' in '999.9.9'")
- failParseWithAnalysisException(UTF8String.fromString("454.3.2"), "999D9D9",
- "Multiple 'D' or '.' in '999D9D9'")
- failParseWithAnalysisException(UTF8String.fromString("454.3.2"), "999.9D9",
- "Multiple 'D' or '.' in '999.9D9'")
- failParseWithAnalysisException(UTF8String.fromString("454.3.2"), "999D9.9",
- "Multiple 'D' or '.' in '999D9.9'")
+ invalidNumberFormat(
+ "999.9.9", "At most one 'D' or '.' is allowed in the number format: '999.9.9'")
+ invalidNumberFormat(
+ "999D9D9", "At most one 'D' or '.' is allowed in the number format: '999D9D9'")
+ invalidNumberFormat(
+ "999.9D9", "At most one 'D' or '.' is allowed in the number format: '999.9D9'")
+ invalidNumberFormat(
+ "999D9.9", "At most one 'D' or '.' is allowed in the number format: '999D9.9'")
// Test ',' and 'G'
Seq(
@@ -145,9 +146,15 @@ class NumberUtilsSuite extends SparkFunSuite {
(",454,367", ",999,999") -> Decimal(454367),
(",454,367", ",000,000") -> Decimal(454367),
(",454,367", "G999G999") -> Decimal(454367),
- (",454,367", "G000G000") -> Decimal(454367)
+ (",454,367", "G000G000") -> Decimal(454367),
+ (",454,367", "999,999") -> Decimal(454367),
+ (",454,367", "000,000") -> Decimal(454367),
+ (",454,367", "999G999") -> Decimal(454367),
+ (",454,367", "000G000") -> Decimal(454367)
).foreach { case ((str, format), expected) =>
- assert(parse(UTF8String.fromString(str), format) === expected)
+ val builder = new TestNumberFormatter(format)
+ builder.check()
+ assert(builder.parse(UTF8String.fromString(str)) === expected)
}
// Test '$'
@@ -157,13 +164,14 @@ class NumberUtilsSuite extends SparkFunSuite {
("78.12$", "99.99$") -> Decimal(78.12),
("78.12$", "00.00$") -> Decimal(78.12)
).foreach { case ((str, format), expected) =>
- assert(parse(UTF8String.fromString(str), format) === expected)
+ val builder = new TestNumberFormatter(format)
+ builder.check()
+ assert(builder.parse(UTF8String.fromString(str)) === expected)
}
- failParseWithAnalysisException(UTF8String.fromString("78$.12"), "99$.99",
- "'$' must be the first or last char in '99$.99'")
- failParseWithAnalysisException(UTF8String.fromString("$78.12$"), "$99.99$",
- "Multiple '$' in '$99.99$'")
+ invalidNumberFormat(
+ "99$.99", "'$' must be the first or last char in the number format: '99$.99'")
+ invalidNumberFormat("$99.99$", "At most one '$' is allowed in the number format: '$99.99$'")
// Test '-' and 'S'
Seq(
@@ -178,19 +186,20 @@ class NumberUtilsSuite extends SparkFunSuite {
("12,454.8-", "99G999D9S") -> Decimal(-12454.8),
("00,454.8-", "99G999.9S") -> Decimal(-454.8)
).foreach { case ((str, format), expected) =>
- assert(parse(UTF8String.fromString(str), format) === expected)
+ val builder = new TestNumberFormatter(format)
+ builder.check()
+ assert(builder.parse(UTF8String.fromString(str)) === expected)
}
- failParseWithAnalysisException(UTF8String.fromString("4-54"), "9S99",
- "'S' or '-' must be the first or last char in '9S99'")
- failParseWithAnalysisException(UTF8String.fromString("4-54"), "9-99",
- "'S' or '-' must be the first or last char in '9-99'")
- failParseWithAnalysisException(UTF8String.fromString("454.3--"), "999D9SS",
- "Multiple 'S' or '-' in '999D9SS'")
+ invalidNumberFormat(
+ "9S99", "'S' or '-' must be the first or last char in the number format: '9S99'")
+ invalidNumberFormat(
+ "9-99", "'S' or '-' must be the first or last char in the number format: '9-99'")
+ invalidNumberFormat(
+ "999D9SS", "At most one 'S' or '-' is allowed in the number format: '999D9SS'")
}
test("format") {
- assert(format(Decimal(454), "") === "")
// Test '9' and '0'
Seq(
@@ -214,8 +223,10 @@ class NumberUtilsSuite extends SparkFunSuite {
(Decimal(54), "0000") -> "0054",
(Decimal(404), "0000") -> "0404",
(Decimal(450), "0000") -> "0450"
- ).foreach { case ((decimal, str), expected) =>
- assert(format(decimal, str) === expected)
+ ).foreach { case ((decimal, format), expected) =>
+ val builder = new TestNumberFormatter(format, false)
+ builder.check()
+ assert(builder.format(decimal) === expected)
}
// Test '.' and 'D'
@@ -240,19 +251,12 @@ class NumberUtilsSuite extends SparkFunSuite {
(Decimal(4542), "0000.") -> "4542.",
(Decimal(4542), "9999D") -> "4542.",
(Decimal(4542), "0000D") -> "4542."
- ).foreach { case ((decimal, str), expected) =>
- assert(format(decimal, str) === expected)
+ ).foreach { case ((decimal, format), expected) =>
+ val builder = new TestNumberFormatter(format, false)
+ builder.check()
+ assert(builder.format(decimal) === expected)
}
- failFormatWithAnalysisException(Decimal(454.32), "999.9.9",
- "Multiple 'D' or '.' in '999.9.9'")
- failFormatWithAnalysisException(Decimal(454.32), "999D9D9",
- "Multiple 'D' or '.' in '999D9D9'")
- failFormatWithAnalysisException(Decimal(454.32), "999.9D9",
- "Multiple 'D' or '.' in '999.9D9'")
- failFormatWithAnalysisException(Decimal(454.32), "999D9.9",
- "Multiple 'D' or '.' in '999D9.9'")
-
// Test ',' and 'G'
Seq(
(Decimal(12454), "99,999") -> "12,454",
@@ -271,8 +275,10 @@ class NumberUtilsSuite extends SparkFunSuite {
(Decimal(454367), ",000,000") -> ",454,367",
(Decimal(454367), "G999G999") -> ",454,367",
(Decimal(454367), "G000G000") -> ",454,367"
- ).foreach { case ((decimal, str), expected) =>
- assert(format(decimal, str) === expected)
+ ).foreach { case ((decimal, format), expected) =>
+ val builder = new TestNumberFormatter(format, false)
+ builder.check()
+ assert(builder.format(decimal) === expected)
}
// Test '$'
@@ -281,15 +287,12 @@ class NumberUtilsSuite extends SparkFunSuite {
(Decimal(78.12), "$00.00") -> "$78.12",
(Decimal(78.12), "99.99$") -> "78.12$",
(Decimal(78.12), "00.00$") -> "78.12$"
- ).foreach { case ((decimal, str), expected) =>
- assert(format(decimal, str) === expected)
+ ).foreach { case ((decimal, format), expected) =>
+ val builder = new TestNumberFormatter(format, false)
+ builder.check()
+ assert(builder.format(decimal) === expected)
}
- failFormatWithAnalysisException(Decimal(78.12), "99$.99",
- "'$' must be the first or last char in '99$.99'")
- failFormatWithAnalysisException(Decimal(78.12), "$99.99$",
- "Multiple '$' in '$99.99$'")
-
// Test '-' and 'S'
Seq(
(Decimal(-454), "999-") -> "454-",
@@ -302,16 +305,11 @@ class NumberUtilsSuite extends SparkFunSuite {
(Decimal(-454), "S000") -> "-454",
(Decimal(-12454.8), "99G999D9S") -> "12,454.8-",
(Decimal(-454.8), "99G999.9S") -> "454.8-"
- ).foreach { case ((decimal, str), expected) =>
- assert(format(decimal, str) === expected)
+ ).foreach { case ((decimal, format), expected) =>
+ val builder = new TestNumberFormatter(format, false)
+ builder.check()
+ assert(builder.format(decimal) === expected)
}
-
- failFormatWithAnalysisException(Decimal(-454), "9S99",
- "'S' or '-' must be the first or last char in '9S99'")
- failFormatWithAnalysisException(Decimal(-454), "9-99",
- "'S' or '-' must be the first or last char in '9-99'")
- failFormatWithAnalysisException(Decimal(-454.3), "999D9SS",
- "Multiple 'S' or '-' in '999D9SS'")
}
}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 07e1d00..b742a05 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
<!-- Automatically generated by ExpressionsSchemaSuite -->
## Summary
- - Number of queries: 375
+ - Number of queries: 376
- Number of expressions that missing example: 12
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
## Schema of Built-in Functions
@@ -300,6 +300,7 @@
| org.apache.spark.sql.catalyst.expressions.Tanh | tanh | SELECT tanh(0) | struct<TANH(0):double> |
| org.apache.spark.sql.catalyst.expressions.TimeWindow | window | SELECT a, window.start, window.end, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, start | struct<a:string,start:timestamp,end:timestamp,cnt:bigint> |
| org.apache.spark.sql.catalyst.expressions.ToDegrees | degrees | SELECT degrees(3.141592653589793) | struct<DEGREES(3.141592653589793):double> |
+| org.apache.spark.sql.catalyst.expressions.ToNumber | to_number | SELECT to_number('454', '999') | struct<to_number(454, 999):decimal(3,0)> |
| org.apache.spark.sql.catalyst.expressions.ToRadians | radians | SELECT radians(180) | struct<RADIANS(180):double> |
| org.apache.spark.sql.catalyst.expressions.ToUTCTimestamp | to_utc_timestamp | SELECT to_utc_timestamp('2016-08-31', 'Asia/Seoul') | struct<to_utc_timestamp(2016-08-31, Asia/Seoul):timestamp> |
| org.apache.spark.sql.catalyst.expressions.ToUnixTimestamp | to_unix_timestamp | SELECT to_unix_timestamp('2016-04-08', 'yyyy-MM-dd') | struct<to_unix_timestamp(2016-04-08, yyyy-MM-dd):bigint> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql
index 53f2aa4..14a89d5 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql
@@ -895,22 +895,22 @@ DROP TABLE width_bucket_test;
-- TO_NUMBER()
--
-- SET lc_numeric = 'C';
--- SELECT '' AS to_number_1, to_number('-34,338,492', '99G999G999');
--- SELECT '' AS to_number_2, to_number('-34,338,492.654,878', '99G999G999D999G999');
+SELECT '' AS to_number_1, to_number('-34,338,492', '99G999G999');
+SELECT '' AS to_number_2, to_number('-34,338,492.654,878', '99G999G999D999G999');
-- SELECT '' AS to_number_3, to_number('<564646.654564>', '999999.999999PR');
--- SELECT '' AS to_number_4, to_number('0.00001-', '9.999999S');
+SELECT '' AS to_number_4, to_number('0.00001-', '9.999999S');
-- SELECT '' AS to_number_5, to_number('5.01-', 'FM9.999999S');
-- SELECT '' AS to_number_5, to_number('5.01-', 'FM9.999999MI');
-- SELECT '' AS to_number_7, to_number('5 4 4 4 4 8 . 7 8', '9 9 9 9 9 9 . 9 9');
-- SELECT '' AS to_number_8, to_number('.01', 'FM9.99');
--- SELECT '' AS to_number_9, to_number('.0', '99999999.99999999');
--- SELECT '' AS to_number_10, to_number('0', '99.99');
+SELECT '' AS to_number_9, to_number('.0', '99999999.99999999');
+SELECT '' AS to_number_10, to_number('0', '99.99');
-- SELECT '' AS to_number_11, to_number('.-01', 'S99.99');
--- SELECT '' AS to_number_12, to_number('.01-', '99.99S');
+SELECT '' AS to_number_12, to_number('.01-', '99.99S');
-- SELECT '' AS to_number_13, to_number(' . 0 1-', ' 9 9 . 9 9 S');
--- SELECT '' AS to_number_14, to_number('34,50','999,99');
--- SELECT '' AS to_number_15, to_number('123,000','999G');
--- SELECT '' AS to_number_16, to_number('123456','999G999');
+SELECT '' AS to_number_14, to_number('34,50','999,99');
+SELECT '' AS to_number_15, to_number('123,000','999G');
+SELECT '' AS to_number_16, to_number('123456','999G999');
-- SELECT '' AS to_number_17, to_number('$1234.56','L9,999.99');
-- SELECT '' AS to_number_18, to_number('$1234.56','L99,999.99');
-- SELECT '' AS to_number_19, to_number('$1,234.56','L99,999.99');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
index 4b5f120..94924a9 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
@@ -124,4 +124,14 @@ SELECT endswith('Spark SQL', 'QL');
SELECT endswith('Spark SQL', 'Spa');
SELECT endswith(null, 'Spark');
SELECT endswith('Spark', null);
-SELECT endswith(null, null);
\ No newline at end of file
+SELECT endswith(null, null);
+
+-- to_number
+select to_number('454', '000');
+select to_number('454.2', '000.0');
+select to_number('12,454', '00,000');
+select to_number('$78.12', '$00.00');
+select to_number('-454', '-000');
+select to_number('-454', 'S000');
+select to_number('12,454.8-', '00,000.9-');
+select to_number('00,454.8-', '00,000.9-');
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
index 6fb9a6d..99927c2 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 94
+-- Number of queries: 102
-- !query
@@ -760,3 +760,67 @@ SELECT endswith(null, null)
struct<endswith(NULL, NULL):boolean>
-- !query output
NULL
+
+
+-- !query
+select to_number('454', '000')
+-- !query schema
+struct<to_number(454, 000):decimal(3,0)>
+-- !query output
+454
+
+
+-- !query
+select to_number('454.2', '000.0')
+-- !query schema
+struct<to_number(454.2, 000.0):decimal(4,1)>
+-- !query output
+454.2
+
+
+-- !query
+select to_number('12,454', '00,000')
+-- !query schema
+struct<to_number(12,454, 00,000):decimal(5,0)>
+-- !query output
+12454
+
+
+-- !query
+select to_number('$78.12', '$00.00')
+-- !query schema
+struct<to_number($78.12, $00.00):decimal(4,2)>
+-- !query output
+78.12
+
+
+-- !query
+select to_number('-454', '-000')
+-- !query schema
+struct<to_number(-454, -000):decimal(3,0)>
+-- !query output
+-454
+
+
+-- !query
+select to_number('-454', 'S000')
+-- !query schema
+struct<to_number(-454, S000):decimal(3,0)>
+-- !query output
+-454
+
+
+-- !query
+select to_number('12,454.8-', '00,000.9-')
+-- !query schema
+struct<to_number(12,454.8-, 00,000.9-):decimal(6,1)>
+-- !query output
+-12454.8
+
+
+-- !query
+select to_number('00,454.8-', '00,000.9-')
+-- !query schema
+struct<to_number(00,454.8-, 00,000.9-):decimal(6,1)>
+-- !query output
+-454.8
diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out
index bc13bb8..41fc990 100644
--- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 592
+-- Number of queries: 601
-- !query
@@ -4595,6 +4595,80 @@ struct<>
-- !query
+SELECT '' AS to_number_1, to_number('-34,338,492', '99G999G999')
+-- !query schema
+struct<to_number_1:string,to_number(-34,338,492, 99G999G999):decimal(8,0)>
+-- !query output
+ -34338492
+
+
+-- !query
+SELECT '' AS to_number_2, to_number('-34,338,492.654,878', '99G999G999D999G999')
+-- !query schema
+struct<>
+-- !query output
+java.lang.IllegalArgumentException
+The input string '-34,338,492.654,878' does not match the given number format: '99G999G999D999G999'
+
+
+-- !query
+SELECT '' AS to_number_4, to_number('0.00001-', '9.999999S')
+-- !query schema
+struct<to_number_4:string,to_number(0.00001-, 9.999999S):decimal(7,6)>
+-- !query output
+ -0.000010
+
+
+-- !query
+SELECT '' AS to_number_9, to_number('.0', '99999999.99999999')
+-- !query schema
+struct<to_number_9:string,to_number(.0, 99999999.99999999):decimal(16,8)>
+-- !query output
+ 0.00000000
+
+
+-- !query
+SELECT '' AS to_number_10, to_number('0', '99.99')
+-- !query schema
+struct<to_number_10:string,to_number(0, 99.99):decimal(4,2)>
+-- !query output
+ 0.00
+
+
+-- !query
+SELECT '' AS to_number_12, to_number('.01-', '99.99S')
+-- !query schema
+struct<to_number_12:string,to_number(.01-, 99.99S):decimal(4,2)>
+-- !query output
+ -0.01
+
+
+-- !query
+SELECT '' AS to_number_14, to_number('34,50','999,99')
+-- !query schema
+struct<to_number_14:string,to_number(34,50, 999,99):decimal(5,0)>
+-- !query output
+ 3450
+
+
+-- !query
+SELECT '' AS to_number_15, to_number('123,000','999G')
+-- !query schema
+struct<>
+-- !query output
+java.lang.IllegalArgumentException
+The input string '123,000' does not match the given number format: '999G'
+
+
+-- !query
+SELECT '' AS to_number_16, to_number('123456','999G999')
+-- !query schema
+struct<to_number_16:string,to_number(123456, 999G999):decimal(6,0)>
+-- !query output
+ 123456
+
+
+-- !query
CREATE TABLE num_input_test (n1 decimal(38, 18)) USING parquet
-- !query schema
struct<>
diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index 2aa2e80..6baac61 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 94
+-- Number of queries: 102
-- !query
@@ -756,3 +756,67 @@ SELECT endswith(null, null)
struct<endswith(NULL, NULL):boolean>
-- !query output
NULL
+
+
+-- !query
+select to_number('454', '000')
+-- !query schema
+struct<to_number(454, 000):decimal(3,0)>
+-- !query output
+454
+
+
+-- !query
+select to_number('454.2', '000.0')
+-- !query schema
+struct<to_number(454.2, 000.0):decimal(4,1)>
+-- !query output
+454.2
+
+
+-- !query
+select to_number('12,454', '00,000')
+-- !query schema
+struct<to_number(12,454, 00,000):decimal(5,0)>
+-- !query output
+12454
+
+
+-- !query
+select to_number('$78.12', '$00.00')
+-- !query schema
+struct<to_number($78.12, $00.00):decimal(4,2)>
+-- !query output
+78.12
+
+
+-- !query
+select to_number('-454', '-000')
+-- !query schema
+struct<to_number(-454, -000):decimal(3,0)>
+-- !query output
+-454
+
+
+-- !query
+select to_number('-454', 'S000')
+-- !query schema
+struct<to_number(-454, S000):decimal(3,0)>
+-- !query output
+-454
+
+
+-- !query
+select to_number('12,454.8-', '00,000.9-')
+-- !query schema
+struct<to_number(12,454.8-, 00,000.9-):decimal(6,1)>
+-- !query output
+-12454.8
+
+
+-- !query
+select to_number('00,454.8-', '00,000.9-')
+-- !query schema
+struct<to_number(00,454.8-, 00,000.9-):decimal(6,1)>
+-- !query output
+-454.8
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org