You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/13 06:29:56 UTC
[spark] branch branch-3.4 updated: [SPARK-42755][CONNECT] Factor literal value conversion out to `connect-common`
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new def02cb9da3 [SPARK-42755][CONNECT] Factor literal value conversion out to `connect-common`
def02cb9da3 is described below
commit def02cb9da38bb2029e2859cd06f517d547f482f
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Mon Mar 13 14:29:15 2023 +0800
[SPARK-42755][CONNECT] Factor literal value conversion out to `connect-common`
### What changes were proposed in this pull request?
Factor literal value conversion out to `connect-common`.
### Why are the changes needed?
when trying to build protos of literal array in the server side for ml, I found we have two implementations:
`LiteralExpressionProtoConverter. toConnectProtoValue` in server module, but it doesn't support array;
`LiteralProtoConverter. toLiteralProto` in client module, it support more types;
We'd better factor it out to common module, so that both client and server can leverage it.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing UT
Closes #40375 from zhengruifeng/connect_mv_literal_common.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
(cherry picked from commit 43caae31dfa05b3d237acfa3115bd0e7b4e540ed)
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../scala/org/apache/spark/sql/functions.scala | 2 +-
.../common/LiteralValueProtoConverter.scala} | 10 +++---
.../org/apache/spark/sql/connect/dsl/package.scala | 14 ++++----
...scala => LiteralExpressionProtoConverter.scala} | 22 +------------
.../sql/connect/planner/SparkConnectPlanner.scala | 2 +-
.../service/SparkConnectStreamHandler.scala | 4 +--
... => LiteralExpressionProtoConverterSuite.scala} | 7 ++--
.../connect/planner/SparkConnectPlannerSuite.scala | 38 +++++++++-------------
.../connect/planner/SparkConnectProtoSuite.scala | 14 ++++----
9 files changed, 44 insertions(+), 69 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 8ce90886e0f..29c2e89c537 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -23,8 +23,8 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter._
import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction}
-import org.apache.spark.sql.expressions.LiteralProtoConverter._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.DataType.parseTypeWithFallback
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
similarity index 95%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
rename to connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
index daddfa9b5af..ceef9b21244 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
@@ -14,7 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.expressions
+
+package org.apache.spark.sql.connect.common
import java.lang.{Boolean => JBoolean, Byte => JByte, Character => JChar, Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong, Short => JShort}
import java.math.{BigDecimal => JBigDecimal}
@@ -25,12 +26,11 @@ import com.google.protobuf.ByteString
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
-import org.apache.spark.sql.connect.client.unsupported
import org.apache.spark.sql.connect.common.DataTypeProtoConverter._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
-object LiteralProtoConverter {
+object LiteralValueProtoConverter {
private lazy val nullType =
proto.DataType.newBuilder().setNull(proto.DataType.NULL.getDefaultInstance).build()
@@ -93,7 +93,7 @@ object LiteralProtoConverter {
case v: CalendarInterval =>
builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds))
case null => builder.setNull(nullType)
- case _ => unsupported(s"literal $literal not supported (yet).")
+ case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).")
}
}
@@ -103,7 +103,7 @@ object LiteralProtoConverter {
* @return
* proto.Expression.Literal
*/
- private def toLiteralProto(literal: Any): proto.Expression.Literal =
+ def toLiteralProto(literal: Any): proto.Expression.Literal =
toLiteralProtoBuilder(literal).build()
private def toDataType(clz: Class[_]): DataType = clz match {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 7e60c5f9a28..21b9180ccfb 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -26,8 +26,8 @@ import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.connect.proto.SetOperation.SetOpType
import org.apache.spark.sql.{Observation, SaveMode}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.planner.{SaveModeConverter, TableSaveMethodConverter}
-import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@@ -342,7 +342,7 @@ package object dsl {
proto.NAFill
.newBuilder()
.setInput(logicalPlan)
- .addAllValues(Seq(toConnectProtoValue(value)).asJava)
+ .addAllValues(Seq(toLiteralProto(value)).asJava)
.build())
.build()
}
@@ -355,13 +355,13 @@ package object dsl {
.newBuilder()
.setInput(logicalPlan)
.addAllCols(cols.asJava)
- .addAllValues(Seq(toConnectProtoValue(value)).asJava)
+ .addAllValues(Seq(toLiteralProto(value)).asJava)
.build())
.build()
}
def fillValueMap(valueMap: Map[String, Any]): Relation = {
- val (cols, values) = valueMap.mapValues(toConnectProtoValue).toSeq.unzip
+ val (cols, values) = valueMap.mapValues(toLiteralProto).toSeq.unzip
Relation
.newBuilder()
.setFillNa(
@@ -422,8 +422,8 @@ package object dsl {
replace.addReplacements(
proto.NAReplace.Replacement
.newBuilder()
- .setOldValue(toConnectProtoValue(oldValue))
- .setNewValue(toConnectProtoValue(newValue)))
+ .setOldValue(toLiteralProto(oldValue))
+ .setNewValue(toLiteralProto(newValue)))
}
Relation
@@ -978,7 +978,7 @@ package object dsl {
def hint(name: String, parameters: Any*): Relation = {
val expressions = parameters.map { parameter =>
- proto.Expression.newBuilder().setLiteral(toConnectProtoValue(parameter)).build()
+ proto.Expression.newBuilder().setLiteral(toLiteralProto(parameter)).build()
}
Relation
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
similarity index 87%
rename from connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
rename to connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
index 7a580913867..9f2baea5737 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanI
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-object LiteralValueProtoConverter {
+object LiteralExpressionProtoConverter {
/**
* Transforms the protocol buffers literals into the appropriate Catalyst literal expression.
@@ -121,25 +121,6 @@ object LiteralValueProtoConverter {
}
}
- def toConnectProtoValue(value: Any): proto.Expression.Literal = {
- value match {
- case null =>
- proto.Expression.Literal
- .newBuilder()
- .setNull(DataTypeProtoConverter.toConnectProtoType(NullType))
- .build()
- case b: Boolean => proto.Expression.Literal.newBuilder().setBoolean(b).build()
- case b: Byte => proto.Expression.Literal.newBuilder().setByte(b).build()
- case s: Short => proto.Expression.Literal.newBuilder().setShort(s).build()
- case i: Int => proto.Expression.Literal.newBuilder().setInteger(i).build()
- case l: Long => proto.Expression.Literal.newBuilder().setLong(l).build()
- case f: Float => proto.Expression.Literal.newBuilder().setFloat(f).build()
- case d: Double => proto.Expression.Literal.newBuilder().setDouble(d).build()
- case s: String => proto.Expression.Literal.newBuilder().setString(s).build()
- case o => throw new Exception(s"Unsupported value type: $o")
- }
- }
-
private def toArrayData(array: proto.Expression.Literal.Array): Any = {
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
tag: ClassTag[T]): Array[T] = {
@@ -195,5 +176,4 @@ object LiteralValueProtoConverter {
throw InvalidPlanInput(s"Unsupported Literal Type: $elementType)")
}
}
-
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 24717e07b00..a057bd8d6c1 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResul
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, UdfPacket}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
-import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue}
+import org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter.{toCatalystExpression, toCatalystValue}
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.connect.service.SparkConnectStreamHandler
import org.apache.spark.sql.errors.QueryCompilationErrors
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 0dd1741f099..104d840ed52 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -28,8 +28,8 @@ import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
-import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsArrowBatches
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
@@ -216,7 +216,7 @@ object SparkConnectStreamHandler {
sessionId: String,
dataframe: DataFrame): ExecutePlanResponse = {
val observedMetrics = dataframe.queryExecution.observedMetrics.map { case (name, row) =>
- val cols = (0 until row.length).map(i => toConnectProtoValue(row(i)))
+ val cols = (0 until row.length).map(i => toLiteralProto(row(i)))
ExecutePlanResponse.ObservedMetrics
.newBuilder()
.setName(name)
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
similarity index 76%
rename from connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala
rename to connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
index 7c8ee6209ac..c3479456617 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
@@ -19,14 +19,15 @@ package org.apache.spark.sql.connect.planner
import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
-import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystValue, toConnectProtoValue}
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
+import org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter.toCatalystValue
-class LiteralValueProtoConverterSuite extends AnyFunSuite { // scalastyle:ignore funsuite
+class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:ignore funsuite
test("basic proto value and catalyst value conversion") {
val values = Array(null, true, 1.toByte, 1.toShort, 1, 1L, 1.1d, 1.1f, "spark")
for (v <- values) {
- assertResult(v)(toCatalystValue(toConnectProtoValue(v)))
+ assertResult(v)(toCatalystValue(toLiteralProto(v)))
}
}
}
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index b79d91d2d10..b6b214c839d 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.connect.common.InvalidPlanInput
-import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
@@ -602,13 +602,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
val logical = transform(
proto.Relation
.newBuilder()
- .setHint(
- proto.Hint
- .newBuilder()
- .setInput(input)
- .setName("REPARTITION")
- .addParameters(
- proto.Expression.newBuilder().setLiteral(toConnectProtoValue(10000)).build()))
+ .setHint(proto.Hint
+ .newBuilder()
+ .setInput(input)
+ .setName("REPARTITION")
+ .addParameters(proto.Expression.newBuilder().setLiteral(toLiteralProto(10000)).build()))
.build())
val df = Dataset.ofRows(spark, logical)
@@ -648,13 +646,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
val logical = transform(
proto.Relation
.newBuilder()
- .setHint(
- proto.Hint
- .newBuilder()
- .setInput(input)
- .setName("REPARTITION")
- .addParameters(
- proto.Expression.newBuilder().setLiteral(toConnectProtoValue("id")).build()))
+ .setHint(proto.Hint
+ .newBuilder()
+ .setInput(input)
+ .setName("REPARTITION")
+ .addParameters(proto.Expression.newBuilder().setLiteral(toLiteralProto("id")).build()))
.build())
assert(10 === Dataset.ofRows(spark, logical).count())
}
@@ -671,13 +667,11 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
val logical = transform(
proto.Relation
.newBuilder()
- .setHint(
- proto.Hint
- .newBuilder()
- .setInput(input)
- .setName("REPARTITION")
- .addParameters(
- proto.Expression.newBuilder().setLiteral(toConnectProtoValue(true)).build()))
+ .setHint(proto.Hint
+ .newBuilder()
+ .setInput(input)
+ .setName("REPARTITION")
+ .addParameters(proto.Expression.newBuilder().setLiteral(toLiteralProto(true)).build()))
.build())
intercept[AnalysisException](Dataset.ofRows(spark, logical))
}
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 00ff6ac2fb6..9cc714d630b 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -32,11 +32,11 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInt
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LocalRelation, LogicalPlan}
import org.apache.spark.sql.connect.common.InvalidPlanInput
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.commands._
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
-import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
import org.apache.spark.sql.execution.arrow.ArrowConverters
@@ -245,7 +245,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
val connectPlan3 =
connectTestRelation.rollup("id".protoAttr, "name".protoAttr)(
- proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build())
+ proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build())
.as("agg1"))
val sparkPlan3 =
sparkTestRelation
@@ -269,7 +269,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
val connectPlan3 =
connectTestRelation.cube("id".protoAttr, "name".protoAttr)(
- proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build())
+ proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build())
.as("agg1"))
val sparkPlan3 =
sparkTestRelation
@@ -282,8 +282,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
val connectPlan1 =
connectTestRelation.pivot("id".protoAttr)(
"name".protoAttr,
- Seq("a", "b", "c").map(toConnectProtoValue))(
- proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build())
+ Seq("a", "b", "c").map(toLiteralProto))(
+ proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build())
.as("agg1"))
val sparkPlan1 =
sparkTestRelation
@@ -295,8 +295,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
val connectPlan2 =
connectTestRelation.pivot("name".protoAttr)(
"id".protoAttr,
- Seq(1, 2, 3).map(toConnectProtoValue))(
- proto_min(proto.Expression.newBuilder().setLiteral(toConnectProtoValue(1)).build())
+ Seq(1, 2, 3).map(toLiteralProto))(
+ proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build())
.as("agg1"))
val sparkPlan2 =
sparkTestRelation
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org