You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by ueshin <gi...@git.apache.org> on 2018/08/06 17:03:22 UTC

[GitHub] spark pull request #21937: [SPARK-23914][SQL][follow-up] refactor ArrayUnion

Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21937#discussion_r207962845
  
    --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---
    @@ -3767,230 +3767,160 @@ object ArraySetLike {
       """,
       since = "2.4.0")
     case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
    -    with ComplexTypeMergingExpression {
    -  var hsInt: OpenHashSet[Int] = _
    -  var hsLong: OpenHashSet[Long] = _
    -
    -  def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
    -    val elem = array.getInt(idx)
    -    if (!hsInt.contains(elem)) {
    -      if (resultArray != null) {
    -        resultArray.setInt(pos, elem)
    -      }
    -      hsInt.add(elem)
    -      true
    -    } else {
    -      false
    -    }
    -  }
    -
    -  def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
    -    val elem = array.getLong(idx)
    -    if (!hsLong.contains(elem)) {
    -      if (resultArray != null) {
    -        resultArray.setLong(pos, elem)
    -      }
    -      hsLong.add(elem)
    -      true
    -    } else {
    -      false
    -    }
    -  }
    +  with ComplexTypeMergingExpression {
     
    -  def evalIntLongPrimitiveType(
    -      array1: ArrayData,
    -      array2: ArrayData,
    -      resultArray: ArrayData,
    -      isLongType: Boolean): Int = {
    -    // store elements into resultArray
    -    var nullElementSize = 0
    -    var pos = 0
    -    Seq(array1, array2).foreach { array =>
    -      var i = 0
    -      while (i < array.numElements()) {
    -        val size = if (!isLongType) hsInt.size else hsLong.size
    -        if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    -          ArraySetLike.throwUnionLengthOverflowException(size)
    -        }
    -        if (array.isNullAt(i)) {
    -          if (nullElementSize == 0) {
    -            if (resultArray != null) {
    -              resultArray.setNullAt(pos)
    +  @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
    +    if (elementTypeSupportEquals) {
    +      (array1, array2) =>
    +        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +        val hs = new OpenHashSet[Any]
    +        var foundNullElement = false
    +        Seq(array1, array2).foreach { array =>
    +          var i = 0
    +          while (i < array.numElements()) {
    +            if (array.isNullAt(i)) {
    +              if (!foundNullElement) {
    +                arrayBuffer += null
    +                foundNullElement = true
    +              }
    +            } else {
    +              val elem = array.get(i, elementType)
    +              if (!hs.contains(elem)) {
    +                if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +                  ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
    +                }
    +                arrayBuffer += elem
    +                hs.add(elem)
    +              }
                 }
    -            pos += 1
    -            nullElementSize = 1
    +            i += 1
               }
    -        } else {
    -          val assigned = if (!isLongType) {
    -            assignInt(array, i, resultArray, pos)
    +        }
    +        new GenericArrayData(arrayBuffer)
    +    } else {
    +      (array1, array2) =>
    +        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +        var alreadyIncludeNull = false
    +        Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
    +          var found = false
    +          if (elem == null) {
    +            if (alreadyIncludeNull) {
    +              found = true
    +            } else {
    +              alreadyIncludeNull = true
    +            }
               } else {
    -            assignLong(array, i, resultArray, pos)
    +            // check elem is already stored in arrayBuffer or not?
    +            var j = 0
    +            while (!found && j < arrayBuffer.size) {
    +              val va = arrayBuffer(j)
    +              if (va != null && ordering.equiv(va, elem)) {
    +                found = true
    +              }
    +              j = j + 1
    +            }
               }
    -          if (assigned) {
    -            pos += 1
    +          if (!found) {
    +            if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +              ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
    +            }
    +            arrayBuffer += elem
               }
    -        }
    -        i += 1
    -      }
    +        }))
    +        new GenericArrayData(arrayBuffer)
         }
    -    pos
       }
     
       override def nullSafeEval(input1: Any, input2: Any): Any = {
         val array1 = input1.asInstanceOf[ArrayData]
         val array2 = input2.asInstanceOf[ArrayData]
     
    -    if (elementTypeSupportEquals) {
    -      elementType match {
    -        case IntegerType =>
    -          // avoid boxing of primitive int array elements
    -          // calculate result array size
    -          hsInt = new OpenHashSet[Int]
    -          val elements = evalIntLongPrimitiveType(array1, array2, null, false)
    -          hsInt = new OpenHashSet[Int]
    -          val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
    -            IntegerType.defaultSize, elements)) {
    -            new GenericArrayData(new Array[Any](elements))
    -          } else {
    -            UnsafeArrayData.forPrimitiveArray(
    -              Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
    -          }
    -          evalIntLongPrimitiveType(array1, array2, resultArray, false)
    -          resultArray
    -        case LongType =>
    -          // avoid boxing of primitive long array elements
    -          // calculate result array size
    -          hsLong = new OpenHashSet[Long]
    -          val elements = evalIntLongPrimitiveType(array1, array2, null, true)
    -          hsLong = new OpenHashSet[Long]
    -          val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
    -            LongType.defaultSize, elements)) {
    -            new GenericArrayData(new Array[Any](elements))
    -          } else {
    -            UnsafeArrayData.forPrimitiveArray(
    -              Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
    -          }
    -          evalIntLongPrimitiveType(array1, array2, resultArray, true)
    -          resultArray
    -        case _ =>
    -          val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    -          val hs = new OpenHashSet[Any]
    -          var foundNullElement = false
    -          Seq(array1, array2).foreach { array =>
    -            var i = 0
    -            while (i < array.numElements()) {
    -              if (array.isNullAt(i)) {
    -                if (!foundNullElement) {
    -                  arrayBuffer += null
    -                  foundNullElement = true
    -                }
    -              } else {
    -                val elem = array.get(i, elementType)
    -                if (!hs.contains(elem)) {
    -                  if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    -                    ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
    -                  }
    -                  arrayBuffer += elem
    -                  hs.add(elem)
    -                }
    -              }
    -              i += 1
    -            }
    -          }
    -          new GenericArrayData(arrayBuffer)
    -      }
    -    } else {
    -      ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
    -    }
    +    evalUnion(array1, array2)
       }
     
       override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val arrayData = classOf[ArrayData].getName
         val i = ctx.freshName("i")
    -    val pos = ctx.freshName("pos")
         val value = ctx.freshName("value")
         val size = ctx.freshName("size")
    -    val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
    -      if (elementTypeSupportEquals) {
    -        elementType match {
    -          case ByteType | ShortType | IntegerType | LongType =>
    -            val ptName = CodeGenerator.primitiveTypeName(elementType)
    -            val unsafeArray = ctx.freshName("unsafeArray")
    -            (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
    -              if (elementType == LongType) "Long" else "Int",
    -              s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
    -              if (elementType == LongType) "(long)" else "(int)",
    -              s"""
    -                 |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
    -                 |${ev.value} = $unsafeArray;
    -               """.stripMargin)
    -          case _ =>
    -            val genericArrayData = classOf[GenericArrayData].getName
    -            val et = ctx.addReferenceObj("elementType", elementType)
    -            ("", "Object",
    -              s"get($i, $et)", s"update($pos, $value)", "Object", "",
    -              s"${ev.value} = new $genericArrayData(new Object[$size]);")
    -        }
    -      } else {
    -        ("", "", "", "", "", "", "")
    -      }
    +    if (canUseSpecializedHashSet) {
    +      val jt = CodeGenerator.javaType(elementType)
    +      val ptName = CodeGenerator.primitiveTypeName(jt)
     
    -    nullSafeCodeGen(ctx, ev, (array1, array2) => {
    -      if (openHashElementType != "") {
    -        // Here, we ensure elementTypeSupportEquals is true
    +      nullSafeCodeGen(ctx, ev, (array1, array2) => {
             val foundNullElement = ctx.freshName("foundNullElement")
    -        val openHashSet = classOf[OpenHashSet[_]].getName
    -        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
    -        val hs = ctx.freshName("hs")
    -        val arrayData = classOf[ArrayData].getName
    -        val arrays = ctx.freshName("arrays")
    +        val nullElementIndex = ctx.freshName("nullElementIndex")
    +        val builder = ctx.freshName("builder")
             val array = ctx.freshName("array")
    +        val arrays = ctx.freshName("arrays")
             val arrayDataIdx = ctx.freshName("arrayDataIdx")
    +        val openHashSet = classOf[OpenHashSet[_]].getName
    +        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
    +        val hashSet = ctx.freshName("hashSet")
    +        val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
    +        val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
    --- End diff --
    
    This is not needed any more?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org