You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by ueshin <gi...@git.apache.org> on 2018/08/06 03:32:19 UTC

[GitHub] spark pull request #21937: [WIP][SPARK-23914][SQL][follow-up] refactor Array...

Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21937#discussion_r207767113
  
    --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---
    @@ -3698,230 +3767,162 @@ object ArraySetLike {
       """,
       since = "2.4.0")
     case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
    -    with ComplexTypeMergingExpression {
    -  var hsInt: OpenHashSet[Int] = _
    -  var hsLong: OpenHashSet[Long] = _
    -
    -  def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
    -    val elem = array.getInt(idx)
    -    if (!hsInt.contains(elem)) {
    -      if (resultArray != null) {
    -        resultArray.setInt(pos, elem)
    -      }
    -      hsInt.add(elem)
    -      true
    -    } else {
    -      false
    -    }
    -  }
    -
    -  def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = {
    -    val elem = array.getLong(idx)
    -    if (!hsLong.contains(elem)) {
    -      if (resultArray != null) {
    -        resultArray.setLong(pos, elem)
    -      }
    -      hsLong.add(elem)
    -      true
    -    } else {
    -      false
    -    }
    -  }
    +  with ComplexTypeMergingExpression {
     
    -  def evalIntLongPrimitiveType(
    -      array1: ArrayData,
    -      array2: ArrayData,
    -      resultArray: ArrayData,
    -      isLongType: Boolean): Int = {
    -    // store elements into resultArray
    -    var nullElementSize = 0
    -    var pos = 0
    -    Seq(array1, array2).foreach { array =>
    -      var i = 0
    -      while (i < array.numElements()) {
    -        val size = if (!isLongType) hsInt.size else hsLong.size
    -        if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    -          ArraySetLike.throwUnionLengthOverflowException(size)
    -        }
    -        if (array.isNullAt(i)) {
    -          if (nullElementSize == 0) {
    -            if (resultArray != null) {
    -              resultArray.setNullAt(pos)
    +  @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
    +    if (elementTypeSupportEquals) {
    +      (array1, array2) =>
    +        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +        val hs = new OpenHashSet[Any]
    +        var foundNullElement = false
    +        Seq(array1, array2).foreach { array =>
    +          var i = 0
    +          while (i < array.numElements()) {
    +            if (array.isNullAt(i)) {
    +              if (!foundNullElement) {
    +                arrayBuffer += null
    +                foundNullElement = true
    +              }
    +            } else {
    +              val elem = array.get(i, elementType)
    +              if (!hs.contains(elem)) {
    +                if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +                  ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
    +                }
    +                arrayBuffer += elem
    +                hs.add(elem)
    +              }
                 }
    -            pos += 1
    -            nullElementSize = 1
    +            i += 1
               }
    -        } else {
    -          val assigned = if (!isLongType) {
    -            assignInt(array, i, resultArray, pos)
    +        }
    +        new GenericArrayData(arrayBuffer)
    +    } else {
    +      (array1, array2) =>
    +        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +        var alreadyIncludeNull = false
    +        Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
    +          var found = false
    +          if (elem == null) {
    +            if (alreadyIncludeNull) {
    +              found = true
    +            } else {
    +              alreadyIncludeNull = true
    +            }
               } else {
    -            assignLong(array, i, resultArray, pos)
    +            // check elem is already stored in arrayBuffer or not?
    +            var j = 0
    +            while (!found && j < arrayBuffer.size) {
    +              val va = arrayBuffer(j)
    +              if (va != null && ordering.equiv(va, elem)) {
    +                found = true
    +              }
    +              j = j + 1
    +            }
               }
    -          if (assigned) {
    -            pos += 1
    +          if (!found) {
    +            if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +              ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
    +            }
    +            arrayBuffer += elem
               }
    -        }
    -        i += 1
    -      }
    +        }))
    +        new GenericArrayData(arrayBuffer)
         }
    -    pos
       }
     
       override def nullSafeEval(input1: Any, input2: Any): Any = {
         val array1 = input1.asInstanceOf[ArrayData]
         val array2 = input2.asInstanceOf[ArrayData]
     
    -    if (elementTypeSupportEquals) {
    -      elementType match {
    -        case IntegerType =>
    -          // avoid boxing of primitive int array elements
    -          // calculate result array size
    -          hsInt = new OpenHashSet[Int]
    -          val elements = evalIntLongPrimitiveType(array1, array2, null, false)
    -          hsInt = new OpenHashSet[Int]
    -          val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
    -            IntegerType.defaultSize, elements)) {
    -            new GenericArrayData(new Array[Any](elements))
    -          } else {
    -            UnsafeArrayData.forPrimitiveArray(
    -              Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
    -          }
    -          evalIntLongPrimitiveType(array1, array2, resultArray, false)
    -          resultArray
    -        case LongType =>
    -          // avoid boxing of primitive long array elements
    -          // calculate result array size
    -          hsLong = new OpenHashSet[Long]
    -          val elements = evalIntLongPrimitiveType(array1, array2, null, true)
    -          hsLong = new OpenHashSet[Long]
    -          val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
    -            LongType.defaultSize, elements)) {
    -            new GenericArrayData(new Array[Any](elements))
    -          } else {
    -            UnsafeArrayData.forPrimitiveArray(
    -              Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
    -          }
    -          evalIntLongPrimitiveType(array1, array2, resultArray, true)
    -          resultArray
    -        case _ =>
    -          val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    -          val hs = new OpenHashSet[Any]
    -          var foundNullElement = false
    -          Seq(array1, array2).foreach { array =>
    -            var i = 0
    -            while (i < array.numElements()) {
    -              if (array.isNullAt(i)) {
    -                if (!foundNullElement) {
    -                  arrayBuffer += null
    -                  foundNullElement = true
    -                }
    -              } else {
    -                val elem = array.get(i, elementType)
    -                if (!hs.contains(elem)) {
    -                  if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    -                    ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
    -                  }
    -                  arrayBuffer += elem
    -                  hs.add(elem)
    -                }
    -              }
    -              i += 1
    -            }
    -          }
    -          new GenericArrayData(arrayBuffer)
    -      }
    -    } else {
    -      ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
    -    }
    +    evalUnion(array1, array2)
       }
     
       override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val arrayData = classOf[ArrayData].getName
         val i = ctx.freshName("i")
    -    val pos = ctx.freshName("pos")
         val value = ctx.freshName("value")
         val size = ctx.freshName("size")
    -    val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) =
    -      if (elementTypeSupportEquals) {
    -        elementType match {
    -          case ByteType | ShortType | IntegerType | LongType =>
    -            val ptName = CodeGenerator.primitiveTypeName(elementType)
    -            val unsafeArray = ctx.freshName("unsafeArray")
    -            (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
    -              if (elementType == LongType) "Long" else "Int",
    -              s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType),
    -              if (elementType == LongType) "(long)" else "(int)",
    -              s"""
    -                 |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
    -                 |${ev.value} = $unsafeArray;
    -               """.stripMargin)
    -          case _ =>
    -            val genericArrayData = classOf[GenericArrayData].getName
    -            val et = ctx.addReferenceObj("elementType", elementType)
    -            ("", "Object",
    -              s"get($i, $et)", s"update($pos, $value)", "Object", "",
    -              s"${ev.value} = new $genericArrayData(new Object[$size]);")
    -        }
    -      } else {
    -        ("", "", "", "", "", "", "")
    -      }
    +    if (canUseSpecializedHashSet) {
    +      val jt = CodeGenerator.javaType(elementType)
    +      val ptName = CodeGenerator.primitiveTypeName(jt)
     
    -    nullSafeCodeGen(ctx, ev, (array1, array2) => {
    -      if (openHashElementType != "") {
    -        // Here, we ensure elementTypeSupportEquals is true
    +      nullSafeCodeGen(ctx, ev, (array1, array2) => {
             val foundNullElement = ctx.freshName("foundNullElement")
    -        val openHashSet = classOf[OpenHashSet[_]].getName
    -        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
    -        val hs = ctx.freshName("hs")
    -        val arrayData = classOf[ArrayData].getName
    -        val arrays = ctx.freshName("arrays")
    +        val nullElementIndex = ctx.freshName("nullElementIndex")
    +        val builder = ctx.freshName("builder")
             val array = ctx.freshName("array")
    +        val arrays = ctx.freshName("arrays")
             val arrayDataIdx = ctx.freshName("arrayDataIdx")
    +        val openHashSet = classOf[OpenHashSet[_]].getName
    +        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
    +        val hashSet = ctx.freshName("hashSet")
    +        val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
    +        val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
    +        val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"
    +
    +        def withArrayNullAssignment(body: String) =
    +          if (dataType.asInstanceOf[ArrayType].containsNull) {
    +            s"""
    +               |if ($array.isNullAt($i)) {
    +               |  if (!$foundNullElement) {
    +               |    $nullElementIndex = $size;
    +               |    $foundNullElement = true;
    +               |    $size++;
    +               |    $builder.$$plus$$eq($nullValueHolder);
    +               |  }
    +               |} else {
    +               |  $body
    +               |}
    +             """.stripMargin
    +          } else {
    +            body
    +          }
    +
    +        val processArray = withArrayNullAssignment(
    +          s"""
    +             |$jt $value = ${genGetValue(array, i)};
    +             |if (!$hashSet.contains($hsValueCast$value)) {
    +             |  if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
    +             |    break;
    +             |  }
    +             |  $hashSet.add$hsPostFix($hsValueCast$value);
    +             |  $builder.$$plus$$eq($value);
    +             |}
    +           """.stripMargin)
    +
    +        // Only need to track null element index when result array's element is nullable.
    +        val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) {
    +          s"""
    +             |boolean $foundNullElement = false;
    +             |int $nullElementIndex = -1;
    +           """.stripMargin
    +        } else {
    +          ""
    +        }
    +
             s"""
    -           |$openHashSet $hs = new $openHashSet$postFix($classTag);
    -           |boolean $foundNullElement = false;
    +           |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
    +           |$declareNullTrackVariables
    +           |int $size = 0;
    +           |$arrayBuilderClass $builder =
    +           |  ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag);
    --- End diff --
    
    nit: new `$arrayBuilderClass()` should work?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org