You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/04 13:03:27 UTC

[spark] branch master updated: [SPARK-38345][SQL] Introduce SQL function ARRAY_SIZE

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5039c0f  [SPARK-38345][SQL] Introduce SQL function ARRAY_SIZE
5039c0f is described below

commit 5039c0f34f98c2c9937f9ad3576fb18d8e9cba34
Author: Xinrong Meng <xi...@databricks.com>
AuthorDate: Fri Mar 4 21:01:49 2022 +0800

    [SPARK-38345][SQL] Introduce SQL function ARRAY_SIZE
    
    ### What changes were proposed in this pull request?
    Introduce SQL function ARRAY_SIZE.
    
    ARRAY_SIZE works the same as SIZE when the input is an array except for:
    - ARRAY_SIZE raises an exception for non-array input.
    - ARRAY_SIZE always returns null for null input.
    
    ### Why are the changes needed?
    Counting elements within an array is a common use case. ARRAY_SIZE ensures the input to be an array and then returns the size.
    
    Other DBRMS like Snowflake supports that as well: [Snowflake ARRAY_SIZE](https://docs.snowflake.com/en/sql-reference/functions/array_size.html). Implementing that improves compatibility with DBMS and makes migration easier.
    
    ### Does this PR introduce _any_ user-facing change?
    Yea. `array_size` is available now.
    
    ```
    scala> spark.sql("select array_size(array(2, 1))").show()
    +-----------------------+
    |array_size(array(2, 1))|
    +-----------------------+
    |                      2|
    +-----------------------+
    
    scala> spark.sql("select array_size(map('a', 1, 'b', 2))").show()
    org.apache.spark.sql.AnalysisException: cannot resolve 'array_size(map('a', 1, 'b', 2))' due to data type mismatch: argument 1 requires array type, however, 'map('a', 1, 'b', 2)' is of map<string,int> type.; line 1 pos 7;
    'Project [unresolvedalias(array_size(map(a, 1, b, 2), None), None)]
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #35671 from xinrong-databricks/array_size.
    
    Authored-by: Xinrong Meng <xi...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/analysis/FunctionRegistry.scala   |  1 +
 .../expressions/collectionOperations.scala         | 27 +++++++++++++-
 .../sql-functions/sql-expression-schema.md         |  3 +-
 .../src/test/resources/sql-tests/inputs/array.sql  |  7 ++++
 .../resources/sql-tests/results/ansi/array.sql.out | 43 +++++++++++++++++++++-
 .../test/resources/sql-tests/results/array.sql.out | 43 +++++++++++++++++++++-
 6 files changed, 120 insertions(+), 4 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index bc7eb09..e01457c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -640,6 +640,7 @@ object FunctionRegistry {
     expression[ArrayIntersect]("array_intersect"),
     expression[ArrayJoin]("array_join"),
     expression[ArrayPosition]("array_position"),
+    expression[ArraySize]("array_size"),
     expression[ArraySort]("array_sort"),
     expression[ArrayExcept]("array_except"),
     expression[ArrayUnion]("array_union"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index e53fc5e..363c531 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
 import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
 import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern}
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -133,6 +133,31 @@ object Size {
   def apply(child: Expression): Size = new Size(child)
 }
 
+
+/**
+ * Given an array, returns total number of elements in it.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the size of an array. The function returns null for null input.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array('b', 'd', 'c', 'a'));
+       4
+  """,
+  since = "3.3.0",
+  group = "collection_funcs")
+case class ArraySize(child: Expression)
+  extends RuntimeReplaceable with ImplicitCastInputTypes with UnaryLike[Expression] {
+
+  override lazy val replacement: Expression = Size(child, legacySizeOfNull = false)
+
+  override def prettyName: String = "array_size"
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+  protected def withNewChildInternal(newChild: Expression): ArraySize = copy(child = newChild)
+}
+
 /**
  * Returns an unordered array containing the keys of the map.
  */
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 88ad1ae..052e88e 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
 <!-- Automatically generated by ExpressionsSchemaSuite -->
 ## Summary
-  - Number of queries: 383
+  - Number of queries: 384
   - Number of expressions that missing example: 12
   - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
 ## Schema of Built-in Functions
@@ -28,6 +28,7 @@
 | org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct<array_position(array(3, 2, 1), 1):bigint> |
 | org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct<array_remove(array(1, 2, 3, NULL, 3), 3):array<int>> |
 | org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct<array_repeat(123, 2):array<string>> |
+| org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct<array_size(array(b, d, c, a)):int> |
 | org.apache.spark.sql.catalyst.expressions.ArraySort | array_sort | SELECT array_sort(array(5, 6, 1), (left, right) -> case when left < right then -1 when left > right then 1 else 0 end) | struct<array_sort(array(5, 6, 1), lambdafunction(CASE WHEN (namedlambdavariable() < namedlambdavariable()) THEN -1 WHEN (namedlambdavariable() > namedlambdavariable()) THEN 1 ELSE 0 END, namedlambdavariable(), namedlambdavariable())):array<int>> |
 | org.apache.spark.sql.catalyst.expressions.ArrayTransform | transform | SELECT transform(array(1, 2, 3), x -> x + 1) | struct<transform(array(1, 2, 3), lambdafunction((namedlambdavariable() + 1), namedlambdavariable())):array<int>> |
 | org.apache.spark.sql.catalyst.expressions.ArrayUnion | array_union | SELECT array_union(array(1, 2, 3), array(1, 3, 5)) | struct<array_union(array(1, 2, 3), array(1, 3, 5)):array<int>> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql
index 0223ce5..dfcf174 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/array.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql
@@ -106,3 +106,10 @@ select elt(2, '123', null);
 
 select array(1, 2, 3)[5];
 select array(1, 2, 3)[-1];
+
+-- array_size
+select array_size(array());
+select array_size(array(true));
+select array_size(array(2, 1));
+select array_size(NULL);
+select array_size(map('a', 1, 'b', 2));
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
index f2b3552..00ac2ee 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 33
+-- Number of queries: 38
 
 
 -- !query
@@ -267,6 +267,47 @@ Invalid index: -1, numElements: 3. If necessary set spark.sql.ansi.strictIndexOp
 
 
 -- !query
+select array_size(array())
+-- !query schema
+struct<array_size(array()):int>
+-- !query output
+0
+
+
+-- !query
+select array_size(array(true))
+-- !query schema
+struct<array_size(array(true)):int>
+-- !query output
+1
+
+
+-- !query
+select array_size(array(2, 1))
+-- !query schema
+struct<array_size(array(2, 1)):int>
+-- !query output
+2
+
+
+-- !query
+select array_size(NULL)
+-- !query schema
+struct<array_size(NULL):int>
+-- !query output
+NULL
+
+
+-- !query
+select array_size(map('a', 1, 'b', 2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'array_size(map('a', 1, 'b', 2))' due to data type mismatch: argument 1 requires array type, however, 'map('a', 1, 'b', 2)' is of map<string,int> type.; line 1 pos 7
+
+
+-- !query
 set spark.sql.ansi.strictIndexOperator=false
 -- !query schema
 struct<key:string,value:string>
diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out
index 9d42b8a..1ff2a17 100644
--- a/sql/core/src/test/resources/sql-tests/results/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 24
+-- Number of queries: 29
 
 
 -- !query
@@ -257,3 +257,44 @@ select array(1, 2, 3)[-1]
 struct<array(1, 2, 3)[-1]:int>
 -- !query output
 NULL
+
+
+-- !query
+select array_size(array())
+-- !query schema
+struct<array_size(array()):int>
+-- !query output
+0
+
+
+-- !query
+select array_size(array(true))
+-- !query schema
+struct<array_size(array(true)):int>
+-- !query output
+1
+
+
+-- !query
+select array_size(array(2, 1))
+-- !query schema
+struct<array_size(array(2, 1)):int>
+-- !query output
+2
+
+
+-- !query
+select array_size(NULL)
+-- !query schema
+struct<array_size(NULL):int>
+-- !query output
+NULL
+
+
+-- !query
+select array_size(map('a', 1, 'b', 2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'array_size(map('a', 1, 'b', 2))' due to data type mismatch: argument 1 requires array type, however, 'map('a', 1, 'b', 2)' is of map<string,int> type.; line 1 pos 7

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org