You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "LuciferYang (via GitHub)" <gi...@apache.org> on 2023/05/08 16:04:17 UTC

[GitHub] [spark] LuciferYang commented on a diff in pull request #40355: [SPARK-42604][CONNECT] Implement functions.typedlit

LuciferYang commented on code in PR #40355:
URL: https://github.com/apache/spark/pull/40355#discussion_r1162267083


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala:
##########
@@ -106,6 +106,37 @@ object functions {
       case _ => createLiteral(toLiteralProtoBuilder(literal))
     }
   }
+
+  /**
+   * Creates a [[Column]] of literal value.
+   *
+   * An alias of `typedlit`, and it is encouraged to use `typedlit` directly.
+   *
+   * @group normal_funcs
+   * @since 3.4.0

Review Comment:
   should change to 3.5.0



##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala:
##########
@@ -106,6 +106,37 @@ object functions {
       case _ => createLiteral(toLiteralProtoBuilder(literal))
     }
   }
+
+  /**
+   * Creates a [[Column]] of literal value.
+   *
+   * An alias of `typedlit`, and it is encouraged to use `typedlit` directly.
+   *
+   * @group normal_funcs
+   * @since 3.4.0
+   */
+  def typedLit[T: TypeTag](literal: T): Column = typedlit(literal)
+
+  /**
+   * Creates a [[Column]] of literal value.
+   *
+   * The passed in object is returned directly if it is already a [[Column]]. If the object is a
+   * Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new [[Column]] is created
+   * to represent the literal value. The difference between this function and [[lit]] is that this
+   * function can handle parameterized scala types e.g.: List, Seq and Map.
+   *
+   * @note
+   *   `typedlit` will call expensive Scala reflection APIs. `lit` is preferred if parameterized
+   *   Scala types are not used.
+   *
+   * @group normal_funcs
+   * @since 3.4.0

Review Comment:
   should change to 3.5.0



##########
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala:
##########
@@ -211,46 +349,48 @@ object LiteralValueProtoConverter {
       builder.result()
     }
 
-    val elementType = array.getElementType
-    if (elementType.hasShort) {
-      makeArrayData(v => v.getShort.toShort)
-    } else if (elementType.hasInteger) {
-      makeArrayData(v => v.getInteger)
-    } else if (elementType.hasLong) {
-      makeArrayData(v => v.getLong)
-    } else if (elementType.hasDouble) {
-      makeArrayData(v => v.getDouble)
-    } else if (elementType.hasByte) {
-      makeArrayData(v => v.getByte.toByte)
-    } else if (elementType.hasFloat) {
-      makeArrayData(v => v.getFloat)
-    } else if (elementType.hasBoolean) {
-      makeArrayData(v => v.getBoolean)
-    } else if (elementType.hasString) {
-      makeArrayData(v => v.getString)
-    } else if (elementType.hasBinary) {
-      makeArrayData(v => v.getBinary.toByteArray)
-    } else if (elementType.hasDate) {
-      makeArrayData(v => DateTimeUtils.toJavaDate(v.getDate))
-    } else if (elementType.hasTimestamp) {
-      makeArrayData(v => DateTimeUtils.toJavaTimestamp(v.getTimestamp))
-    } else if (elementType.hasTimestampNtz) {
-      makeArrayData(v => DateTimeUtils.microsToLocalDateTime(v.getTimestampNtz))
-    } else if (elementType.hasDayTimeInterval) {
-      makeArrayData(v => IntervalUtils.microsToDuration(v.getDayTimeInterval))
-    } else if (elementType.hasYearMonthInterval) {
-      makeArrayData(v => IntervalUtils.monthsToPeriod(v.getYearMonthInterval))
-    } else if (elementType.hasDecimal) {
-      makeArrayData(v => Decimal(v.getDecimal.getValue))
-    } else if (elementType.hasCalendarInterval) {
-      makeArrayData(v => {
-        val interval = v.getCalendarInterval
-        new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
-      })
-    } else if (elementType.hasArray) {
-      makeArrayData(v => toCatalystArray(v.getArray))
-    } else {
-      throw new UnsupportedOperationException(s"Unsupported Literal Type: $elementType)")
+    makeArrayData(getConverter(array.getElementType))
+  }
+
+  def toCatalystMap(map: proto.Expression.Literal.Map): mutable.Map[_, _] = {
+    def makeMapData[K, V](
+        keyConverter: proto.Expression.Literal => K,
+        valueConverter: proto.Expression.Literal => V)(implicit
+        tagK: ClassTag[K],
+        tagV: ClassTag[V]): mutable.Map[K, V] = {
+      val builder = mutable.HashMap.empty[K, V]
+      val keys = map.getKeysList.asScala
+      val values = map.getValuesList.asScala
+      builder.sizeHint(keys.size)
+      keys.zip(values).foreach { case (key, value) =>
+        builder += ((keyConverter(key), valueConverter(value)))
+      }
+      builder
     }
+
+    makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
+  }
+
+  def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = {
+    def toTuple[A <: Object](data: Seq[A]): Product = {
+      try {
+        val tupleClass = Utils.classForName(s"scala.Tuple${data.length}")
+        tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product]
+      } catch {
+        case _: Exception =>
+          throw InvalidPlanInput(s"Unsupported Literal: ${data.mkString("Array(", ", ", ")")})")
+      }
+    }
+
+    val elements = struct.getElementsList.asScala
+    val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
+    val structData = elements
+      .zip(dataTypes)
+      .map { case (element, dataType) =>
+        getConverter(dataType)(element)
+      }
+      .asInstanceOf[Seq[Object]]

Review Comment:
   Otherwise, the Scala 2.13 test will fail
   
   ```
   [info] - function_typedLit *** FAILED *** (14 milliseconds)
   [info]   java.lang.ClassCastException: scala.collection.mutable.ArrayBuffer cannot be cast to scala.collection.immutable.Seq
   [info]   at org.apache.spark.sql.connect.common.LiteralValueProtoConverter$.toCatalystStruct(LiteralValueProtoConverter.scala:389)
   [info]   at org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter$.toCatalystExpression(LiteralExpressionProtoConverter.scala:114)
   [info]   at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLiteral(SparkConnectPlanner.scala:1250)
   [info]   at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformExpression(SparkConnectPlanner.scala:1183)
   [info]   at org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformProject$1(SparkConnectPlanner.scala:1165)
   [info]   at scala.collection.immutable.List.map(List.scala:250)
   [info]   at scala.collection.immutable.List.map(List.scala:79)
   [info]   at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformProject(SparkConnectPlanner.scala:1165)
   [info]   at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:96)
   [info]   at org.apache.spark.sql.connect.ProtoToParsedPlanTestSuite.$anonfun$createTest$2(ProtoToParsedPlanTestSuite.scala:167)
   
   ```



##########
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala:
##########
@@ -211,46 +349,48 @@ object LiteralValueProtoConverter {
       builder.result()
     }
 
-    val elementType = array.getElementType
-    if (elementType.hasShort) {
-      makeArrayData(v => v.getShort.toShort)
-    } else if (elementType.hasInteger) {
-      makeArrayData(v => v.getInteger)
-    } else if (elementType.hasLong) {
-      makeArrayData(v => v.getLong)
-    } else if (elementType.hasDouble) {
-      makeArrayData(v => v.getDouble)
-    } else if (elementType.hasByte) {
-      makeArrayData(v => v.getByte.toByte)
-    } else if (elementType.hasFloat) {
-      makeArrayData(v => v.getFloat)
-    } else if (elementType.hasBoolean) {
-      makeArrayData(v => v.getBoolean)
-    } else if (elementType.hasString) {
-      makeArrayData(v => v.getString)
-    } else if (elementType.hasBinary) {
-      makeArrayData(v => v.getBinary.toByteArray)
-    } else if (elementType.hasDate) {
-      makeArrayData(v => DateTimeUtils.toJavaDate(v.getDate))
-    } else if (elementType.hasTimestamp) {
-      makeArrayData(v => DateTimeUtils.toJavaTimestamp(v.getTimestamp))
-    } else if (elementType.hasTimestampNtz) {
-      makeArrayData(v => DateTimeUtils.microsToLocalDateTime(v.getTimestampNtz))
-    } else if (elementType.hasDayTimeInterval) {
-      makeArrayData(v => IntervalUtils.microsToDuration(v.getDayTimeInterval))
-    } else if (elementType.hasYearMonthInterval) {
-      makeArrayData(v => IntervalUtils.monthsToPeriod(v.getYearMonthInterval))
-    } else if (elementType.hasDecimal) {
-      makeArrayData(v => Decimal(v.getDecimal.getValue))
-    } else if (elementType.hasCalendarInterval) {
-      makeArrayData(v => {
-        val interval = v.getCalendarInterval
-        new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
-      })
-    } else if (elementType.hasArray) {
-      makeArrayData(v => toCatalystArray(v.getArray))
-    } else {
-      throw new UnsupportedOperationException(s"Unsupported Literal Type: $elementType)")
+    makeArrayData(getConverter(array.getElementType))
+  }
+
+  def toCatalystMap(map: proto.Expression.Literal.Map): mutable.Map[_, _] = {
+    def makeMapData[K, V](
+        keyConverter: proto.Expression.Literal => K,
+        valueConverter: proto.Expression.Literal => V)(implicit
+        tagK: ClassTag[K],
+        tagV: ClassTag[V]): mutable.Map[K, V] = {
+      val builder = mutable.HashMap.empty[K, V]
+      val keys = map.getKeysList.asScala
+      val values = map.getValuesList.asScala
+      builder.sizeHint(keys.size)
+      keys.zip(values).foreach { case (key, value) =>
+        builder += ((keyConverter(key), valueConverter(value)))
+      }
+      builder
     }
+
+    makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
+  }
+
+  def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = {
+    def toTuple[A <: Object](data: Seq[A]): Product = {
+      try {
+        val tupleClass = Utils.classForName(s"scala.Tuple${data.length}")
+        tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product]
+      } catch {
+        case _: Exception =>
+          throw InvalidPlanInput(s"Unsupported Literal: ${data.mkString("Array(", ", ", ")")})")
+      }
+    }
+
+    val elements = struct.getElementsList.asScala
+    val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
+    val structData = elements
+      .zip(dataTypes)
+      .map { case (element, dataType) =>
+        getConverter(dataType)(element)
+      }
+      .asInstanceOf[Seq[Object]]

Review Comment:
   ```suggestion
         .asInstanceOf[scala.collection.Seq[Object]].toSeq
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org