You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2020/03/31 17:03:40 UTC

[GitHub] [spark] dbtsai commented on a change in pull request #27066: [SPARK-31317][SQL] Add withField method to Column class

dbtsai commented on a change in pull request #27066: [SPARK-31317][SQL] Add withField method to Column class
URL: https://github.com/apache/spark/pull/27066#discussion_r401072794
 
 

 ##########
 File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
 ##########
 @@ -514,3 +514,181 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
 
   override def prettyName: String = "str_to_map"
 }
+
+/**
+ * Adds/replaces fields in a struct.
+ * Returns null if struct is null.
+ * If multiple fields already exist with the one of the given fieldNames, they will all be replaced.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(struct, name1, val1, name2, val2, ...) - Adds/replaces fields in struct by name.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(NAMED_STRUCT("a", 1), "b", 2, "c", 3);
+       {"a":1,"b":2,"c":3}
+  """)
+// scalastyle:on line.size.limit
+case class AddFields(children: Seq[Expression]) extends Expression {
+
+  private lazy val struct: Expression = children.head
+  private lazy val (nameExprs, valExprs) = children.drop(1).grouped(2).map {
+    case Seq(name, value) => (name, value)
+  }.toList.unzip
+  private lazy val fieldNames = nameExprs.map(_.eval().asInstanceOf[UTF8String].toString)
+  private lazy val pairs = fieldNames.zip(valExprs)
+
+  override def nullable: Boolean = struct.nullable
+
+  private lazy val ogStructType: StructType =
+    struct.dataType.asInstanceOf[StructType]
+
+  override lazy val dataType: StructType = {
+    val existingFields = ogStructType.fields.map { x => (x.name, x) }
+    val addOrReplaceFields = pairs.map { case (fieldName, field) =>
+      (fieldName, StructField(fieldName, field.dataType, field.nullable))
+    }
+    val newFields = loop(existingFields, addOrReplaceFields).map(_._2)
+    StructType(newFields)
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.size % 2 == 0) {
+      return TypeCheckResult.TypeCheckFailure(s"$prettyName expects an odd number of arguments.")
+    }
+
+    val typeName = struct.dataType.typeName
+    val expectedStructType = StructType(Nil).typeName
+    if (typeName != expectedStructType) {
+      return TypeCheckResult.TypeCheckFailure(
+        s"Only $expectedStructType is allowed to appear at first position, got: $typeName.")
+    }
+
+    if (nameExprs.exists(e => e == null || !(e.foldable && e.dataType == StringType))) {
+      return TypeCheckResult.TypeCheckFailure(
+        s"Only non-null foldable ${StringType.catalogString} expressions are allowed to appear " +
+          s"at even position.")
+    }
+
+    if (valExprs.contains(null)) {
+      return TypeCheckResult.TypeCheckFailure(
+        s"Only non-null expressions are allowed to appear at odd positions after first position.")
+    }
+
+    TypeCheckResult.TypeCheckSuccess
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val structValue = struct.eval(input)
+    if (structValue == null) {
+      null
+    } else {
+      val existingValues: Seq[(FieldName, Any)] =
+        ogStructType.fieldNames.zip(structValue.asInstanceOf[InternalRow].toSeq(ogStructType))
+      val addOrReplaceValues: Seq[(FieldName, Any)] =
+        pairs.map { case (fieldName, expression) => (fieldName, expression.eval(input)) }
+      val newValues = loop(existingValues, addOrReplaceValues).map(_._2)
+      InternalRow.fromSeq(newValues)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val structGen = struct.genCode(ctx)
+    val addOrReplaceFieldsGens = valExprs.map(_.genCode(ctx))
+    val resultCode: String = {
+      val structVar = structGen.value
+      type NullCheck = String
+      type NonNullValue = String
+      val existingFieldsCode: Seq[(FieldName, (NullCheck, NonNullValue))] =
+        ogStructType.fields.zipWithIndex.map {
+          case (structField, i) =>
+            val nullCheck = s"$structVar.isNullAt($i)"
+            val nonNullValue = CodeGenerator.getValue(structVar, structField.dataType, i.toString)
+            (structField.name, (nullCheck, nonNullValue))
+        }
+      val addOrReplaceFieldsCode: Seq[(FieldName, (NullCheck, NonNullValue))] =
+        fieldNames.zip(addOrReplaceFieldsGens).map {
+          case (fieldName, fieldExprCode) =>
+            val nullCheck = fieldExprCode.isNull.code
+            val nonNullValue = fieldExprCode.value.code
+            (fieldName, (nullCheck, nonNullValue))
+        }
+      val newFieldsCode = loop(existingFieldsCode, addOrReplaceFieldsCode)
+      val rowClass = classOf[GenericInternalRow].getName
+      val rowValuesVar = ctx.freshName("rowValues")
+      val populateRowValuesVar = newFieldsCode.zipWithIndex.map {
+        case ((_, (nullCheck, nonNullValue)), i) =>
+          s"""
+             |if ($nullCheck) {
+             | $rowValuesVar[$i] = null;
+             |} else {
+             | $rowValuesVar[$i] = $nonNullValue;
+             |}""".stripMargin
+      }.mkString("\n|")
+
+      s"""
+         |Object[] $rowValuesVar = new Object[${dataType.length}];
+         |
+         |${addOrReplaceFieldsGens.map(_.code).mkString("\n")}
+         |$populateRowValuesVar
+         |
+         |${ev.value} = new $rowClass($rowValuesVar);
+          """.stripMargin
+    }
+
+    if (nullable) {
+      val nullSafeEval =
+        structGen.code + ctx.nullSafeExec(struct.nullable, structGen.isNull) {
+          s"""
+             |${ev.isNull} = false; // resultCode could change nullability.
+             |$resultCode
+             |""".stripMargin
+        }
+
+      ev.copy(code =
+        code"""
+          boolean ${ev.isNull} = true;
+          ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+          $nullSafeEval
+          """)
+    } else {
+      ev.copy(code =
+        code"""
+          ${structGen.code}
+          ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+          $resultCode
+          """, isNull = FalseLiteral)
+    }
+  }
+
+  override def prettyName: String = "add_fields"
+
+  private type FieldName = String
+
+  /**
+   * Recursively loop through addOrReplaceFields, adding or replacing fields by FieldName.
+   */
+  @scala.annotation.tailrec
+  private def loop[V](existingFields: Seq[(String, V)],
+                      addOrReplaceFields: Seq[(String, V)]): Seq[(String, V)] = {
+    if (addOrReplaceFields.nonEmpty) {
+      val existingFieldNames = existingFields.map(_._1)
+      val newField@(newFieldName, _) = addOrReplaceFields.head
 
 Review comment:
   Spark's coding style, `val newField @ (newFieldName, _) = addOrReplaceFields.head`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org