You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2023/02/28 05:33:08 UTC
[spark] branch branch-3.4 updated: [SPARK-42610][CONNECT] Add encoders to SQLImplicits
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 816774aa532 [SPARK-42610][CONNECT] Add encoders to SQLImplicits
816774aa532 is described below
commit 816774aa532ae4e017c937c86fdb784df200ee0e
Author: Herman van Hovell <he...@databricks.com>
AuthorDate: Mon Feb 27 21:32:49 2023 -0800
[SPARK-42610][CONNECT] Add encoders to SQLImplicits
### What changes were proposed in this pull request?
Add implicit encoder resolution to `SQLImplicits` class.
### Why are the changes needed?
API parity.
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
Added test to `SQLImplicitsTestSuite`.
Closes #40205 from hvanhovell/SPARK-42610.
Authored-by: Herman van Hovell <he...@databricks.com>
Signed-off-by: Dongjoon Hyun <do...@apache.org>
(cherry picked from commit 968f280fd0d488372b0b09738ff9728b45499bef)
Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
.../scala/org/apache/spark/sql/SQLImplicits.scala | 240 ++++++++++++++++++++-
.../scala/org/apache/spark/sql/SparkSession.scala | 2 +-
.../apache/spark/sql/SQLImplicitsTestSuite.scala | 95 ++++++++
3 files changed, 334 insertions(+), 3 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index e63c9481da5..8f429541def 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -17,13 +17,20 @@
package org.apache.spark.sql
import scala.language.implicitConversions
+import scala.reflect.classTag
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
/**
- * A collection of implicit methods for converting names and Symbols into [[Column]]s.
+ * A collection of implicit methods for converting names and Symbols into [[Column]]s, and for
+ * converting common Scala objects into [[Dataset]]s.
*
* @since 3.4.0
*/
-abstract class SQLImplicits {
+abstract class SQLImplicits extends LowPrioritySQLImplicits {
/**
* Converts $"col name" into a [[Column]].
@@ -41,4 +48,233 @@ abstract class SQLImplicits {
* @since 3.4.0
*/
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
+ /** @since 3.4.0 */
+ implicit val newIntEncoder: Encoder[Int] = PrimitiveIntEncoder
+
+ /** @since 3.4.0 */
+ implicit val newLongEncoder: Encoder[Long] = PrimitiveLongEncoder
+
+ /** @since 3.4.0 */
+ implicit val newDoubleEncoder: Encoder[Double] = PrimitiveDoubleEncoder
+
+ /** @since 3.4.0 */
+ implicit val newFloatEncoder: Encoder[Float] = PrimitiveFloatEncoder
+
+ /** @since 3.4.0 */
+ implicit val newByteEncoder: Encoder[Byte] = PrimitiveByteEncoder
+
+ /** @since 3.4.0 */
+ implicit val newShortEncoder: Encoder[Short] = PrimitiveShortEncoder
+
+ /** @since 3.4.0 */
+ implicit val newBooleanEncoder: Encoder[Boolean] = PrimitiveBooleanEncoder
+
+ /** @since 3.4.0 */
+ implicit val newStringEncoder: Encoder[String] = StringEncoder
+
+ /** @since 3.4.0 */
+ implicit val newJavaDecimalEncoder: Encoder[java.math.BigDecimal] =
+ AgnosticEncoders.DEFAULT_JAVA_DECIMAL_ENCODER
+
+ /** @since 3.4.0 */
+ implicit val newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] =
+ AgnosticEncoders.DEFAULT_SCALA_DECIMAL_ENCODER
+
+ /** @since 3.4.0 */
+ implicit val newDateEncoder: Encoder[java.sql.Date] = AgnosticEncoders.STRICT_DATE_ENCODER
+
+ /** @since 3.4.0 */
+ implicit val newLocalDateEncoder: Encoder[java.time.LocalDate] =
+ AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER
+
+ /** @since 3.4.0 */
+ implicit val newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] =
+ AgnosticEncoders.LocalDateTimeEncoder
+
+ /** @since 3.4.0 */
+ implicit val newTimeStampEncoder: Encoder[java.sql.Timestamp] =
+ AgnosticEncoders.STRICT_TIMESTAMP_ENCODER
+
+ /** @since 3.4.0 */
+ implicit val newInstantEncoder: Encoder[java.time.Instant] =
+ AgnosticEncoders.STRICT_INSTANT_ENCODER
+
+ /** @since 3.4.0 */
+ implicit val newDurationEncoder: Encoder[java.time.Duration] = DayTimeIntervalEncoder
+
+ /** @since 3.4.0 */
+ implicit val newPeriodEncoder: Encoder[java.time.Period] = YearMonthIntervalEncoder
+
+ /** @since 3.4.0 */
+ implicit def newJavaEnumEncoder[A <: java.lang.Enum[_]: TypeTag]: Encoder[A] = {
+ ScalaReflection.encoderFor[A]
+ }
+
+ // Boxed primitives
+
+ /** @since 3.4.0 */
+ implicit val newBoxedIntEncoder: Encoder[java.lang.Integer] = BoxedIntEncoder
+
+ /** @since 3.4.0 */
+ implicit val newBoxedLongEncoder: Encoder[java.lang.Long] = BoxedLongEncoder
+
+ /** @since 3.4.0 */
+ implicit val newBoxedDoubleEncoder: Encoder[java.lang.Double] = BoxedDoubleEncoder
+
+ /** @since 3.4.0 */
+ implicit val newBoxedFloatEncoder: Encoder[java.lang.Float] = BoxedFloatEncoder
+
+ /** @since 3.4.0 */
+ implicit val newBoxedByteEncoder: Encoder[java.lang.Byte] = BoxedByteEncoder
+
+ /** @since 3.4.0 */
+ implicit val newBoxedShortEncoder: Encoder[java.lang.Short] = BoxedShortEncoder
+
+ /** @since 3.4.0 */
+ implicit val newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = BoxedBooleanEncoder
+
+ // Seqs
+ private def newSeqEncoder[E](elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Seq[E]] = {
+ IterableEncoder(
+ classTag[Seq[E]],
+ elementEncoder,
+ elementEncoder.nullable,
+ elementEncoder.lenientSerialization)
+ }
+
+ /**
+ * @since 3.4.0
+ * @deprecated
+ * use [[newSequenceEncoder]]
+ */
+ val newIntSeqEncoder: Encoder[Seq[Int]] = newSeqEncoder(PrimitiveIntEncoder)
+
+ /**
+ * @since 3.4.0
+ * @deprecated
+ * use [[newSequenceEncoder]]
+ */
+ val newLongSeqEncoder: Encoder[Seq[Long]] = newSeqEncoder(PrimitiveLongEncoder)
+
+ /**
+ * @since 3.4.0
+ * @deprecated
+ * use [[newSequenceEncoder]]
+ */
+ val newDoubleSeqEncoder: Encoder[Seq[Double]] = newSeqEncoder(PrimitiveDoubleEncoder)
+
+ /**
+ * @since 3.4.0
+ * @deprecated
+ * use [[newSequenceEncoder]]
+ */
+ val newFloatSeqEncoder: Encoder[Seq[Float]] = newSeqEncoder(PrimitiveFloatEncoder)
+
+ /**
+ * @since 3.4.0
+ * @deprecated
+ * use [[newSequenceEncoder]]
+ */
+ val newByteSeqEncoder: Encoder[Seq[Byte]] = newSeqEncoder(PrimitiveByteEncoder)
+
+ /**
+ * @since 3.4.0
+ * @deprecated
+ * use [[newSequenceEncoder]]
+ */
+ val newShortSeqEncoder: Encoder[Seq[Short]] = newSeqEncoder(PrimitiveShortEncoder)
+
+ /**
+ * @since 3.4.0
+ * @deprecated
+ * use [[newSequenceEncoder]]
+ */
+ val newBooleanSeqEncoder: Encoder[Seq[Boolean]] = newSeqEncoder(PrimitiveBooleanEncoder)
+
+ /**
+ * @since 3.4.0
+ * @deprecated
+ * use [[newSequenceEncoder]]
+ */
+ val newStringSeqEncoder: Encoder[Seq[String]] = newSeqEncoder(StringEncoder)
+
+ /**
+ * @since 3.4.0
+ * @deprecated
+ * use [[newSequenceEncoder]]
+ */
+ def newProductSeqEncoder[A <: Product: TypeTag]: Encoder[Seq[A]] =
+ newSeqEncoder(ScalaReflection.encoderFor[A])
+
+ /** @since 3.4.0 */
+ implicit def newSequenceEncoder[T <: Seq[_]: TypeTag]: Encoder[T] =
+ ScalaReflection.encoderFor[T]
+
+ // Maps
+ /** @since 3.4.0 */
+ implicit def newMapEncoder[T <: Map[_, _]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T]
+
+ /**
+ * Notice that we serialize `Set` to Catalyst array. The set property is only kept when
+ * manipulating the domain objects. The serialization format doesn't keep the set property. When
+ * we have a Catalyst array which contains duplicated elements and convert it to
+ * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated.
+ *
+ * @since 3.4.0
+ */
+ implicit def newSetEncoder[T <: Set[_]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T]
+
+ // Arrays
+ private def newArrayEncoder[E](
+ elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Array[E]] = {
+ ArrayEncoder(elementEncoder, elementEncoder.nullable)
+ }
+
+ /** @since 3.4.0 */
+ implicit val newIntArrayEncoder: Encoder[Array[Int]] = newArrayEncoder(PrimitiveIntEncoder)
+
+ /** @since 3.4.0 */
+ implicit val newLongArrayEncoder: Encoder[Array[Long]] = newArrayEncoder(PrimitiveLongEncoder)
+
+ /** @since 3.4.0 */
+ implicit val newDoubleArrayEncoder: Encoder[Array[Double]] =
+ newArrayEncoder(PrimitiveDoubleEncoder)
+
+ /** @since 3.4.0 */
+ implicit val newFloatArrayEncoder: Encoder[Array[Float]] = newArrayEncoder(
+ PrimitiveFloatEncoder)
+
+ /** @since 3.4.0 */
+ implicit val newByteArrayEncoder: Encoder[Array[Byte]] = BinaryEncoder
+
+ /** @since 3.4.0 */
+ implicit val newShortArrayEncoder: Encoder[Array[Short]] = newArrayEncoder(
+ PrimitiveShortEncoder)
+
+ /** @since 3.4.0 */
+ implicit val newBooleanArrayEncoder: Encoder[Array[Boolean]] =
+ newArrayEncoder(PrimitiveBooleanEncoder)
+
+ /** @since 3.4.0 */
+ implicit val newStringArrayEncoder: Encoder[Array[String]] = newArrayEncoder(StringEncoder)
+
+ /** @since 3.4.0 */
+ implicit def newProductArrayEncoder[A <: Product: TypeTag]: Encoder[Array[A]] = {
+ newArrayEncoder(ScalaReflection.encoderFor[A])
+ }
+}
+
+/**
+ * Lower priority implicit methods for converting Scala objects into [[Dataset]]s. Conflicting
+ * implicits are placed here to disambiguate resolution.
+ *
+ * Reasons for including specific implicits: newProductEncoder - to disambiguate for `List`s which
+ * are both `Seq` and `Product`
+ */
+trait LowPrioritySQLImplicits {
+
+ /** @since 3.4.0 */
+ implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] =
+ ScalaReflection.encoderFor[T]
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 3aed781855c..fa13af00f14 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -207,7 +207,7 @@ class SparkSession(
// Disable style checker so "implicits" object can start with lowercase i
/**
* (Scala-specific) Implicit methods available in Scala for converting common names and
- * [[Symbol]]s into [[Column]]s.
+ * [[Symbol]]s into [[Column]]s, and for converting common Scala objects into `DataFrame`s.
*
* {{{
* val sparkSession = SparkSession.builder.getOrCreate()
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
index 1f141d7c71a..3fcc135a22e 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
@@ -16,15 +16,21 @@
*/
package org.apache.spark.sql
+import java.sql.{Date, Timestamp}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import java.util.concurrent.atomic.AtomicLong
import io.grpc.inprocess.InProcessChannelBuilder
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder}
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+/**
+ * Test suite for SQL implicits.
+ */
class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll {
private var session: SparkSession = _
@@ -44,4 +50,93 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll {
assertEqual($"x", Column("x"))
assertEqual('y, Column("y"))
}
+
+ test("test implicit encoder resolution") {
+ val spark = session
+ import spark.implicits._
+ def testImplicit[T: Encoder](expected: T): Unit = {
+ val encoder = implicitly[Encoder[T]].asInstanceOf[AgnosticEncoder[T]]
+ val expressionEncoder = ExpressionEncoder(encoder).resolveAndBind()
+ val serializer = expressionEncoder.createSerializer()
+ val deserializer = expressionEncoder.createDeserializer()
+ val actual = deserializer(serializer(expected))
+ assert(actual === expected)
+ }
+
+ val booleans = Array(false, true, false, false)
+ testImplicit(booleans.head)
+ testImplicit(java.lang.Boolean.valueOf(booleans.head))
+ testImplicit(booleans)
+ testImplicit(booleans.toSeq)
+ testImplicit(booleans.toSeq)(newBooleanSeqEncoder)
+
+ val bytes = Array(76.toByte, 59.toByte, 121.toByte)
+ testImplicit(bytes.head)
+ testImplicit(java.lang.Byte.valueOf(bytes.head))
+ testImplicit(bytes)
+ testImplicit(bytes.toSeq)
+ testImplicit(bytes.toSeq)(newByteSeqEncoder)
+
+ val shorts = Array(21.toShort, (-213).toShort, 14876.toShort)
+ testImplicit(shorts.head)
+ testImplicit(java.lang.Short.valueOf(shorts.head))
+ testImplicit(shorts)
+ testImplicit(shorts.toSeq)
+ testImplicit(shorts.toSeq)(newShortSeqEncoder)
+
+ val ints = Array(4, 6, 5)
+ testImplicit(ints.head)
+ testImplicit(java.lang.Integer.valueOf(ints.head))
+ testImplicit(ints)
+ testImplicit(ints.toSeq)
+ testImplicit(ints.toSeq)(newIntSeqEncoder)
+
+ val longs = Array(System.nanoTime(), System.currentTimeMillis())
+ testImplicit(longs.head)
+ testImplicit(java.lang.Long.valueOf(longs.head))
+ testImplicit(longs)
+ testImplicit(longs.toSeq)
+ testImplicit(longs.toSeq)(newLongSeqEncoder)
+
+ val floats = Array(3f, 10.9f)
+ testImplicit(floats.head)
+ testImplicit(java.lang.Float.valueOf(floats.head))
+ testImplicit(floats)
+ testImplicit(floats.toSeq)
+ testImplicit(floats.toSeq)(newFloatSeqEncoder)
+
+ val doubles = Array(23.78d, -329.6d)
+ testImplicit(doubles.head)
+ testImplicit(java.lang.Double.valueOf(doubles.head))
+ testImplicit(doubles)
+ testImplicit(doubles.toSeq)
+ testImplicit(doubles.toSeq)(newDoubleSeqEncoder)
+
+ val strings = Array("foo", "baz", "bar")
+ testImplicit(strings.head)
+ testImplicit(strings)
+ testImplicit(strings.toSeq)
+ testImplicit(strings.toSeq)(newStringSeqEncoder)
+
+ val myTypes = Array(MyType(12L, Math.E, Math.PI), MyType(0, 0, 0))
+ testImplicit(myTypes.head)
+ testImplicit(myTypes)
+ testImplicit(myTypes.toSeq)
+ testImplicit(myTypes.toSeq)(newProductSeqEncoder[MyType])
+
+ // Others.
+ val decimal = java.math.BigDecimal.valueOf(3141527000000000000L, 18)
+ testImplicit(decimal)
+ testImplicit(BigDecimal(decimal))
+ testImplicit(Date.valueOf(LocalDate.now()))
+ testImplicit(LocalDate.now())
+ testImplicit(LocalDateTime.now())
+ testImplicit(Instant.now())
+ testImplicit(Timestamp.from(Instant.now()))
+ testImplicit(Period.ofYears(2))
+ testImplicit(Duration.ofMinutes(77))
+ testImplicit(SaveMode.Append)
+ testImplicit(Map(("key", "value"), ("foo", "baz")))
+ testImplicit(Set(1, 2, 4))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org