You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/24 06:31:08 UTC

[spark] branch branch-3.3 updated: [SPARK-38063][SQL] Support split_part Function

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 072968d  [SPARK-38063][SQL] Support split_part Function
072968d is described below

commit 072968d730863e89635c903999a397fc0233ea87
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Thu Mar 24 14:28:32 2022 +0800

    [SPARK-38063][SQL] Support split_part Function
    
    ### What changes were proposed in this pull request?
    
    `split_part()` is a commonly supported function by other systems such as Postgres and some other systems.
    
    The Spark equivalent is `element_at(split(arg, delim), part)`
    
    ### Why are the changes needed?
    
    Adding new SQL function.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. This PR adds a new function so there is no previous behavior. The following demonstrates more about the new function:
    
    syntax: `split_part(str, delimiter, partNum)`
    
    This function splits `str` by `delimiter` and return requested part of the split (1-based). If any input is null, returns null. If the index is out of range of split parts, returns empty string. If index is 0, throws an ArrayIndexOutOfBoundsException.
    
    `str` and `delimiter` are the same type as `string`. `partNum` is `integer` type
    
    Examples:
    ```
          > SELECT _FUNC_('11.12.13', '.', 3);
           13
          > SELECT _FUNC_(NULL, '.', 3);
          NULL
    ```
    
    ### How was this patch tested?
    
    Unit Test
    
    Closes #35352 from amaliujia/splitpart.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 3858bf0fbd02e3d8fd18e967f3841c50b9294414)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../org/apache/spark/unsafe/types/UTF8String.java  | 21 +++++-
 .../sql/catalyst/analysis/FunctionRegistry.scala   |  1 +
 .../expressions/collectionOperations.scala         | 22 ++++--
 .../catalyst/expressions/stringExpressions.scala   | 75 ++++++++++++++++++-
 .../sql-functions/sql-expression-schema.md         |  3 +-
 .../sql-tests/inputs/string-functions.sql          | 12 ++++
 .../results/ansi/string-functions.sql.out          | 83 +++++++++++++++++++++-
 .../sql-tests/results/string-functions.sql.out     | 83 +++++++++++++++++++++-
 8 files changed, 291 insertions(+), 9 deletions(-)

diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 98c61cf..0f9d653 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.Map;
+import java.util.regex.Pattern;
 
 import com.esotericsoftware.kryo.Kryo;
 import com.esotericsoftware.kryo.KryoSerializable;
