You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2020/11/21 03:22:39 UTC

[GitHub] [spark] AngersZhuuuu commented on a change in pull request #30243: [SPARK-33335][SQL] Support `array_contains_array` func

AngersZhuuuu commented on a change in pull request #30243:
URL: https://github.com/apache/spark/pull/30243#discussion_r528062328



##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -3957,3 +3957,194 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
 
   override def prettyName: String = "array_except"
 }
+
+/**
+ * Checks if the array (left) has the array (right)
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array1, array2) - Returns true if the array1 contains the array2.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), array(2));
+       true
+  """,
+  group = "array_funcs",
+  since = "3.1.0")
+case class ArrayContainsArray(left: Expression, right: Expression)
+  extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant {
+
+  override def dataType: DataType = BooleanType
+
+  override def et: DataType = elementType
+
+  override def dt: DataType = dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val typeCheckResult = super.checkInputDataTypes()
+    if (typeCheckResult.isSuccess) {
+      TypeUtils.checkForOrderingExpr(et, s"function $prettyName")
+    } else {
+      typeCheckResult
+    }
+  }
+
+  @transient lazy val evalContains: (ArrayData, ArrayData) => Boolean = {
+    if (TypeUtils.typeWithProperEquals(elementType)) {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          val hs = new OpenHashSet[Any]
+          var result = true
+          var foundNullElement = false
+          var i = 0
+          while (i < array1.numElements()) {
+            if (array1.isNullAt(i) && !foundNullElement) {
+              foundNullElement = true
+            } else {
+              val elem = array1.get(i, elementType)
+              hs.add(elem)
+            }
+            i += 1
+          }
+          i = 0
+          while (i < array2.numElements() && result) {
+            if (array2.isNullAt(i)) {
+              if (!foundNullElement) {
+                result = false
+              }
+            } else {
+              val elem = array2.get(i, elementType)
+              if (!hs.contains(elem)) {
+                result = false
+              }
+            }
+            i += 1
+          }
+          result
+        }
+    } else {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          var alreadySeenNull = false
+          var i = 0
+          var elementFound = true
+          while (elementFound && i < array2.numElements()) {
+            var found = false
+            val elem2 = array2.get(i, elementType)
+            if (array2.isNullAt(i)) {
+              if (!alreadySeenNull) {
+                var j = 0
+                while (!found && j < array1.numElements()) {
+                  found = array1.isNullAt(j)
+                  j += 1
+                }
+                // array1 is scanned only once for null element
+                alreadySeenNull = true
+              }
+            } else {
+              var j = 0
+              while (!found && j < array2.numElements()) {
+                if (!array1.isNullAt(j)) {
+                  val elem1 = array1.get(j, elementType)
+                  if (ordering.equiv(elem2, elem1)) {
+                    found = true
+                  }
+                }
+                j += 1
+              }
+            }
+            if (!found) {
+              elementFound = false
+            }
+            i += 1
+          }
+          elementFound
+        }
+    }
+  }
+
+  override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val array1 = input1.asInstanceOf[ArrayData]
+    val array2 = input2.asInstanceOf[ArrayData]
+
+    evalContains(array1, array2)
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val i = ctx.freshName("i")
+    val value = ctx.freshName("value")
+    if (canUseSpecializedHashSet) {
+      val jt = CodeGenerator.javaType(elementType)
+
+      nullSafeCodeGen(ctx, ev, (array1, array2) => {
+        val result = ctx.freshName("result")
+        val foundNullElement = ctx.freshName("foundNullElement")
+        val openHashSet = classOf[OpenHashSet[_]].getName
+        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+        val hashSet = ctx.freshName("hashSet")
+
+        def withArray1NullCheck(body: String): String =
+          s"""
+             |if ($array1.isNullAt($i) && !$foundNullElement) {

Review comment:
       > If there is more than one null element in `array1`, does this code work?
   
   Update UT and it works. also change t as below like  https://github.com/apache/spark/pull/30243#discussion_r528061382
   ```
                |if ($array1.isNullAt($i)) {
                |  if (!$foundNullElement) {
                |    $foundNullElement = true;
                |  }
   ```

##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -3957,3 +3957,194 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
 
   override def prettyName: String = "array_except"
 }
+
+/**
+ * Checks if the array (left) has the array (right)
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array1, array2) - Returns true if the array1 contains the array2.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), array(2));
+       true
+  """,
+  group = "array_funcs",
+  since = "3.1.0")
+case class ArrayContainsArray(left: Expression, right: Expression)
+  extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant {
+
+  override def dataType: DataType = BooleanType
+
+  override def et: DataType = elementType
+
+  override def dt: DataType = dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val typeCheckResult = super.checkInputDataTypes()
+    if (typeCheckResult.isSuccess) {
+      TypeUtils.checkForOrderingExpr(et, s"function $prettyName")
+    } else {
+      typeCheckResult
+    }
+  }
+
+  @transient lazy val evalContains: (ArrayData, ArrayData) => Boolean = {
+    if (TypeUtils.typeWithProperEquals(elementType)) {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          val hs = new OpenHashSet[Any]
+          var result = true
+          var foundNullElement = false
+          var i = 0
+          while (i < array1.numElements()) {
+            if (array1.isNullAt(i) && !foundNullElement) {
+              foundNullElement = true
+            } else {
+              val elem = array1.get(i, elementType)
+              hs.add(elem)
+            }
+            i += 1
+          }
+          i = 0
+          while (i < array2.numElements() && result) {
+            if (array2.isNullAt(i)) {
+              if (!foundNullElement) {
+                result = false
+              }
+            } else {
+              val elem = array2.get(i, elementType)
+              if (!hs.contains(elem)) {
+                result = false
+              }
+            }
+            i += 1
+          }
+          result
+        }
+    } else {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          var alreadySeenNull = false
+          var i = 0
+          var elementFound = true
+          while (elementFound && i < array2.numElements()) {
+            var found = false
+            val elem2 = array2.get(i, elementType)
+            if (array2.isNullAt(i)) {
+              if (!alreadySeenNull) {
+                var j = 0
+                while (!found && j < array1.numElements()) {
+                  found = array1.isNullAt(j)
+                  j += 1
+                }
+                // array1 is scanned only once for null element
+                alreadySeenNull = true
+              }
+            } else {
+              var j = 0
+              while (!found && j < array2.numElements()) {
+                if (!array1.isNullAt(j)) {
+                  val elem1 = array1.get(j, elementType)
+                  if (ordering.equiv(elem2, elem1)) {
+                    found = true
+                  }
+                }
+                j += 1
+              }
+            }
+            if (!found) {
+              elementFound = false
+            }
+            i += 1
+          }
+          elementFound
+        }
+    }
+  }
+
+  override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val array1 = input1.asInstanceOf[ArrayData]
+    val array2 = input2.asInstanceOf[ArrayData]
+
+    evalContains(array1, array2)
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val i = ctx.freshName("i")
+    val value = ctx.freshName("value")
+    if (canUseSpecializedHashSet) {
+      val jt = CodeGenerator.javaType(elementType)
+
+      nullSafeCodeGen(ctx, ev, (array1, array2) => {
+        val result = ctx.freshName("result")
+        val foundNullElement = ctx.freshName("foundNullElement")
+        val openHashSet = classOf[OpenHashSet[_]].getName
+        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+        val hashSet = ctx.freshName("hashSet")
+
+        def withArray1NullCheck(body: String): String =
+          s"""
+             |if ($array1.isNullAt($i) && !$foundNullElement) {
+             |  $foundNullElement = true;
+             |} else {
+             |  $body
+             |}
+               """.stripMargin
+
+        val writeArray1ToHashSet = withArray1NullCheck(
+          s"""
+             |$jt $value = ${genGetValue(array1, i)};
+             |$hashSet.add$hsPostFix($hsValueCast$value);
+           """.stripMargin)
+
+        val processArray2 =
+          s"""
+             |if ($array2.isNullAt($i)) {
+             |  if (!$foundNullElement) {
+             |    $result = false;
+             |  }
+             |} else {
+             |  $jt $value = ${genGetValue(array2, i)};
+             |  if (!$hashSet.contains($hsValueCast$value)) {
+             |   $result = false;

Review comment:
       > nit: indentation
   
   Done

##########
File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -3957,3 +3957,194 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
 
   override def prettyName: String = "array_except"
 }
+
+/**
+ * Checks if the array (left) has the array (right)
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array1, array2) - Returns true if the array1 contains the array2.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), array(2));
+       true
+  """,
+  group = "array_funcs",
+  since = "3.1.0")
+case class ArrayContainsArray(left: Expression, right: Expression)
+  extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant {
+
+  override def dataType: DataType = BooleanType
+
+  override def et: DataType = elementType
+
+  override def dt: DataType = dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val typeCheckResult = super.checkInputDataTypes()
+    if (typeCheckResult.isSuccess) {
+      TypeUtils.checkForOrderingExpr(et, s"function $prettyName")
+    } else {
+      typeCheckResult
+    }
+  }
+
+  @transient lazy val evalContains: (ArrayData, ArrayData) => Boolean = {
+    if (TypeUtils.typeWithProperEquals(elementType)) {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          val hs = new OpenHashSet[Any]
+          var result = true
+          var foundNullElement = false
+          var i = 0
+          while (i < array1.numElements()) {
+            if (array1.isNullAt(i) && !foundNullElement) {
+              foundNullElement = true
+            } else {
+              val elem = array1.get(i, elementType)
+              hs.add(elem)
+            }
+            i += 1
+          }
+          i = 0
+          while (i < array2.numElements() && result) {
+            if (array2.isNullAt(i)) {
+              if (!foundNullElement) {
+                result = false
+              }
+            } else {
+              val elem = array2.get(i, elementType)
+              if (!hs.contains(elem)) {
+                result = false
+              }
+            }
+            i += 1
+          }
+          result
+        }
+    } else {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          var alreadySeenNull = false
+          var i = 0
+          var elementFound = true
+          while (elementFound && i < array2.numElements()) {
+            var found = false
+            val elem2 = array2.get(i, elementType)
+            if (array2.isNullAt(i)) {
+              if (!alreadySeenNull) {
+                var j = 0
+                while (!found && j < array1.numElements()) {
+                  found = array1.isNullAt(j)
+                  j += 1
+                }
+                // array1 is scanned only once for null element
+                alreadySeenNull = true
+              }
+            } else {
+              var j = 0
+              while (!found && j < array2.numElements()) {
+                if (!array1.isNullAt(j)) {
+                  val elem1 = array1.get(j, elementType)
+                  if (ordering.equiv(elem2, elem1)) {
+                    found = true
+                  }
+                }
+                j += 1
+              }
+            }
+            if (!found) {
+              elementFound = false
+            }
+            i += 1
+          }
+          elementFound
+        }
+    }
+  }
+
+  override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val array1 = input1.asInstanceOf[ArrayData]
+    val array2 = input2.asInstanceOf[ArrayData]
+
+    evalContains(array1, array2)
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val i = ctx.freshName("i")
+    val value = ctx.freshName("value")
+    if (canUseSpecializedHashSet) {
+      val jt = CodeGenerator.javaType(elementType)
+
+      nullSafeCodeGen(ctx, ev, (array1, array2) => {
+        val result = ctx.freshName("result")
+        val foundNullElement = ctx.freshName("foundNullElement")
+        val openHashSet = classOf[OpenHashSet[_]].getName
+        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+        val hashSet = ctx.freshName("hashSet")
+
+        def withArray1NullCheck(body: String): String =
+          s"""
+             |if ($array1.isNullAt($i) && !$foundNullElement) {
+             |  $foundNullElement = true;
+             |} else {
+             |  $body
+             |}
+               """.stripMargin

Review comment:
       > nit: indentation
   
   Done




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org