You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/24 06:31:08 UTC
[spark] branch branch-3.3 updated: [SPARK-38063][SQL] Support split_part Function
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 072968d [SPARK-38063][SQL] Support split_part Function
072968d is described below
commit 072968d730863e89635c903999a397fc0233ea87
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Thu Mar 24 14:28:32 2022 +0800
[SPARK-38063][SQL] Support split_part Function
### What changes were proposed in this pull request?
`split_part()` is a commonly supported function by other systems such as Postgres and some other systems.
The Spark equivalent is `element_at(split(arg, delim), part)`
### Why are the changes needed?
Adding new SQL function.
### Does this PR introduce _any_ user-facing change?
Yes. This PR adds a new function so there is no previous behavior. The following demonstrates more about the new function:
syntax: `split_part(str, delimiter, partNum)`
This function splits `str` by `delimiter` and return requested part of the split (1-based). If any input is null, returns null. If the index is out of range of split parts, returns empty string. If index is 0, throws an ArrayIndexOutOfBoundsException.
`str` and `delimiter` are the same type as `string`. `partNum` is `integer` type
Examples:
```
> SELECT _FUNC_('11.12.13', '.', 3);
13
> SELECT _FUNC_(NULL, '.', 3);
NULL
```
### How was this patch tested?
Unit Test
Closes #35352 from amaliujia/splitpart.
Authored-by: Rui Wang <ru...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit 3858bf0fbd02e3d8fd18e967f3841c50b9294414)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../org/apache/spark/unsafe/types/UTF8String.java | 21 +++++-
.../sql/catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/collectionOperations.scala | 22 ++++--
.../catalyst/expressions/stringExpressions.scala | 75 ++++++++++++++++++-
.../sql-functions/sql-expression-schema.md | 3 +-
.../sql-tests/inputs/string-functions.sql | 12 ++++
.../results/ansi/string-functions.sql.out | 83 +++++++++++++++++++++-
.../sql-tests/results/string-functions.sql.out | 83 +++++++++++++++++++++-
8 files changed, 291 insertions(+), 9 deletions(-)
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 98c61cf..0f9d653 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Map;
+import java.util.regex.Pattern;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
@@ -999,13 +1000,31 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
}
public UTF8String[] split(UTF8String pattern, int limit) {
+ return split(pattern.toString(), limit);
+ }
+
+ public UTF8String[] splitSQL(UTF8String delimiter, int limit) {
+ // if delimiter is empty string, skip the regex based splitting directly as regex
+ // treats empty string as matching anything, thus use the input directly.
+ if (delimiter.numBytes() == 0) {
+ return new UTF8String[]{this};
+ } else {
+ // we do not treat delimiter as a regex but consider the whole string of delimiter
+ // as the separator to split string. Java String's split, however, only accept
+ // regex as the pattern to split, thus we can quote the delimiter to escape special
+ // characters in the string.
+ return split(Pattern.quote(delimiter.toString()), limit);
+ }
+ }
+
+ private UTF8String[] split(String delimiter, int limit) {
// Java String's split method supports "ignore empty string" behavior when the limit is 0
// whereas other languages do not. To avoid this java specific behavior, we fall back to
// -1 when the limit is 0.
if (limit == 0) {
limit = -1;
}
- String[] splits = toString().split(pattern.toString(), limit);
+ String[] splits = toString().split(delimiter, limit);
UTF8String[] res = new UTF8String[splits.length];
for (int i = 0; i < res.length; i++) {
res[i] = fromString(splits[i]);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a37d4b2..a06112a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -549,6 +549,7 @@ object FunctionRegistry {
expression[SoundEx]("soundex"),
expression[StringSpace]("space"),
expression[StringSplit]("split"),
+ expression[SplitPart]("split_part"),
expression[Substring]("substr", true),
expression[Substring]("substring"),
expression[Left]("left"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 363c531..ca00839 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2095,10 +2095,12 @@ case class ArrayPosition(left: Expression, right: Expression)
case class ElementAt(
left: Expression,
right: Expression,
+ // The value to return if index is out of bound
+ defaultValueOutOfBound: Option[Literal] = None,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {
- def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)
+ def this(left: Expression, right: Expression) = this(left, right, None, SQLConf.get.ansiEnabled)
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
@@ -2179,7 +2181,10 @@ case class ElementAt(
if (failOnError) {
throw QueryExecutionErrors.invalidElementAtIndexError(index, array.numElements())
} else {
- null
+ defaultValueOutOfBound match {
+ case Some(value) => value.eval()
+ case None => null
+ }
}
} else {
val idx = if (index == 0) {
@@ -2218,7 +2223,16 @@ case class ElementAt(
val indexOutOfBoundBranch = if (failOnError) {
s"throw QueryExecutionErrors.invalidElementAtIndexError($index, $eval1.numElements());"
} else {
- s"${ev.isNull} = true;"
+ defaultValueOutOfBound match {
+ case Some(value) =>
+ val defaultValueEval = value.genCode(ctx)
+ s"""
+ ${defaultValueEval.code}
+ ${ev.isNull} = ${defaultValueEval.isNull}
+ ${ev.value} = ${defaultValueEval.value}
+ """.stripMargin
+ case None => s"${ev.isNull} = true;"
+ }
}
s"""
@@ -2278,7 +2292,7 @@ case class ElementAt(
case class TryElementAt(left: Expression, right: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) =
- this(left, right, ElementAt(left, right, failOnError = false))
+ this(left, right, ElementAt(left, right, None, failOnError = false))
override def prettyName: String = "try_element_at"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index fc73216..a08ab84 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.{StringType, _}
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -2943,3 +2943,76 @@ case class Sentences(
copy(str = newFirst, language = newSecond, country = newThird)
}
+
+/**
+ * Splits a given string by a specified delimiter and return splits into a
+ * GenericArrayData. This expression is different from `split` function as
+ * `split` takes regex expression as the pattern to split strings while this
+ * expression take delimiter (a string without carrying special meaning on its
+ * characters, thus is not treated as regex) to split strings.
+ */
+case class StringSplitSQL(
+ str: Expression,
+ delimiter: Expression) extends BinaryExpression with NullIntolerant {
+ override def dataType: DataType = ArrayType(StringType, containsNull = false)
+ override def left: Expression = str
+ override def right: Expression = delimiter
+
+ override def nullSafeEval(string: Any, delimiter: Any): Any = {
+ val strings = string.asInstanceOf[UTF8String].splitSQL(
+ delimiter.asInstanceOf[UTF8String], -1);
+ new GenericArrayData(strings.asInstanceOf[Array[Any]])
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val arrayClass = classOf[GenericArrayData].getName
+ nullSafeCodeGen(ctx, ev, (str, delimiter) => {
+ // Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
+ s"${ev.value} = new $arrayClass($str.splitSQL($delimiter,-1));"
+ })
+ }
+
+ override def withNewChildrenInternal(
+ newFirst: Expression, newSecond: Expression): StringSplitSQL =
+ copy(str = newFirst, delimiter = newSecond)
+}
+
+/**
+ * Splits a given string by a specified delimiter and returns the requested part.
+ * If any input is null, returns null.
+ * If index is out of range of split parts, return empty string.
+ * If index is 0, throws an ArrayIndexOutOfBoundsException.
+ */
+@ExpressionDescription(
+ usage =
+ """
+ _FUNC_(str, delimiter, partNum) - Splits `str` by delimiter and return
+ requested part of the split (1-based). If any input is null, returns null.
+ if `partNum` is out of range of split parts, returns empty string. If `partNum` is 0,
+ throws an error. If `partNum` is negative, the parts are counted backward from the
+ end of the string. If the `delimiter` is an empty string, the `str` is not split.
+ """,
+ examples =
+ """
+ Examples:
+ > SELECT _FUNC_('11.12.13', '.', 3);
+ 13
+ """,
+ since = "3.3.0",
+ group = "string_funcs")
+case class SplitPart (
+ str: Expression,
+ delimiter: Expression,
+ partNum: Expression)
+ extends RuntimeReplaceable with ImplicitCastInputTypes {
+ override lazy val replacement: Expression =
+ ElementAt(StringSplitSQL(str, delimiter), partNum, Some(Literal.create("", StringType)),
+ false)
+ override def nodeName: String = "split_part"
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
+ def children: Seq[Expression] = Seq(str, delimiter, partNum)
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
+ copy(str = newChildren.apply(0), delimiter = newChildren.apply(1),
+ partNum = newChildren.apply(2))
+ }
+}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 1afba46..166c761 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
<!-- Automatically generated by ExpressionsSchemaSuite -->
## Summary
- - Number of queries: 383
+ - Number of queries: 384
- Number of expressions that missing example: 12
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
## Schema of Built-in Functions
@@ -275,6 +275,7 @@
| org.apache.spark.sql.catalyst.expressions.SoundEx | soundex | SELECT soundex('Miller') | struct<soundex(Miller):string> |
| org.apache.spark.sql.catalyst.expressions.SparkPartitionID | spark_partition_id | SELECT spark_partition_id() | struct<SPARK_PARTITION_ID():int> |
| org.apache.spark.sql.catalyst.expressions.SparkVersion | version | SELECT version() | struct<version():string> |
+| org.apache.spark.sql.catalyst.expressions.SplitPart | split_part | SELECT split_part('11.12.13', '.', 3) | struct<split_part(11.12.13, ., 3):string> |
| org.apache.spark.sql.catalyst.expressions.Sqrt | sqrt | SELECT sqrt(4) | struct<SQRT(4):double> |
| org.apache.spark.sql.catalyst.expressions.Stack | stack | SELECT stack(2, 1, 2, 3) | struct<col0:int,col1:int> |
| org.apache.spark.sql.catalyst.expressions.StartsWithExpressionBuilder$ | startswith | SELECT startswith('Spark SQL', 'Spark') | struct<startswith(Spark SQL, Spark):boolean> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
index e7c01a6..7d22e79 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
@@ -27,6 +27,18 @@ select right("abcd", -2), right("abcd", 0), right("abcd", 'a');
SELECT split('aa1cc2ee3', '[1-9]+');
SELECT split('aa1cc2ee3', '[1-9]+', 2);
+-- split_part function
+SELECT split_part('11.12.13', '.', 2);
+SELECT split_part('11.12.13', '.', -1);
+SELECT split_part('11.12.13', '.', -3);
+SELECT split_part('11.12.13', '', 1);
+SELECT split_part('11ab12ab13', 'ab', 1);
+SELECT split_part('11.12.13', '.', 0);
+SELECT split_part('11.12.13', '.', 4);
+SELECT split_part('11.12.13', '.', 5);
+SELECT split_part('11.12.13', '.', -5);
+SELECT split_part(null, '.', 1);
+
-- substring function
SELECT substr('Spark SQL', 5);
SELECT substr('Spark SQL', -3);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
index b182b5c..01213bd 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 131
+-- Number of queries: 141
-- !query
@@ -127,6 +127,87 @@ struct<split(aa1cc2ee3, [1-9]+, 2):array<string>>
-- !query
+SELECT split_part('11.12.13', '.', 2)
+-- !query schema
+struct<split_part(11.12.13, ., 2):string>
+-- !query output
+12
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -1)
+-- !query schema
+struct<split_part(11.12.13, ., -1):string>
+-- !query output
+13
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -3)
+-- !query schema
+struct<split_part(11.12.13, ., -3):string>
+-- !query output
+11
+
+
+-- !query
+SELECT split_part('11.12.13', '', 1)
+-- !query schema
+struct<split_part(11.12.13, , 1):string>
+-- !query output
+11.12.13
+
+
+-- !query
+SELECT split_part('11ab12ab13', 'ab', 1)
+-- !query schema
+struct<split_part(11ab12ab13, ab, 1):string>
+-- !query output
+11
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 0)
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArrayIndexOutOfBoundsException
+SQL array indices start at 1
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 4)
+-- !query schema
+struct<split_part(11.12.13, ., 4):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 5)
+-- !query schema
+struct<split_part(11.12.13, ., 5):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -5)
+-- !query schema
+struct<split_part(11.12.13, ., -5):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part(null, '.', 1)
+-- !query schema
+struct<split_part(NULL, ., 1):string>
+-- !query output
+NULL
+
+
+-- !query
SELECT substr('Spark SQL', 5)
-- !query schema
struct<substr(Spark SQL, 5, 2147483647):string>
diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index 4307df7..3a7f197 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 131
+-- Number of queries: 141
-- !query
@@ -125,6 +125,87 @@ struct<split(aa1cc2ee3, [1-9]+, 2):array<string>>
-- !query
+SELECT split_part('11.12.13', '.', 2)
+-- !query schema
+struct<split_part(11.12.13, ., 2):string>
+-- !query output
+12
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -1)
+-- !query schema
+struct<split_part(11.12.13, ., -1):string>
+-- !query output
+13
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -3)
+-- !query schema
+struct<split_part(11.12.13, ., -3):string>
+-- !query output
+11
+
+
+-- !query
+SELECT split_part('11.12.13', '', 1)
+-- !query schema
+struct<split_part(11.12.13, , 1):string>
+-- !query output
+11.12.13
+
+
+-- !query
+SELECT split_part('11ab12ab13', 'ab', 1)
+-- !query schema
+struct<split_part(11ab12ab13, ab, 1):string>
+-- !query output
+11
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 0)
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArrayIndexOutOfBoundsException
+SQL array indices start at 1
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 4)
+-- !query schema
+struct<split_part(11.12.13, ., 4):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 5)
+-- !query schema
+struct<split_part(11.12.13, ., 5):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -5)
+-- !query schema
+struct<split_part(11.12.13, ., -5):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part(null, '.', 1)
+-- !query schema
+struct<split_part(NULL, ., 1):string>
+-- !query output
+NULL
+
+
+-- !query
SELECT substr('Spark SQL', 5)
-- !query schema
struct<substr(Spark SQL, 5, 2147483647):string>
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org