@@ -999,13 +1000,31 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
   }
 
   public UTF8String[] split(UTF8String pattern, int limit) {
+    return split(pattern.toString(), limit);
+  }
+
+  public UTF8String[] splitSQL(UTF8String delimiter, int limit) {
+    // if delimiter is empty string, skip the regex based splitting directly as regex
+    // treats empty string as matching anything, thus use the input directly.
+    if (delimiter.numBytes() == 0) {
+      return new UTF8String[]{this};
+    } else {
+      // we do not treat delimiter as a regex but consider the whole string of delimiter
+      // as the separator to split string. Java String's split, however, only accept
+      // regex as the pattern to split, thus we can quote the delimiter to escape special
+      // characters in the string.
+      return split(Pattern.quote(delimiter.toString()), limit);
+    }
+  }
+
+  private UTF8String[] split(String delimiter, int limit) {
     // Java String's split method supports "ignore empty string" behavior when the limit is 0
     // whereas other languages do not. To avoid this java specific behavior, we fall back to
     // -1 when the limit is 0.
     if (limit == 0) {
       limit = -1;
     }
-    String[] splits = toString().split(pattern.toString(), limit);
+    String[] splits = toString().split(delimiter, limit);
     UTF8String[] res = new UTF8String[splits.length];
     for (int i = 0; i < res.length; i++) {
       res[i] = fromString(splits[i]);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a37d4b2..a06112a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -549,6 +549,7 @@ object FunctionRegistry {
     expression[SoundEx]("soundex"),
     expression[StringSpace]("space"),
     expression[StringSplit]("split"),
+    expression[SplitPart]("split_part"),
     expression[Substring]("substr", true),
     expression[Substring]("substring"),
     expression[Left]("left"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 363c531..ca00839 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -2095,10 +2095,12 @@ case class ArrayPosition(left: Expression, right: Expression)
 case class ElementAt(
     left: Expression,
     right: Expression,
+    // The value to return if index is out of bound
+    defaultValueOutOfBound: Option[Literal] = None,
     failOnError: Boolean = SQLConf.get.ansiEnabled)
   extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {
 
-  def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)
+  def this(left: Expression, right: Expression) = this(left, right, None, SQLConf.get.ansiEnabled)
 
   @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
 
@@ -2179,7 +2181,10 @@ case class ElementAt(
           if (failOnError) {
             throw QueryExecutionErrors.invalidElementAtIndexError(index, array.numElements())
           } else {
-            null
+            defaultValueOutOfBound match {
+              case Some(value) => value.eval()
+              case None => null
+            }
           }
         } else {
           val idx = if (index == 0) {
@@ -2218,7 +2223,16 @@ case class ElementAt(
           val indexOutOfBoundBranch = if (failOnError) {
             s"throw QueryExecutionErrors.invalidElementAtIndexError($index, $eval1.numElements());"
           } else {
-            s"${ev.isNull} = true;"
+            defaultValueOutOfBound match {
+              case Some(value) =>
+                val defaultValueEval = value.genCode(ctx)
+                s"""
+                  ${defaultValueEval.code}
+                  ${ev.isNull} = ${defaultValueEval.isNull}
+                  ${ev.value} = ${defaultValueEval.value}
+                """.stripMargin
+              case None => s"${ev.isNull} = true;"
+            }
           }
 
           s"""
@@ -2278,7 +2292,7 @@ case class ElementAt(
 case class TryElementAt(left: Expression, right: Expression, replacement: Expression)
   extends RuntimeReplaceable with InheritAnalysisRules {
   def this(left: Expression, right: Expression) =
-    this(left, right, ElementAt(left, right, failOnError = false))
+    this(left, right, ElementAt(left, right, None, failOnError = false))
 
   override def prettyName: String = "try_element_at"
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index fc73216..a08ab84 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.{StringType, _}
 import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -2943,3 +2943,76 @@ case class Sentences(
     copy(str = newFirst, language = newSecond, country = newThird)
 
 }
+
+/**
+ * Splits a given string by a specified delimiter and return splits into a
+ * GenericArrayData. This expression is different from `split` function as
+ * `split` takes regex expression as the pattern to split strings while this
+ * expression take delimiter (a string without carrying special meaning on its
+ * characters, thus is not treated as regex) to split strings.
+ */
+case class StringSplitSQL(
+    str: Expression,
+    delimiter: Expression) extends BinaryExpression with NullIntolerant {
+  override def dataType: DataType = ArrayType(StringType, containsNull = false)
+  override def left: Expression = str
+  override def right: Expression = delimiter
+
+  override def nullSafeEval(string: Any, delimiter: Any): Any = {
+    val strings = string.asInstanceOf[UTF8String].splitSQL(
+      delimiter.asInstanceOf[UTF8String], -1);
+    new GenericArrayData(strings.asInstanceOf[Array[Any]])
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val arrayClass = classOf[GenericArrayData].getName
+    nullSafeCodeGen(ctx, ev, (str, delimiter) => {
+      // Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
+      s"${ev.value} = new $arrayClass($str.splitSQL($delimiter,-1));"
+    })
+  }
+
+  override def withNewChildrenInternal(
+      newFirst: Expression, newSecond: Expression): StringSplitSQL =
+    copy(str = newFirst, delimiter = newSecond)
+}
+
+/**
+ * Splits a given string by a specified delimiter and returns the requested part.
+ * If any input is null, returns null.
+ * If index is out of range of split parts, return empty string.
+ * If index is 0, throws an ArrayIndexOutOfBoundsException.
+ */
+@ExpressionDescription(
+  usage =
+    """
+    _FUNC_(str, delimiter, partNum) - Splits `str` by delimiter and return
+      requested part of the split (1-based). If any input is null, returns null.
+      if `partNum` is out of range of split parts, returns empty string. If `partNum` is 0,
+      throws an error. If `partNum` is negative, the parts are counted backward from the
+      end of the string. If the `delimiter` is an empty string, the `str` is not split.
+  """,
+  examples =
+    """
+    Examples:
+      > SELECT _FUNC_('11.12.13', '.', 3);
+       13
+  """,
+  since = "3.3.0",
+  group = "string_funcs")
+case class SplitPart (
+    str: Expression,
+    delimiter: Expression,
+    partNum: Expression)
+  extends RuntimeReplaceable with ImplicitCastInputTypes {
+  override lazy val replacement: Expression =
+    ElementAt(StringSplitSQL(str, delimiter), partNum, Some(Literal.create("", StringType)),
+      false)
+  override def nodeName: String = "split_part"
+  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
+  def children: Seq[Expression] = Seq(str, delimiter, partNum)
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
+    copy(str = newChildren.apply(0), delimiter = newChildren.apply(1),
+      partNum = newChildren.apply(2))
+  }
+}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 1afba46..166c761 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
 <!-- Automatically generated by ExpressionsSchemaSuite -->
 ## Summary
-  - Number of queries: 383
+  - Number of queries: 384
   - Number of expressions that missing example: 12
   - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
 ## Schema of Built-in Functions
@@ -275,6 +275,7 @@
 | org.apache.spark.sql.catalyst.expressions.SoundEx | soundex | SELECT soundex('Miller') | struct<soundex(Miller):string> |
 | org.apache.spark.sql.catalyst.expressions.SparkPartitionID | spark_partition_id | SELECT spark_partition_id() | struct<SPARK_PARTITION_ID():int> |
 | org.apache.spark.sql.catalyst.expressions.SparkVersion | version | SELECT version() | struct<version():string> |
+| org.apache.spark.sql.catalyst.expressions.SplitPart | split_part | SELECT split_part('11.12.13', '.', 3) | struct<split_part(11.12.13, ., 3):string> |
 | org.apache.spark.sql.catalyst.expressions.Sqrt | sqrt | SELECT sqrt(4) | struct<SQRT(4):double> |
 | org.apache.spark.sql.catalyst.expressions.Stack | stack | SELECT stack(2, 1, 2, 3) | struct<col0:int,col1:int> |
 | org.apache.spark.sql.catalyst.expressions.StartsWithExpressionBuilder$ | startswith | SELECT startswith('Spark SQL', 'Spark') | struct<startswith(Spark SQL, Spark):boolean> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
index e7c01a6..7d22e79 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
@@ -27,6 +27,18 @@ select right("abcd", -2), right("abcd", 0), right("abcd", 'a');
 SELECT split('aa1cc2ee3', '[1-9]+');
 SELECT split('aa1cc2ee3', '[1-9]+', 2);
 
+-- split_part function
+SELECT split_part('11.12.13', '.', 2);
+SELECT split_part('11.12.13', '.', -1);
+SELECT split_part('11.12.13', '.', -3);
+SELECT split_part('11.12.13', '', 1);
+SELECT split_part('11ab12ab13', 'ab', 1);
+SELECT split_part('11.12.13', '.', 0);
+SELECT split_part('11.12.13', '.', 4);
+SELECT split_part('11.12.13', '.', 5);
+SELECT split_part('11.12.13', '.', -5);
+SELECT split_part(null, '.', 1);
+
 -- substring function
 SELECT substr('Spark SQL', 5);
 SELECT substr('Spark SQL', -3);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
index b182b5c..01213bd 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 131
+-- Number of queries: 141
 
 
 -- !query
@@ -127,6 +127,87 @@ struct<split(aa1cc2ee3, [1-9]+, 2):array<string>>
 
 
 -- !query
+SELECT split_part('11.12.13', '.', 2)
+-- !query schema
+struct<split_part(11.12.13, ., 2):string>
+-- !query output
+12
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -1)
+-- !query schema
+struct<split_part(11.12.13, ., -1):string>
+-- !query output
+13
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -3)
+-- !query schema
+struct<split_part(11.12.13, ., -3):string>
+-- !query output
+11
+
+
+-- !query
+SELECT split_part('11.12.13', '', 1)
+-- !query schema
+struct<split_part(11.12.13, , 1):string>
+-- !query output
+11.12.13
+
+
+-- !query
+SELECT split_part('11ab12ab13', 'ab', 1)
+-- !query schema
+struct<split_part(11ab12ab13, ab, 1):string>
+-- !query output
+11
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 0)
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArrayIndexOutOfBoundsException
+SQL array indices start at 1
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 4)
+-- !query schema
+struct<split_part(11.12.13, ., 4):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 5)
+-- !query schema
+struct<split_part(11.12.13, ., 5):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -5)
+-- !query schema
+struct<split_part(11.12.13, ., -5):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part(null, '.', 1)
+-- !query schema
+struct<split_part(NULL, ., 1):string>
+-- !query output
+NULL
+
+
+-- !query
 SELECT substr('Spark SQL', 5)
 -- !query schema
 struct<substr(Spark SQL, 5, 2147483647):string>
diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index 4307df7..3a7f197 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 131
+-- Number of queries: 141
 
 
 -- !query
@@ -125,6 +125,87 @@ struct<split(aa1cc2ee3, [1-9]+, 2):array<string>>
 
 
 -- !query
+SELECT split_part('11.12.13', '.', 2)
+-- !query schema
+struct<split_part(11.12.13, ., 2):string>
+-- !query output
+12
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -1)
+-- !query schema
+struct<split_part(11.12.13, ., -1):string>
+-- !query output
+13
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -3)
+-- !query schema
+struct<split_part(11.12.13, ., -3):string>
+-- !query output
+11
+
+
+-- !query
+SELECT split_part('11.12.13', '', 1)
+-- !query schema
+struct<split_part(11.12.13, , 1):string>
+-- !query output
+11.12.13
+
+
+-- !query
+SELECT split_part('11ab12ab13', 'ab', 1)
+-- !query schema
+struct<split_part(11ab12ab13, ab, 1):string>
+-- !query output
+11
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 0)
+-- !query schema
+struct<>
+-- !query output
+java.lang.ArrayIndexOutOfBoundsException
+SQL array indices start at 1
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 4)
+-- !query schema
+struct<split_part(11.12.13, ., 4):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part('11.12.13', '.', 5)
+-- !query schema
+struct<split_part(11.12.13, ., 5):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part('11.12.13', '.', -5)
+-- !query schema
+struct<split_part(11.12.13, ., -5):string>
+-- !query output
+
+
+
+-- !query
+SELECT split_part(null, '.', 1)
+-- !query schema
+struct<split_part(NULL, ., 1):string>
+-- !query output
+NULL
+
+
+-- !query
 SELECT substr('Spark SQL', 5)
 -- !query schema
 struct<substr(Spark SQL, 5, 2147483647):string>

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org