You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by kiszk <gi...@git.apache.org> on 2018/08/02 17:02:08 UTC
[GitHub] spark pull request #21966: [SPARK-23915][SQL][followup] Add array_except fun...
Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/21966#discussion_r207302021
--- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---
@@ -4077,81 +4078,84 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayData = classOf[ArrayData].getName
val i = ctx.freshName("i")
- val pos = ctx.freshName("pos")
val value = ctx.freshName("value")
- val hsValue = ctx.freshName("hsValue")
val size = ctx.freshName("size")
- if (elementTypeSupportEquals) {
- val ptName = CodeGenerator.primitiveTypeName(elementType)
- val unsafeArray = ctx.freshName("unsafeArray")
- val (postFix, openHashElementType, hsJavaTypeName, genHsValue,
- getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) =
- elementType match {
- case ByteType | ShortType | IntegerType =>
- ("$mcI$sp", "Int", "int", s"(int) $value",
- s"get$ptName($i)", s"set$ptName($pos, $value)",
- CodeGenerator.javaType(elementType), ptName,
- s"""
- |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
- |${ev.value} = $unsafeArray;
- """.stripMargin)
- case LongType | FloatType | DoubleType =>
- val signature = elementType match {
- case LongType => "$mcJ$sp"
- case FloatType => "$mcF$sp"
- case DoubleType => "$mcD$sp"
- }
- (signature, CodeGenerator.boxedType(elementType),
- CodeGenerator.javaType(elementType), value,
- s"get$ptName($i)", s"set$ptName($pos, $value)",
- CodeGenerator.javaType(elementType), ptName,
- s"""
- |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")}
- |${ev.value} = $unsafeArray;
- """.stripMargin)
- case _ =>
- val genericArrayData = classOf[GenericArrayData].getName
- val et = ctx.addReferenceObj("elementType", elementType)
- ("", "Object", "Object", value,
- s"get($i, $et)", s"update($pos, $value)", "Object", "Ref",
- s"${ev.value} = new $genericArrayData(new Object[$size]);")
- }
+ val canUseSpecializedHashSet = elementType match {
+ case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true
+ case _ => false
+ }
+ if (canUseSpecializedHashSet) {
+ val jt = CodeGenerator.javaType(elementType)
+ val ptName = CodeGenerator.primitiveTypeName(jt)
+
+ def genGetValue(array: String): String =
+ CodeGenerator.getValue(array, elementType, i)
+
+ val (hsPostFix, hsTypeName) = elementType match {
+ // we cast byte/short to int when writing to the hash set.
+ case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
+ case LongType => ("$mcJ$sp", ptName)
+ case FloatType => ("$mcF$sp", ptName)
+ case DoubleType => ("$mcD$sp", ptName)
+ }
+
+ // we cast byte/short to int when writing to the hash set.
+ val hsValueCast = elementType match {
+ case ByteType | ShortType => "(int) "
+ case _ => ""
+ }
nullSafeCodeGen(ctx, ev, (array1, array2) => {
val notFoundNullElement = ctx.freshName("notFoundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
- val array = ctx.freshName("array")
val openHashSet = classOf[OpenHashSet[_]].getName
- val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
- val hs = ctx.freshName("hs")
+ val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+ val hashSet = ctx.freshName("hashSet")
val genericArrayData = classOf[GenericArrayData].getName
val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
- val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName"
- val arrayBuilderClassTag = if (primitiveTypeName != "Ref") {
- s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()"
- } else {
- s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()"
- }
+ val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
+ val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"
- def withArray2NullCheck(body: String) =
- if (right.dataType.asInstanceOf[ArrayType].containsNull) {
- s"""
- |if ($array2.isNullAt($i)) {
- | $notFoundNullElement = false;
- |} else {
- | $body
- |}
+ def withArray2NullCheck(body: String): String =
+ if (left.dataType.asInstanceOf[ArrayType].containsNull) {
--- End diff --
Is it better to use the following structure to make `else` clause common?
```
if (right.dataType.asInstanceOf[ArrayType].containsNull) {
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
...
} else {
...
}
} else {
body
}
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org