You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/11/07 06:45:00 UTC
[2/2] spark git commit: [SPARK-18296][SQL] Use consistent naming for
expression test suites
[SPARK-18296][SQL] Use consistent naming for expression test suites
## What changes were proposed in this pull request?
We have an undocumented naming convention to call expression unit tests ExpressionsSuite, and the end-to-end tests FunctionsSuite. It'd be great to make all test suites consistent with this naming convention.
## How was this patch tested?
This is a test-only naming change.
Author: Reynold Xin <rx...@databricks.com>
Closes #15793 from rxin/SPARK-18296.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9db06c44
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9db06c44
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9db06c44
Branch: refs/heads/master
Commit: 9db06c442cf85e41d51c7b167817f4e7971bf0da
Parents: 07ac3f0
Author: Reynold Xin <rx...@databricks.com>
Authored: Sun Nov 6 22:44:55 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sun Nov 6 22:44:55 2016 -0800
----------------------------------------------------------------------
.../expressions/BitwiseExpressionsSuite.scala | 134 +++++
.../expressions/BitwiseFunctionsSuite.scala | 134 -----
.../CollectionExpressionsSuite.scala | 108 ++++
.../expressions/CollectionFunctionsSuite.scala | 109 ----
.../expressions/MathExpressionsSuite.scala | 582 +++++++++++++++++++
.../expressions/MathFunctionsSuite.scala | 582 -------------------
.../expressions/MiscExpressionsSuite.scala | 42 ++
.../expressions/MiscFunctionsSuite.scala | 42 --
.../expressions/NullExpressionsSuite.scala | 136 +++++
.../expressions/NullFunctionsSuite.scala | 136 -----
.../apache/spark/sql/MathExpressionsSuite.scala | 424 --------------
.../apache/spark/sql/MathFunctionsSuite.scala | 424 ++++++++++++++
12 files changed, 1426 insertions(+), 1427 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/9db06c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala
new file mode 100644
index 0000000..4188dad
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+
+
+class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ import IntegralLiteralTestUtils._
+
+ test("BitwiseNOT") {
+ def check(input: Any, expected: Any): Unit = {
+ val expr = BitwiseNot(Literal(input))
+ assert(expr.dataType === Literal(input).dataType)
+ checkEvaluation(expr, expected)
+ }
+
+ // Need the extra toByte even though IntelliJ thought it's not needed.
+ check(1.toByte, (~1.toByte).toByte)
+ check(1000.toShort, (~1000.toShort).toShort)
+ check(1000000, ~1000000)
+ check(123456789123L, ~123456789123L)
+
+ checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null)
+ checkEvaluation(BitwiseNot(positiveShortLit), (~positiveShort).toShort)
+ checkEvaluation(BitwiseNot(negativeShortLit), (~negativeShort).toShort)
+ checkEvaluation(BitwiseNot(positiveIntLit), ~positiveInt)
+ checkEvaluation(BitwiseNot(negativeIntLit), ~negativeInt)
+ checkEvaluation(BitwiseNot(positiveLongLit), ~positiveLong)
+ checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong)
+
+ DataTypeTestUtils.integralType.foreach { dt =>
+ checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot, dt)
+ }
+ }
+
+ test("BitwiseAnd") {
+ def check(input1: Any, input2: Any, expected: Any): Unit = {
+ val expr = BitwiseAnd(Literal(input1), Literal(input2))
+ assert(expr.dataType === Literal(input1).dataType)
+ checkEvaluation(expr, expected)
+ }
+
+ // Need the extra toByte even though IntelliJ thought it's not needed.
+ check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte)
+ check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort)
+ check(1000000, 4, 1000000 & 4)
+ check(123456789123L, 5L, 123456789123L & 5L)
+
+ val nullLit = Literal.create(null, IntegerType)
+ checkEvaluation(BitwiseAnd(nullLit, Literal(1)), null)
+ checkEvaluation(BitwiseAnd(Literal(1), nullLit), null)
+ checkEvaluation(BitwiseAnd(nullLit, nullLit), null)
+ checkEvaluation(BitwiseAnd(positiveShortLit, negativeShortLit),
+ (positiveShort & negativeShort).toShort)
+ checkEvaluation(BitwiseAnd(positiveIntLit, negativeIntLit), positiveInt & negativeInt)
+ checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong)
+
+ DataTypeTestUtils.integralType.foreach { dt =>
+ checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd, dt, dt)
+ }
+ }
+
+ test("BitwiseOr") {
+ def check(input1: Any, input2: Any, expected: Any): Unit = {
+ val expr = BitwiseOr(Literal(input1), Literal(input2))
+ assert(expr.dataType === Literal(input1).dataType)
+ checkEvaluation(expr, expected)
+ }
+
+ // Need the extra toByte even though IntelliJ thought it's not needed.
+ check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte)
+ check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort)
+ check(1000000, 4, 1000000 | 4)
+ check(123456789123L, 5L, 123456789123L | 5L)
+
+ val nullLit = Literal.create(null, IntegerType)
+ checkEvaluation(BitwiseOr(nullLit, Literal(1)), null)
+ checkEvaluation(BitwiseOr(Literal(1), nullLit), null)
+ checkEvaluation(BitwiseOr(nullLit, nullLit), null)
+ checkEvaluation(BitwiseOr(positiveShortLit, negativeShortLit),
+ (positiveShort | negativeShort).toShort)
+ checkEvaluation(BitwiseOr(positiveIntLit, negativeIntLit), positiveInt | negativeInt)
+ checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong)
+
+ DataTypeTestUtils.integralType.foreach { dt =>
+ checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr, dt, dt)
+ }
+ }
+
+ test("BitwiseXor") {
+ def check(input1: Any, input2: Any, expected: Any): Unit = {
+ val expr = BitwiseXor(Literal(input1), Literal(input2))
+ assert(expr.dataType === Literal(input1).dataType)
+ checkEvaluation(expr, expected)
+ }
+
+ // Need the extra toByte even though IntelliJ thought it's not needed.
+ check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte)
+ check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort)
+ check(1000000, 4, 1000000 ^ 4)
+ check(123456789123L, 5L, 123456789123L ^ 5L)
+
+ val nullLit = Literal.create(null, IntegerType)
+ checkEvaluation(BitwiseXor(nullLit, Literal(1)), null)
+ checkEvaluation(BitwiseXor(Literal(1), nullLit), null)
+ checkEvaluation(BitwiseXor(nullLit, nullLit), null)
+ checkEvaluation(BitwiseXor(positiveShortLit, negativeShortLit),
+ (positiveShort ^ negativeShort).toShort)
+ checkEvaluation(BitwiseXor(positiveIntLit, negativeIntLit), positiveInt ^ negativeInt)
+ checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong)
+
+ DataTypeTestUtils.integralType.foreach { dt =>
+ checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt)
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9db06c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
deleted file mode 100644
index 3a310c0..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
+++ /dev/null
@@ -1,134 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types._
-
-
-class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
-
- import IntegralLiteralTestUtils._
-
- test("BitwiseNOT") {
- def check(input: Any, expected: Any): Unit = {
- val expr = BitwiseNot(Literal(input))
- assert(expr.dataType === Literal(input).dataType)
- checkEvaluation(expr, expected)
- }
-
- // Need the extra toByte even though IntelliJ thought it's not needed.
- check(1.toByte, (~1.toByte).toByte)
- check(1000.toShort, (~1000.toShort).toShort)
- check(1000000, ~1000000)
- check(123456789123L, ~123456789123L)
-
- checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null)
- checkEvaluation(BitwiseNot(positiveShortLit), (~positiveShort).toShort)
- checkEvaluation(BitwiseNot(negativeShortLit), (~negativeShort).toShort)
- checkEvaluation(BitwiseNot(positiveIntLit), ~positiveInt)
- checkEvaluation(BitwiseNot(negativeIntLit), ~negativeInt)
- checkEvaluation(BitwiseNot(positiveLongLit), ~positiveLong)
- checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong)
-
- DataTypeTestUtils.integralType.foreach { dt =>
- checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot, dt)
- }
- }
-
- test("BitwiseAnd") {
- def check(input1: Any, input2: Any, expected: Any): Unit = {
- val expr = BitwiseAnd(Literal(input1), Literal(input2))
- assert(expr.dataType === Literal(input1).dataType)
- checkEvaluation(expr, expected)
- }
-
- // Need the extra toByte even though IntelliJ thought it's not needed.
- check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte)
- check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort)
- check(1000000, 4, 1000000 & 4)
- check(123456789123L, 5L, 123456789123L & 5L)
-
- val nullLit = Literal.create(null, IntegerType)
- checkEvaluation(BitwiseAnd(nullLit, Literal(1)), null)
- checkEvaluation(BitwiseAnd(Literal(1), nullLit), null)
- checkEvaluation(BitwiseAnd(nullLit, nullLit), null)
- checkEvaluation(BitwiseAnd(positiveShortLit, negativeShortLit),
- (positiveShort & negativeShort).toShort)
- checkEvaluation(BitwiseAnd(positiveIntLit, negativeIntLit), positiveInt & negativeInt)
- checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong)
-
- DataTypeTestUtils.integralType.foreach { dt =>
- checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd, dt, dt)
- }
- }
-
- test("BitwiseOr") {
- def check(input1: Any, input2: Any, expected: Any): Unit = {
- val expr = BitwiseOr(Literal(input1), Literal(input2))
- assert(expr.dataType === Literal(input1).dataType)
- checkEvaluation(expr, expected)
- }
-
- // Need the extra toByte even though IntelliJ thought it's not needed.
- check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte)
- check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort)
- check(1000000, 4, 1000000 | 4)
- check(123456789123L, 5L, 123456789123L | 5L)
-
- val nullLit = Literal.create(null, IntegerType)
- checkEvaluation(BitwiseOr(nullLit, Literal(1)), null)
- checkEvaluation(BitwiseOr(Literal(1), nullLit), null)
- checkEvaluation(BitwiseOr(nullLit, nullLit), null)
- checkEvaluation(BitwiseOr(positiveShortLit, negativeShortLit),
- (positiveShort | negativeShort).toShort)
- checkEvaluation(BitwiseOr(positiveIntLit, negativeIntLit), positiveInt | negativeInt)
- checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong)
-
- DataTypeTestUtils.integralType.foreach { dt =>
- checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr, dt, dt)
- }
- }
-
- test("BitwiseXor") {
- def check(input1: Any, input2: Any, expected: Any): Unit = {
- val expr = BitwiseXor(Literal(input1), Literal(input2))
- assert(expr.dataType === Literal(input1).dataType)
- checkEvaluation(expr, expected)
- }
-
- // Need the extra toByte even though IntelliJ thought it's not needed.
- check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte)
- check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort)
- check(1000000, 4, 1000000 ^ 4)
- check(123456789123L, 5L, 123456789123L ^ 5L)
-
- val nullLit = Literal.create(null, IntegerType)
- checkEvaluation(BitwiseXor(nullLit, Literal(1)), null)
- checkEvaluation(BitwiseXor(Literal(1), nullLit), null)
- checkEvaluation(BitwiseXor(nullLit, nullLit), null)
- checkEvaluation(BitwiseXor(positiveShortLit, negativeShortLit),
- (positiveShort ^ negativeShort).toShort)
- checkEvaluation(BitwiseXor(positiveIntLit, negativeIntLit), positiveInt ^ negativeInt)
- checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong)
-
- DataTypeTestUtils.integralType.foreach { dt =>
- checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt)
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/9db06c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
new file mode 100644
index 0000000..020687e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+
+class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("Array and Map Size") {
+ val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
+ val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
+
+ checkEvaluation(Size(a0), 3)
+ checkEvaluation(Size(a1), 0)
+ checkEvaluation(Size(a2), 2)
+
+ val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType))
+ val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
+ val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType))
+
+ checkEvaluation(Size(m0), 2)
+ checkEvaluation(Size(m1), 0)
+ checkEvaluation(Size(m2), 1)
+
+ checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1)
+ checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1)
+ }
+
+ test("MapKeys/MapValues") {
+ val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType))
+ val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
+ val m2 = Literal.create(null, MapType(StringType, StringType))
+
+ checkEvaluation(MapKeys(m0), Seq("a", "b"))
+ checkEvaluation(MapValues(m0), Seq("1", "2"))
+ checkEvaluation(MapKeys(m1), Seq())
+ checkEvaluation(MapValues(m1), Seq())
+ checkEvaluation(MapKeys(m2), null)
+ checkEvaluation(MapValues(m2), null)
+ }
+
+ test("Sort Array") {
+ val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
+ val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
+ val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
+ val a4 = Literal.create(Seq(null, null), ArrayType(NullType))
+
+ checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
+ checkEvaluation(new SortArray(a1), Seq[Integer]())
+ checkEvaluation(new SortArray(a2), Seq("a", "b"))
+ checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
+ checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
+ checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
+ checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
+ checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
+ checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
+ checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
+ checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
+ checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))
+
+ checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
+ checkEvaluation(new SortArray(a4), Seq(null, null))
+
+ val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+ val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS)
+
+ checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2)))
+ }
+
+ test("Array contains") {
+ val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+ val a2 = Literal.create(Seq(null), ArrayType(LongType))
+ val a3 = Literal.create(null, ArrayType(StringType))
+
+ checkEvaluation(ArrayContains(a0, Literal(1)), true)
+ checkEvaluation(ArrayContains(a0, Literal(0)), false)
+ checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null)
+
+ checkEvaluation(ArrayContains(a1, Literal("")), true)
+ checkEvaluation(ArrayContains(a1, Literal("a")), null)
+ checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null)
+
+ checkEvaluation(ArrayContains(a2, Literal(1L)), null)
+ checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null)
+
+ checkEvaluation(ArrayContains(a3, Literal("")), null)
+ checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9db06c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
deleted file mode 100644
index c76dad2..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types._
-
-
-class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
-
- test("Array and Map Size") {
- val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
- val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
- val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
-
- checkEvaluation(Size(a0), 3)
- checkEvaluation(Size(a1), 0)
- checkEvaluation(Size(a2), 2)
-
- val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType))
- val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
- val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType))
-
- checkEvaluation(Size(m0), 2)
- checkEvaluation(Size(m1), 0)
- checkEvaluation(Size(m2), 1)
-
- checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1)
- checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1)
- }
-
- test("MapKeys/MapValues") {
- val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType))
- val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
- val m2 = Literal.create(null, MapType(StringType, StringType))
-
- checkEvaluation(MapKeys(m0), Seq("a", "b"))
- checkEvaluation(MapValues(m0), Seq("1", "2"))
- checkEvaluation(MapKeys(m1), Seq())
- checkEvaluation(MapValues(m1), Seq())
- checkEvaluation(MapKeys(m2), null)
- checkEvaluation(MapValues(m2), null)
- }
-
- test("Sort Array") {
- val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
- val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
- val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
- val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
- val a4 = Literal.create(Seq(null, null), ArrayType(NullType))
-
- checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
- checkEvaluation(new SortArray(a1), Seq[Integer]())
- checkEvaluation(new SortArray(a2), Seq("a", "b"))
- checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
- checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
- checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
- checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
- checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
- checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
- checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
- checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
- checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))
-
- checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
- checkEvaluation(new SortArray(a4), Seq(null, null))
-
- val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
- val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS)
-
- checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2)))
- }
-
- test("Array contains") {
- val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
- val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
- val a2 = Literal.create(Seq(null), ArrayType(LongType))
- val a3 = Literal.create(null, ArrayType(StringType))
-
- checkEvaluation(ArrayContains(a0, Literal(1)), true)
- checkEvaluation(ArrayContains(a0, Literal(0)), false)
- checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null)
-
- checkEvaluation(ArrayContains(a1, Literal("")), true)
- checkEvaluation(ArrayContains(a1, Literal("a")), null)
- checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null)
-
- checkEvaluation(ArrayContains(a2, Literal(1L)), null)
- checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null)
-
- checkEvaluation(ArrayContains(a3, Literal("")), null)
- checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/9db06c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
new file mode 100644
index 0000000..6b5bfac
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -0,0 +1,582 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.nio.charset.StandardCharsets
+
+import com.google.common.math.LongMath
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
+import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
+import org.apache.spark.sql.types._
+
+class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ import IntegralLiteralTestUtils._
+
+ /**
+ * Used for testing leaf math expressions.
+ *
+ * @param e expression
+ * @param c The constants in scala.math
+ * @tparam T Generic type for primitives
+ */
+ private def testLeaf[T](
+ e: () => Expression,
+ c: T): Unit = {
+ checkEvaluation(e(), c, EmptyRow)
+ checkEvaluation(e(), c, create_row(null))
+ }
+
+ /**
+ * Used for testing unary math expressions.
+ *
+ * @param c expression
+ * @param f The functions in scala.math or elsewhere used to generate expected results
+ * @param domain The set of values to run the function with
+ * @param expectNull Whether the given values should return null or not
+ * @param expectNaN Whether the given values should eval to NaN or not
+ * @tparam T Generic type for primitives
+ * @tparam U Generic type for the output of the given function `f`
+ */
+ private def testUnary[T, U](
+ c: Expression => Expression,
+ f: T => U,
+ domain: Iterable[T] = (-20 to 20).map(_ * 0.1),
+ expectNull: Boolean = false,
+ expectNaN: Boolean = false,
+ evalType: DataType = DoubleType): Unit = {
+ if (expectNull) {
+ domain.foreach { value =>
+ checkEvaluation(c(Literal(value)), null, EmptyRow)
+ }
+ } else if (expectNaN) {
+ domain.foreach { value =>
+ checkNaN(c(Literal(value)), EmptyRow)
+ }
+ } else {
+ domain.foreach { value =>
+ checkEvaluation(c(Literal(value)), f(value), EmptyRow)
+ }
+ }
+ checkEvaluation(c(Literal.create(null, evalType)), null, create_row(null))
+ }
+
+ /**
+ * Used for testing binary math expressions.
+ *
+ * @param c The DataFrame function
+ * @param f The functions in scala.math
+ * @param domain The set of values to run the function with
+ * @param expectNull Whether the given values should return null or not
+ * @param expectNaN Whether the given values should eval to NaN or not
+ */
+ private def testBinary(
+ c: (Expression, Expression) => Expression,
+ f: (Double, Double) => Double,
+ domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)),
+ expectNull: Boolean = false, expectNaN: Boolean = false): Unit = {
+ if (expectNull) {
+ domain.foreach { case (v1, v2) =>
+ checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null))
+ }
+ } else if (expectNaN) {
+ domain.foreach { case (v1, v2) =>
+ checkNaN(c(Literal(v1), Literal(v2)), EmptyRow)
+ }
+ } else {
+ domain.foreach { case (v1, v2) =>
+ checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
+ checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
+ }
+ }
+ checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null))
+ checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null))
+ }
+
+ private def checkNaN(
+ expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
+ checkNaNWithoutCodegen(expression, inputRow)
+ checkNaNWithGeneratedProjection(expression, inputRow)
+ checkNaNWithOptimization(expression, inputRow)
+ }
+
+ private def checkNaNWithoutCodegen(
+ expression: Expression,
+ inputRow: InternalRow = EmptyRow): Unit = {
+ val actual = try evaluate(expression, inputRow) catch {
+ case e: Exception => fail(s"Exception evaluating $expression", e)
+ }
+ if (!actual.asInstanceOf[Double].isNaN) {
+ fail(s"Incorrect evaluation (codegen off): $expression, " +
+ s"actual: $actual, " +
+ s"expected: NaN")
+ }
+ }
+
+ private def checkNaNWithGeneratedProjection(
+ expression: Expression,
+ inputRow: InternalRow = EmptyRow): Unit = {
+
+ val plan = generateProject(
+ GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
+ expression)
+
+ val actual = plan(inputRow).get(0, expression.dataType)
+ if (!actual.asInstanceOf[Double].isNaN) {
+ fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN")
+ }
+ }
+
+ private def checkNaNWithOptimization(
+ expression: Expression,
+ inputRow: InternalRow = EmptyRow): Unit = {
+ val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
+ val optimizedPlan = SimpleTestOptimizer.execute(plan)
+ checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
+ }
+
+ test("conv") {
+ checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
+ checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
+ checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
+ checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
+ checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
+ checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
+ checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ Conv(Literal("1234"), Literal(10), Literal(37)), null)
+ checkEvaluation(
+ Conv(Literal(""), Literal(10), Literal(16)), null)
+ checkEvaluation(
+ Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
+ // If there is an invalid digit in the number, the longest valid prefix should be converted.
+ checkEvaluation(
+ Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
+ }
+
+ test("e") {
+ testLeaf(EulerNumber, math.E)
+ }
+
+ test("pi") {
+ testLeaf(Pi, math.Pi)
+ }
+
+ test("sin") {
+ testUnary(Sin, math.sin)
+ checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType)
+ }
+
+ test("asin") {
+ testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1))
+ testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true)
+ checkConsistencyBetweenInterpretedAndCodegen(Asin, DoubleType)
+ }
+
+ test("sinh") {
+ testUnary(Sinh, math.sinh)
+ checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType)
+ }
+
+ test("cos") {
+ testUnary(Cos, math.cos)
+ checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType)
+ }
+
+ test("acos") {
+ testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1))
+ testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true)
+ checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType)
+ }
+
+ test("cosh") {
+ testUnary(Cosh, math.cosh)
+ checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType)
+ }
+
+ test("tan") {
+ testUnary(Tan, math.tan)
+ checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType)
+ }
+
+ test("atan") {
+ testUnary(Atan, math.atan)
+ checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType)
+ }
+
+ test("tanh") {
+ testUnary(Tanh, math.tanh)
+ checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType)
+ }
+
+ test("toDegrees") {
+ testUnary(ToDegrees, math.toDegrees)
+ checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType)
+ }
+
+ test("toRadians") {
+ testUnary(ToRadians, math.toRadians)
+ checkConsistencyBetweenInterpretedAndCodegen(ToRadians, DoubleType)
+ }
+
+ test("cbrt") {
+ testUnary(Cbrt, math.cbrt)
+ checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType)
+ }
+
+ test("ceil") {
+ testUnary(Ceil, (d: Double) => math.ceil(d).toLong)
+ checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType)
+
+ testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1)))
+ checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3))
+ checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0))
+ checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0))
+ }
+
+ test("floor") {
+ testUnary(Floor, (d: Double) => math.floor(d).toLong)
+ checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType)
+
+ testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1)))
+ checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3))
+ checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0))
+ checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0))
+ }
+
+ test("factorial") {
+ (0 to 20).foreach { value =>
+ checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow)
+ }
+ checkEvaluation(Literal.create(null, IntegerType), null, create_row(null))
+ checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow)
+ checkEvaluation(Factorial(Literal(21)), null, EmptyRow)
+ checkConsistencyBetweenInterpretedAndCodegen(Factorial.apply _, IntegerType)
+ }
+
+ test("rint") {
+ testUnary(Rint, math.rint)
+ checkConsistencyBetweenInterpretedAndCodegen(Rint, DoubleType)
+ }
+
+ test("exp") {
+ testUnary(Exp, math.exp)
+ checkConsistencyBetweenInterpretedAndCodegen(Exp, DoubleType)
+ }
+
+ test("expm1") {
+ testUnary(Expm1, math.expm1)
+ checkConsistencyBetweenInterpretedAndCodegen(Expm1, DoubleType)
+ }
+
+ test("signum") {
+ testUnary[Double, Double](Signum, math.signum)
+ checkConsistencyBetweenInterpretedAndCodegen(Signum, DoubleType)
+ }
+
+ test("log") {
+ testUnary(Log, math.log, (1 to 20).map(_ * 0.1))
+ testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true)
+ checkConsistencyBetweenInterpretedAndCodegen(Log, DoubleType)
+ }
+
+ test("log10") {
+ testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1))
+ testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true)
+ checkConsistencyBetweenInterpretedAndCodegen(Log10, DoubleType)
+ }
+
+ test("log1p") {
+ testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1))
+ testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true)
+ checkConsistencyBetweenInterpretedAndCodegen(Log1p, DoubleType)
+ }
+
+ test("bin") {
+ testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType)
+
+ val row = create_row(null, 12L, 123L, 1234L, -123L)
+ val l1 = 'a.long.at(0)
+ val l2 = 'a.long.at(1)
+ val l3 = 'a.long.at(2)
+ val l4 = 'a.long.at(3)
+ val l5 = 'a.long.at(4)
+
+ checkEvaluation(Bin(l1), null, row)
+ checkEvaluation(Bin(l2), java.lang.Long.toBinaryString(12), row)
+ checkEvaluation(Bin(l3), java.lang.Long.toBinaryString(123), row)
+ checkEvaluation(Bin(l4), java.lang.Long.toBinaryString(1234), row)
+ checkEvaluation(Bin(l5), java.lang.Long.toBinaryString(-123), row)
+
+ checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong))
+ checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong))
+
+ checkConsistencyBetweenInterpretedAndCodegen(Bin, LongType)
+ }
+
+ test("log2") {
+ def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2)
+ testUnary(Log2, f, (1 to 20).map(_ * 0.1))
+ testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true)
+ checkConsistencyBetweenInterpretedAndCodegen(Log2, DoubleType)
+ }
+
+ test("sqrt") {
+ testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1))
+ testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true)
+
+ checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
+ checkNaN(Sqrt(Literal(-1.0)), EmptyRow)
+ checkNaN(Sqrt(Literal(-1.5)), EmptyRow)
+ checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType)
+ }
+
+ test("pow") {
+ testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
+ testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true)
+ checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType)
+ }
+
+ test("shift left") {
+ checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null)
+ checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+ checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42)
+
+ checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
+ checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong)
+
+ checkEvaluation(ShiftLeft(positiveIntLit, positiveIntLit), positiveInt << positiveInt)
+ checkEvaluation(ShiftLeft(positiveIntLit, negativeIntLit), positiveInt << negativeInt)
+ checkEvaluation(ShiftLeft(negativeIntLit, positiveIntLit), negativeInt << positiveInt)
+ checkEvaluation(ShiftLeft(negativeIntLit, negativeIntLit), negativeInt << negativeInt)
+ checkEvaluation(ShiftLeft(positiveLongLit, positiveIntLit), positiveLong << positiveInt)
+ checkEvaluation(ShiftLeft(positiveLongLit, negativeIntLit), positiveLong << negativeInt)
+ checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt)
+ checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt)
+
+ checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, IntegerType, IntegerType)
+ checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, LongType, IntegerType)
+ }
+
+ test("shift right") {
+ checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null)
+ checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+ checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21)
+
+ checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
+ checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
+
+ checkEvaluation(ShiftRight(positiveIntLit, positiveIntLit), positiveInt >> positiveInt)
+ checkEvaluation(ShiftRight(positiveIntLit, negativeIntLit), positiveInt >> negativeInt)
+ checkEvaluation(ShiftRight(negativeIntLit, positiveIntLit), negativeInt >> positiveInt)
+ checkEvaluation(ShiftRight(negativeIntLit, negativeIntLit), negativeInt >> negativeInt)
+ checkEvaluation(ShiftRight(positiveLongLit, positiveIntLit), positiveLong >> positiveInt)
+ checkEvaluation(ShiftRight(positiveLongLit, negativeIntLit), positiveLong >> negativeInt)
+ checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt)
+ checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt)
+
+ checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, IntegerType, IntegerType)
+ checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, LongType, IntegerType)
+ }
+
+ test("shift right unsigned") {
+ checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null)
+ checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null)
+ checkEvaluation(
+ ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
+ checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
+
+ checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
+ checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
+
+ checkEvaluation(ShiftRightUnsigned(positiveIntLit, positiveIntLit),
+ positiveInt >>> positiveInt)
+ checkEvaluation(ShiftRightUnsigned(positiveIntLit, negativeIntLit),
+ positiveInt >>> negativeInt)
+ checkEvaluation(ShiftRightUnsigned(negativeIntLit, positiveIntLit),
+ negativeInt >>> positiveInt)
+ checkEvaluation(ShiftRightUnsigned(negativeIntLit, negativeIntLit),
+ negativeInt >>> negativeInt)
+ checkEvaluation(ShiftRightUnsigned(positiveLongLit, positiveIntLit),
+ positiveLong >>> positiveInt)
+ checkEvaluation(ShiftRightUnsigned(positiveLongLit, negativeIntLit),
+ positiveLong >>> negativeInt)
+ checkEvaluation(ShiftRightUnsigned(negativeLongLit, positiveIntLit),
+ negativeLong >>> positiveInt)
+ checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit),
+ negativeLong >>> negativeInt)
+
+ checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, IntegerType, IntegerType)
+ checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, LongType, IntegerType)
+ }
+
+ test("hex") {
+ checkEvaluation(Hex(Literal.create(null, LongType)), null)
+ checkEvaluation(Hex(Literal(28L)), "1C")
+ checkEvaluation(Hex(Literal(-28L)), "FFFFFFFFFFFFFFE4")
+ checkEvaluation(Hex(Literal(100800200404L)), "177828FED4")
+ checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C")
+ checkEvaluation(Hex(Literal.create(null, BinaryType)), null)
+ checkEvaluation(Hex(Literal("helloHex".getBytes(StandardCharsets.UTF_8))), "68656C6C6F486578")
+ // scalastyle:off
+ // Turn off scala style for non-ascii chars
+ checkEvaluation(Hex(Literal("\u4e09\u91cd\u7684".getBytes(StandardCharsets.UTF_8))), "E4B889E9878DE79A84")
+ // scalastyle:on
+ Seq(LongType, BinaryType, StringType).foreach { dt =>
+ checkConsistencyBetweenInterpretedAndCodegen(Hex.apply _, dt)
+ }
+ }
+
+ test("unhex") {
+ checkEvaluation(Unhex(Literal.create(null, StringType)), null)
+ checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes(StandardCharsets.UTF_8))
+ checkEvaluation(Unhex(Literal("")), new Array[Byte](0))
+ checkEvaluation(Unhex(Literal("F")), Array[Byte](15))
+ checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1))
+ checkEvaluation(Unhex(Literal("GG")), null)
+ // scalastyle:off
+ // Turn off scala style for non-ascii chars
+ checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "\u4e09\u91cd\u7684".getBytes(StandardCharsets.UTF_8))
+ checkEvaluation(Unhex(Literal("\u4e09\u91cd\u7684")), null)
+ // scalastyle:on
+ checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType)
+ }
+
+ test("hypot") {
+ testBinary(Hypot, math.hypot)
+ checkConsistencyBetweenInterpretedAndCodegen(Hypot, DoubleType, DoubleType)
+ }
+
+ test("atan2") {
+ testBinary(Atan2, math.atan2)
+ checkConsistencyBetweenInterpretedAndCodegen(Atan2, DoubleType, DoubleType)
+ }
+
+ test("binary log") {
+ val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1)
+ val domain = (1 to 20).map(v => (v * 0.1, v * 0.2))
+
+ domain.foreach { case (v1, v2) =>
+ checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
+ checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
+ checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow)
+ }
+
+ // null input should yield null output
+ checkEvaluation(
+ Logarithm(Literal.create(null, DoubleType), Literal(1.0)),
+ null,
+ create_row(null))
+ checkEvaluation(
+ Logarithm(Literal(1.0), Literal.create(null, DoubleType)),
+ null,
+ create_row(null))
+
+ // negative input should yield null output
+ checkEvaluation(
+ Logarithm(Literal(-1.0), Literal(1.0)),
+ null,
+ create_row(null))
+ checkEvaluation(
+ Logarithm(Literal(1.0), Literal(-1.0)),
+ null,
+ create_row(null))
+ checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType)
+ }
+
+ test("round/bround") {
+ val scales = -6 to 6
+ val doublePi: Double = math.Pi
+ val shortPi: Short = 31415
+ val intPi: Int = 314159265
+ val longPi: Long = 31415926535897932L
+ val bdPi: BigDecimal = BigDecimal(31415927L, 7)
+
+ val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142,
+ 3.1416, 3.14159, 3.141593)
+
+ val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++
+ Seq.fill[Short](7)(31415)
+
+ val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
+ 314159270) ++ Seq.fill(7)(314159265)
+
+ val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
+ 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
+ Seq.fill(7)(31415926535897932L)
+
+ val intResultsB: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
+ 314159260) ++ Seq.fill(7)(314159265)
+
+ scales.zipWithIndex.foreach { case (scale, i) =>
+ checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
+ checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
+ checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
+ checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
+ checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow)
+ checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow)
+ checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow)
+ checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow)
+ }
+
+ val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
+ BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
+ BigDecimal(3.141593), BigDecimal(3.1415927))
+ // round_scale > current_scale would result in precision increase
+ // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
+ (0 to 7).foreach { i =>
+ checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
+ checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow)
+ }
+ (8 to 10).foreach { scale =>
+ checkEvaluation(Round(bdPi, scale), null, EmptyRow)
+ checkEvaluation(BRound(bdPi, scale), null, EmptyRow)
+ }
+
+ DataTypeTestUtils.numericTypes.foreach { dataType =>
+ checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null)
+ checkEvaluation(Round(Literal.create(null, dataType),
+ Literal.create(null, IntegerType)), null)
+ checkEvaluation(BRound(Literal.create(null, dataType), Literal(2)), null)
+ checkEvaluation(BRound(Literal.create(null, dataType),
+ Literal.create(null, IntegerType)), null)
+ }
+
+ checkEvaluation(Round(2.5, 0), 3.0)
+ checkEvaluation(Round(3.5, 0), 4.0)
+ checkEvaluation(Round(-2.5, 0), -3.0)
+ checkEvaluation(Round(-3.5, 0), -4.0)
+ checkEvaluation(Round(-0.35, 1), -0.4)
+ checkEvaluation(Round(-35, -1), -40)
+ checkEvaluation(BRound(2.5, 0), 2.0)
+ checkEvaluation(BRound(3.5, 0), 4.0)
+ checkEvaluation(BRound(-2.5, 0), -2.0)
+ checkEvaluation(BRound(-3.5, 0), -4.0)
+ checkEvaluation(BRound(-0.35, 1), -0.4)
+ checkEvaluation(BRound(-35, -1), -40)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9db06c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
deleted file mode 100644
index f88c9e8..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ /dev/null
@@ -1,582 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import java.nio.charset.StandardCharsets
-
-import com.google.common.math.LongMath
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
-import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
-import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
-import org.apache.spark.sql.types._
-
-class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
-
- import IntegralLiteralTestUtils._
-
- /**
- * Used for testing leaf math expressions.
- *
- * @param e expression
- * @param c The constants in scala.math
- * @tparam T Generic type for primitives
- */
- private def testLeaf[T](
- e: () => Expression,
- c: T): Unit = {
- checkEvaluation(e(), c, EmptyRow)
- checkEvaluation(e(), c, create_row(null))
- }
-
- /**
- * Used for testing unary math expressions.
- *
- * @param c expression
- * @param f The functions in scala.math or elsewhere used to generate expected results
- * @param domain The set of values to run the function with
- * @param expectNull Whether the given values should return null or not
- * @param expectNaN Whether the given values should eval to NaN or not
- * @tparam T Generic type for primitives
- * @tparam U Generic type for the output of the given function `f`
- */
- private def testUnary[T, U](
- c: Expression => Expression,
- f: T => U,
- domain: Iterable[T] = (-20 to 20).map(_ * 0.1),
- expectNull: Boolean = false,
- expectNaN: Boolean = false,
- evalType: DataType = DoubleType): Unit = {
- if (expectNull) {
- domain.foreach { value =>
- checkEvaluation(c(Literal(value)), null, EmptyRow)
- }
- } else if (expectNaN) {
- domain.foreach { value =>
- checkNaN(c(Literal(value)), EmptyRow)
- }
- } else {
- domain.foreach { value =>
- checkEvaluation(c(Literal(value)), f(value), EmptyRow)
- }
- }
- checkEvaluation(c(Literal.create(null, evalType)), null, create_row(null))
- }
-
- /**
- * Used for testing binary math expressions.
- *
- * @param c The DataFrame function
- * @param f The functions in scala.math
- * @param domain The set of values to run the function with
- * @param expectNull Whether the given values should return null or not
- * @param expectNaN Whether the given values should eval to NaN or not
- */
- private def testBinary(
- c: (Expression, Expression) => Expression,
- f: (Double, Double) => Double,
- domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)),
- expectNull: Boolean = false, expectNaN: Boolean = false): Unit = {
- if (expectNull) {
- domain.foreach { case (v1, v2) =>
- checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null))
- }
- } else if (expectNaN) {
- domain.foreach { case (v1, v2) =>
- checkNaN(c(Literal(v1), Literal(v2)), EmptyRow)
- }
- } else {
- domain.foreach { case (v1, v2) =>
- checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
- checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
- }
- }
- checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null))
- checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null))
- }
-
- private def checkNaN(
- expression: Expression, inputRow: InternalRow = EmptyRow): Unit = {
- checkNaNWithoutCodegen(expression, inputRow)
- checkNaNWithGeneratedProjection(expression, inputRow)
- checkNaNWithOptimization(expression, inputRow)
- }
-
- private def checkNaNWithoutCodegen(
- expression: Expression,
- inputRow: InternalRow = EmptyRow): Unit = {
- val actual = try evaluate(expression, inputRow) catch {
- case e: Exception => fail(s"Exception evaluating $expression", e)
- }
- if (!actual.asInstanceOf[Double].isNaN) {
- fail(s"Incorrect evaluation (codegen off): $expression, " +
- s"actual: $actual, " +
- s"expected: NaN")
- }
- }
-
- private def checkNaNWithGeneratedProjection(
- expression: Expression,
- inputRow: InternalRow = EmptyRow): Unit = {
-
- val plan = generateProject(
- GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
- expression)
-
- val actual = plan(inputRow).get(0, expression.dataType)
- if (!actual.asInstanceOf[Double].isNaN) {
- fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN")
- }
- }
-
- private def checkNaNWithOptimization(
- expression: Expression,
- inputRow: InternalRow = EmptyRow): Unit = {
- val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = SimpleTestOptimizer.execute(plan)
- checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow)
- }
-
- test("conv") {
- checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
- checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
- checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
- checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
- checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
- checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
- checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null)
- checkEvaluation(
- Conv(Literal("1234"), Literal(10), Literal(37)), null)
- checkEvaluation(
- Conv(Literal(""), Literal(10), Literal(16)), null)
- checkEvaluation(
- Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
- // If there is an invalid digit in the number, the longest valid prefix should be converted.
- checkEvaluation(
- Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
- }
-
- test("e") {
- testLeaf(EulerNumber, math.E)
- }
-
- test("pi") {
- testLeaf(Pi, math.Pi)
- }
-
- test("sin") {
- testUnary(Sin, math.sin)
- checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType)
- }
-
- test("asin") {
- testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1))
- testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true)
- checkConsistencyBetweenInterpretedAndCodegen(Asin, DoubleType)
- }
-
- test("sinh") {
- testUnary(Sinh, math.sinh)
- checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType)
- }
-
- test("cos") {
- testUnary(Cos, math.cos)
- checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType)
- }
-
- test("acos") {
- testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1))
- testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true)
- checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType)
- }
-
- test("cosh") {
- testUnary(Cosh, math.cosh)
- checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType)
- }
-
- test("tan") {
- testUnary(Tan, math.tan)
- checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType)
- }
-
- test("atan") {
- testUnary(Atan, math.atan)
- checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType)
- }
-
- test("tanh") {
- testUnary(Tanh, math.tanh)
- checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType)
- }
-
- test("toDegrees") {
- testUnary(ToDegrees, math.toDegrees)
- checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType)
- }
-
- test("toRadians") {
- testUnary(ToRadians, math.toRadians)
- checkConsistencyBetweenInterpretedAndCodegen(ToRadians, DoubleType)
- }
-
- test("cbrt") {
- testUnary(Cbrt, math.cbrt)
- checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType)
- }
-
- test("ceil") {
- testUnary(Ceil, (d: Double) => math.ceil(d).toLong)
- checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType)
-
- testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1)))
- checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3))
- checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0))
- checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0))
- }
-
- test("floor") {
- testUnary(Floor, (d: Double) => math.floor(d).toLong)
- checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType)
-
- testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1)))
- checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3))
- checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0))
- checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0))
- }
-
- test("factorial") {
- (0 to 20).foreach { value =>
- checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow)
- }
- checkEvaluation(Literal.create(null, IntegerType), null, create_row(null))
- checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow)
- checkEvaluation(Factorial(Literal(21)), null, EmptyRow)
- checkConsistencyBetweenInterpretedAndCodegen(Factorial.apply _, IntegerType)
- }
-
- test("rint") {
- testUnary(Rint, math.rint)
- checkConsistencyBetweenInterpretedAndCodegen(Rint, DoubleType)
- }
-
- test("exp") {
- testUnary(Exp, math.exp)
- checkConsistencyBetweenInterpretedAndCodegen(Exp, DoubleType)
- }
-
- test("expm1") {
- testUnary(Expm1, math.expm1)
- checkConsistencyBetweenInterpretedAndCodegen(Expm1, DoubleType)
- }
-
- test("signum") {
- testUnary[Double, Double](Signum, math.signum)
- checkConsistencyBetweenInterpretedAndCodegen(Signum, DoubleType)
- }
-
- test("log") {
- testUnary(Log, math.log, (1 to 20).map(_ * 0.1))
- testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true)
- checkConsistencyBetweenInterpretedAndCodegen(Log, DoubleType)
- }
-
- test("log10") {
- testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1))
- testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true)
- checkConsistencyBetweenInterpretedAndCodegen(Log10, DoubleType)
- }
-
- test("log1p") {
- testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1))
- testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true)
- checkConsistencyBetweenInterpretedAndCodegen(Log1p, DoubleType)
- }
-
- test("bin") {
- testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType)
-
- val row = create_row(null, 12L, 123L, 1234L, -123L)
- val l1 = 'a.long.at(0)
- val l2 = 'a.long.at(1)
- val l3 = 'a.long.at(2)
- val l4 = 'a.long.at(3)
- val l5 = 'a.long.at(4)
-
- checkEvaluation(Bin(l1), null, row)
- checkEvaluation(Bin(l2), java.lang.Long.toBinaryString(12), row)
- checkEvaluation(Bin(l3), java.lang.Long.toBinaryString(123), row)
- checkEvaluation(Bin(l4), java.lang.Long.toBinaryString(1234), row)
- checkEvaluation(Bin(l5), java.lang.Long.toBinaryString(-123), row)
-
- checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong))
- checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong))
-
- checkConsistencyBetweenInterpretedAndCodegen(Bin, LongType)
- }
-
- test("log2") {
- def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2)
- testUnary(Log2, f, (1 to 20).map(_ * 0.1))
- testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true)
- checkConsistencyBetweenInterpretedAndCodegen(Log2, DoubleType)
- }
-
- test("sqrt") {
- testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1))
- testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true)
-
- checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null))
- checkNaN(Sqrt(Literal(-1.0)), EmptyRow)
- checkNaN(Sqrt(Literal(-1.5)), EmptyRow)
- checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType)
- }
-
- test("pow") {
- testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0)))
- testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true)
- checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType)
- }
-
- test("shift left") {
- checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null)
- checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null)
- checkEvaluation(
- ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
- checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42)
-
- checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
- checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong)
-
- checkEvaluation(ShiftLeft(positiveIntLit, positiveIntLit), positiveInt << positiveInt)
- checkEvaluation(ShiftLeft(positiveIntLit, negativeIntLit), positiveInt << negativeInt)
- checkEvaluation(ShiftLeft(negativeIntLit, positiveIntLit), negativeInt << positiveInt)
- checkEvaluation(ShiftLeft(negativeIntLit, negativeIntLit), negativeInt << negativeInt)
- checkEvaluation(ShiftLeft(positiveLongLit, positiveIntLit), positiveLong << positiveInt)
- checkEvaluation(ShiftLeft(positiveLongLit, negativeIntLit), positiveLong << negativeInt)
- checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt)
- checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt)
-
- checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, IntegerType, IntegerType)
- checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, LongType, IntegerType)
- }
-
- test("shift right") {
- checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null)
- checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null)
- checkEvaluation(
- ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
- checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21)
-
- checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
- checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
-
- checkEvaluation(ShiftRight(positiveIntLit, positiveIntLit), positiveInt >> positiveInt)
- checkEvaluation(ShiftRight(positiveIntLit, negativeIntLit), positiveInt >> negativeInt)
- checkEvaluation(ShiftRight(negativeIntLit, positiveIntLit), negativeInt >> positiveInt)
- checkEvaluation(ShiftRight(negativeIntLit, negativeIntLit), negativeInt >> negativeInt)
- checkEvaluation(ShiftRight(positiveLongLit, positiveIntLit), positiveLong >> positiveInt)
- checkEvaluation(ShiftRight(positiveLongLit, negativeIntLit), positiveLong >> negativeInt)
- checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt)
- checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt)
-
- checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, IntegerType, IntegerType)
- checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, LongType, IntegerType)
- }
-
- test("shift right unsigned") {
- checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null)
- checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null)
- checkEvaluation(
- ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
- checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21)
-
- checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
- checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
-
- checkEvaluation(ShiftRightUnsigned(positiveIntLit, positiveIntLit),
- positiveInt >>> positiveInt)
- checkEvaluation(ShiftRightUnsigned(positiveIntLit, negativeIntLit),
- positiveInt >>> negativeInt)
- checkEvaluation(ShiftRightUnsigned(negativeIntLit, positiveIntLit),
- negativeInt >>> positiveInt)
- checkEvaluation(ShiftRightUnsigned(negativeIntLit, negativeIntLit),
- negativeInt >>> negativeInt)
- checkEvaluation(ShiftRightUnsigned(positiveLongLit, positiveIntLit),
- positiveLong >>> positiveInt)
- checkEvaluation(ShiftRightUnsigned(positiveLongLit, negativeIntLit),
- positiveLong >>> negativeInt)
- checkEvaluation(ShiftRightUnsigned(negativeLongLit, positiveIntLit),
- negativeLong >>> positiveInt)
- checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit),
- negativeLong >>> negativeInt)
-
- checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, IntegerType, IntegerType)
- checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, LongType, IntegerType)
- }
-
- test("hex") {
- checkEvaluation(Hex(Literal.create(null, LongType)), null)
- checkEvaluation(Hex(Literal(28L)), "1C")
- checkEvaluation(Hex(Literal(-28L)), "FFFFFFFFFFFFFFE4")
- checkEvaluation(Hex(Literal(100800200404L)), "177828FED4")
- checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C")
- checkEvaluation(Hex(Literal.create(null, BinaryType)), null)
- checkEvaluation(Hex(Literal("helloHex".getBytes(StandardCharsets.UTF_8))), "68656C6C6F486578")
- // scalastyle:off
- // Turn off scala style for non-ascii chars
- checkEvaluation(Hex(Literal("\u4e09\u91cd\u7684".getBytes(StandardCharsets.UTF_8))), "E4B889E9878DE79A84")
- // scalastyle:on
- Seq(LongType, BinaryType, StringType).foreach { dt =>
- checkConsistencyBetweenInterpretedAndCodegen(Hex.apply _, dt)
- }
- }
-
- test("unhex") {
- checkEvaluation(Unhex(Literal.create(null, StringType)), null)
- checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes(StandardCharsets.UTF_8))
- checkEvaluation(Unhex(Literal("")), new Array[Byte](0))
- checkEvaluation(Unhex(Literal("F")), Array[Byte](15))
- checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1))
- checkEvaluation(Unhex(Literal("GG")), null)
- // scalastyle:off
- // Turn off scala style for non-ascii chars
- checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "\u4e09\u91cd\u7684".getBytes(StandardCharsets.UTF_8))
- checkEvaluation(Unhex(Literal("\u4e09\u91cd\u7684")), null)
- // scalastyle:on
- checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType)
- }
-
- test("hypot") {
- testBinary(Hypot, math.hypot)
- checkConsistencyBetweenInterpretedAndCodegen(Hypot, DoubleType, DoubleType)
- }
-
- test("atan2") {
- testBinary(Atan2, math.atan2)
- checkConsistencyBetweenInterpretedAndCodegen(Atan2, DoubleType, DoubleType)
- }
-
- test("binary log") {
- val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1)
- val domain = (1 to 20).map(v => (v * 0.1, v * 0.2))
-
- domain.foreach { case (v1, v2) =>
- checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow)
- checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow)
- checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow)
- }
-
- // null input should yield null output
- checkEvaluation(
- Logarithm(Literal.create(null, DoubleType), Literal(1.0)),
- null,
- create_row(null))
- checkEvaluation(
- Logarithm(Literal(1.0), Literal.create(null, DoubleType)),
- null,
- create_row(null))
-
- // negative input should yield null output
- checkEvaluation(
- Logarithm(Literal(-1.0), Literal(1.0)),
- null,
- create_row(null))
- checkEvaluation(
- Logarithm(Literal(1.0), Literal(-1.0)),
- null,
- create_row(null))
- checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType)
- }
-
- test("round/bround") {
- val scales = -6 to 6
- val doublePi: Double = math.Pi
- val shortPi: Short = 31415
- val intPi: Int = 314159265
- val longPi: Long = 31415926535897932L
- val bdPi: BigDecimal = BigDecimal(31415927L, 7)
-
- val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142,
- 3.1416, 3.14159, 3.141593)
-
- val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++
- Seq.fill[Short](7)(31415)
-
- val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
- 314159270) ++ Seq.fill(7)(314159265)
-
- val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
- 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
- Seq.fill(7)(31415926535897932L)
-
- val intResultsB: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
- 314159260) ++ Seq.fill(7)(314159265)
-
- scales.zipWithIndex.foreach { case (scale, i) =>
- checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
- checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
- checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
- checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
- checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow)
- checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow)
- checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow)
- checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow)
- }
-
- val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
- BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
- BigDecimal(3.141593), BigDecimal(3.1415927))
- // round_scale > current_scale would result in precision increase
- // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
- (0 to 7).foreach { i =>
- checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
- checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow)
- }
- (8 to 10).foreach { scale =>
- checkEvaluation(Round(bdPi, scale), null, EmptyRow)
- checkEvaluation(BRound(bdPi, scale), null, EmptyRow)
- }
-
- DataTypeTestUtils.numericTypes.foreach { dataType =>
- checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null)
- checkEvaluation(Round(Literal.create(null, dataType),
- Literal.create(null, IntegerType)), null)
- checkEvaluation(BRound(Literal.create(null, dataType), Literal(2)), null)
- checkEvaluation(BRound(Literal.create(null, dataType),
- Literal.create(null, IntegerType)), null)
- }
-
- checkEvaluation(Round(2.5, 0), 3.0)
- checkEvaluation(Round(3.5, 0), 4.0)
- checkEvaluation(Round(-2.5, 0), -3.0)
- checkEvaluation(Round(-3.5, 0), -4.0)
- checkEvaluation(Round(-0.35, 1), -0.4)
- checkEvaluation(Round(-35, -1), -40)
- checkEvaluation(BRound(2.5, 0), 2.0)
- checkEvaluation(BRound(3.5, 0), 4.0)
- checkEvaluation(BRound(-2.5, 0), -2.0)
- checkEvaluation(BRound(-3.5, 0), -4.0)
- checkEvaluation(BRound(-0.35, 1), -0.4)
- checkEvaluation(BRound(-35, -1), -40)
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/9db06c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
new file mode 100644
index 0000000..a26d070
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+
+class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("assert_true") {
+ intercept[RuntimeException] {
+ checkEvaluation(AssertTrue(Literal.create(false, BooleanType)), null)
+ }
+ intercept[RuntimeException] {
+ checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null)
+ }
+ intercept[RuntimeException] {
+ checkEvaluation(AssertTrue(Literal.create(null, NullType)), null)
+ }
+ intercept[RuntimeException] {
+ checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null)
+ }
+ checkEvaluation(AssertTrue(Literal.create(true, BooleanType)), null)
+ checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9db06c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
deleted file mode 100644
index ed82efe..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types._
-
-class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
-
- test("assert_true") {
- intercept[RuntimeException] {
- checkEvaluation(AssertTrue(Literal.create(false, BooleanType)), null)
- }
- intercept[RuntimeException] {
- checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null)
- }
- intercept[RuntimeException] {
- checkEvaluation(AssertTrue(Literal.create(null, NullType)), null)
- }
- intercept[RuntimeException] {
- checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null)
- }
- checkEvaluation(AssertTrue(Literal.create(true, BooleanType)), null)
- checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null)
- }
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/9db06c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
new file mode 100644
index 0000000..5064a1f
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.types._
+
+class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = {
+ testFunc(false, BooleanType)
+ testFunc(1.toByte, ByteType)
+ testFunc(1.toShort, ShortType)
+ testFunc(1, IntegerType)
+ testFunc(1L, LongType)
+ testFunc(1.0F, FloatType)
+ testFunc(1.0, DoubleType)
+ testFunc(Decimal(1.5), DecimalType(2, 1))
+ testFunc(new java.sql.Date(10), DateType)
+ testFunc(new java.sql.Timestamp(10), TimestampType)
+ testFunc("abcd", StringType)
+ }
+
+ test("isnull and isnotnull") {
+ testAllTypes { (value: Any, tpe: DataType) =>
+ checkEvaluation(IsNull(Literal.create(value, tpe)), false)
+ checkEvaluation(IsNotNull(Literal.create(value, tpe)), true)
+ checkEvaluation(IsNull(Literal.create(null, tpe)), true)
+ checkEvaluation(IsNotNull(Literal.create(null, tpe)), false)
+ }
+ }
+
+ test("AssertNotNUll") {
+ val ex = intercept[RuntimeException] {
+ evaluate(AssertNotNull(Literal(null), Seq.empty[String]))
+ }.getMessage
+ assert(ex.contains("Null value appeared in non-nullable field"))
+ }
+
+ test("IsNaN") {
+ checkEvaluation(IsNaN(Literal(Double.NaN)), true)
+ checkEvaluation(IsNaN(Literal(Float.NaN)), true)
+ checkEvaluation(IsNaN(Literal(math.log(-3))), true)
+ checkEvaluation(IsNaN(Literal.create(null, DoubleType)), false)
+ checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
+ checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
+ checkEvaluation(IsNaN(Literal(5.5f)), false)
+ }
+
+ test("nanvl") {
+ checkEvaluation(NaNvl(Literal(5.0), Literal.create(null, DoubleType)), 5.0)
+ checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(5.0)), null)
+ checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(Double.NaN)), null)
+ checkEvaluation(NaNvl(Literal(Double.NaN), Literal(5.0)), 5.0)
+ checkEvaluation(NaNvl(Literal(Double.NaN), Literal.create(null, DoubleType)), null)
+ assert(NaNvl(Literal(Double.NaN), Literal(Double.NaN)).
+ eval(EmptyRow).asInstanceOf[Double].isNaN)
+ }
+
+ test("coalesce") {
+ testAllTypes { (value: Any, tpe: DataType) =>
+ val lit = Literal.create(value, tpe)
+ val nullLit = Literal.create(null, tpe)
+ checkEvaluation(Coalesce(Seq(nullLit)), null)
+ checkEvaluation(Coalesce(Seq(lit)), value)
+ checkEvaluation(Coalesce(Seq(nullLit, lit)), value)
+ checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value)
+ checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
+ }
+ }
+
+ test("SPARK-16602 Nvl should support numeric-string cases") {
+ def analyze(expr: Expression): Expression = {
+ val relation = LocalRelation()
+ SimpleAnalyzer.execute(Project(Seq(Alias(expr, "c")()), relation)).expressions.head
+ }
+
+ val intLit = Literal.create(1, IntegerType)
+ val doubleLit = Literal.create(2.2, DoubleType)
+ val stringLit = Literal.create("c", StringType)
+ val nullLit = Literal.create(null, NullType)
+
+ assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType)
+ assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType)
+ assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType)
+
+ assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType)
+ assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType)
+ assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType)
+ }
+
+ test("AtLeastNNonNulls") {
+ val mix = Seq(Literal("x"),
+ Literal.create(null, StringType),
+ Literal.create(null, DoubleType),
+ Literal(Double.NaN),
+ Literal(5f))
+
+ val nanOnly = Seq(Literal("x"),
+ Literal(10.0),
+ Literal(Float.NaN),
+ Literal(math.log(-2)),
+ Literal(Double.MaxValue))
+
+ val nullOnly = Seq(Literal("x"),
+ Literal.create(null, DoubleType),
+ Literal.create(null, DecimalType.USER_DEFAULT),
+ Literal(Float.MaxValue),
+ Literal(false))
+
+ checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow)
+ checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org