You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/04 13:03:27 UTC
[spark] branch master updated: [SPARK-38345][SQL] Introduce SQL function ARRAY_SIZE
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5039c0f [SPARK-38345][SQL] Introduce SQL function ARRAY_SIZE
5039c0f is described below
commit 5039c0f34f98c2c9937f9ad3576fb18d8e9cba34
Author: Xinrong Meng <xi...@databricks.com>
AuthorDate: Fri Mar 4 21:01:49 2022 +0800
[SPARK-38345][SQL] Introduce SQL function ARRAY_SIZE
### What changes were proposed in this pull request?
Introduce SQL function ARRAY_SIZE.
ARRAY_SIZE works the same as SIZE when the input is an array except for:
- ARRAY_SIZE raises an exception for non-array input.
- ARRAY_SIZE always returns null for null input.
### Why are the changes needed?
Counting elements within an array is a common use case. ARRAY_SIZE ensures the input to be an array and then returns the size.
Other DBRMS like Snowflake supports that as well: [Snowflake ARRAY_SIZE](https://docs.snowflake.com/en/sql-reference/functions/array_size.html). Implementing that improves compatibility with DBMS and makes migration easier.
### Does this PR introduce _any_ user-facing change?
Yea. `array_size` is available now.
```
scala> spark.sql("select array_size(array(2, 1))").show()
+-----------------------+
|array_size(array(2, 1))|
+-----------------------+
| 2|
+-----------------------+
scala> spark.sql("select array_size(map('a', 1, 'b', 2))").show()
org.apache.spark.sql.AnalysisException: cannot resolve 'array_size(map('a', 1, 'b', 2))' due to data type mismatch: argument 1 requires array type, however, 'map('a', 1, 'b', 2)' is of map<string,int> type.; line 1 pos 7;
'Project [unresolvedalias(array_size(map(a, 1, b, 2), None), None)]
```
### How was this patch tested?
Unit tests.
Closes #35671 from xinrong-databricks/array_size.
Authored-by: Xinrong Meng <xi...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../sql/catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/collectionOperations.scala | 27 +++++++++++++-
.../sql-functions/sql-expression-schema.md | 3 +-
.../src/test/resources/sql-tests/inputs/array.sql | 7 ++++
.../resources/sql-tests/results/ansi/array.sql.out | 43 +++++++++++++++++++++-
.../test/resources/sql-tests/results/array.sql.out | 43 +++++++++++++++++++++-
6 files changed, 120 insertions(+), 4 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index bc7eb09..e01457c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -640,6 +640,7 @@ object FunctionRegistry {
expression[ArrayIntersect]("array_intersect"),
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
+ expression[ArraySize]("array_size"),
expression[ArraySort]("array_sort"),
expression[ArrayExcept]("array_except"),
expression[ArrayUnion]("array_union"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index e53fc5e..363c531 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -133,6 +133,31 @@ object Size {
def apply(child: Expression): Size = new Size(child)
}
+
+/**
+ * Given an array, returns total number of elements in it.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the size of an array. The function returns null for null input.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array('b', 'd', 'c', 'a'));
+ 4
+ """,
+ since = "3.3.0",
+ group = "collection_funcs")
+case class ArraySize(child: Expression)
+ extends RuntimeReplaceable with ImplicitCastInputTypes with UnaryLike[Expression] {
+
+ override lazy val replacement: Expression = Size(child, legacySizeOfNull = false)
+
+ override def prettyName: String = "array_size"
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+ protected def withNewChildInternal(newChild: Expression): ArraySize = copy(child = newChild)
+}
+
/**
* Returns an unordered array containing the keys of the map.
*/
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 88ad1ae..052e88e 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
<!-- Automatically generated by ExpressionsSchemaSuite -->
## Summary
- - Number of queries: 383
+ - Number of queries: 384
- Number of expressions that missing example: 12
- Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
## Schema of Built-in Functions
@@ -28,6 +28,7 @@
| org.apache.spark.sql.catalyst.expressions.ArrayPosition | array_position | SELECT array_position(array(3, 2, 1), 1) | struct<array_position(array(3, 2, 1), 1):bigint> |
| org.apache.spark.sql.catalyst.expressions.ArrayRemove | array_remove | SELECT array_remove(array(1, 2, 3, null, 3), 3) | struct<array_remove(array(1, 2, 3, NULL, 3), 3):array<int>> |
| org.apache.spark.sql.catalyst.expressions.ArrayRepeat | array_repeat | SELECT array_repeat('123', 2) | struct<array_repeat(123, 2):array<string>> |
+| org.apache.spark.sql.catalyst.expressions.ArraySize | array_size | SELECT array_size(array('b', 'd', 'c', 'a')) | struct<array_size(array(b, d, c, a)):int> |
| org.apache.spark.sql.catalyst.expressions.ArraySort | array_sort | SELECT array_sort(array(5, 6, 1), (left, right) -> case when left < right then -1 when left > right then 1 else 0 end) | struct<array_sort(array(5, 6, 1), lambdafunction(CASE WHEN (namedlambdavariable() < namedlambdavariable()) THEN -1 WHEN (namedlambdavariable() > namedlambdavariable()) THEN 1 ELSE 0 END, namedlambdavariable(), namedlambdavariable())):array<int>> |
| org.apache.spark.sql.catalyst.expressions.ArrayTransform | transform | SELECT transform(array(1, 2, 3), x -> x + 1) | struct<transform(array(1, 2, 3), lambdafunction((namedlambdavariable() + 1), namedlambdavariable())):array<int>> |
| org.apache.spark.sql.catalyst.expressions.ArrayUnion | array_union | SELECT array_union(array(1, 2, 3), array(1, 3, 5)) | struct<array_union(array(1, 2, 3), array(1, 3, 5)):array<int>> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql
index 0223ce5..dfcf174 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/array.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql
@@ -106,3 +106,10 @@ select elt(2, '123', null);
select array(1, 2, 3)[5];
select array(1, 2, 3)[-1];
+
+-- array_size
+select array_size(array());
+select array_size(array(true));
+select array_size(array(2, 1));
+select array_size(NULL);
+select array_size(map('a', 1, 'b', 2));
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
index f2b3552..00ac2ee 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/array.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 33
+-- Number of queries: 38
-- !query
@@ -267,6 +267,47 @@ Invalid index: -1, numElements: 3. If necessary set spark.sql.ansi.strictIndexOp
-- !query
+select array_size(array())
+-- !query schema
+struct<array_size(array()):int>
+-- !query output
+0
+
+
+-- !query
+select array_size(array(true))
+-- !query schema
+struct<array_size(array(true)):int>
+-- !query output
+1
+
+
+-- !query
+select array_size(array(2, 1))
+-- !query schema
+struct<array_size(array(2, 1)):int>
+-- !query output
+2
+
+
+-- !query
+select array_size(NULL)
+-- !query schema
+struct<array_size(NULL):int>
+-- !query output
+NULL
+
+
+-- !query
+select array_size(map('a', 1, 'b', 2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'array_size(map('a', 1, 'b', 2))' due to data type mismatch: argument 1 requires array type, however, 'map('a', 1, 'b', 2)' is of map<string,int> type.; line 1 pos 7
+
+
+-- !query
set spark.sql.ansi.strictIndexOperator=false
-- !query schema
struct<key:string,value:string>
diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out
index 9d42b8a..1ff2a17 100644
--- a/sql/core/src/test/resources/sql-tests/results/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 24
+-- Number of queries: 29
-- !query
@@ -257,3 +257,44 @@ select array(1, 2, 3)[-1]
struct<array(1, 2, 3)[-1]:int>
-- !query output
NULL
+
+
+-- !query
+select array_size(array())
+-- !query schema
+struct<array_size(array()):int>
+-- !query output
+0
+
+
+-- !query
+select array_size(array(true))
+-- !query schema
+struct<array_size(array(true)):int>
+-- !query output
+1
+
+
+-- !query
+select array_size(array(2, 1))
+-- !query schema
+struct<array_size(array(2, 1)):int>
+-- !query output
+2
+
+
+-- !query
+select array_size(NULL)
+-- !query schema
+struct<array_size(NULL):int>
+-- !query output
+NULL
+
+
+-- !query
+select array_size(map('a', 1, 'b', 2))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'array_size(map('a', 1, 'b', 2))' due to data type mismatch: argument 1 requires array type, however, 'map('a', 1, 'b', 2)' is of map<string,int> type.; line 1 pos 7
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org