You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/27 13:15:23 UTC
[spark] branch master updated: [SPARK-42580][CONNECT] Scala client add client side typed APIs
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5243d0be2c1 [SPARK-42580][CONNECT] Scala client add client side typed APIs
5243d0be2c1 is described below
commit 5243d0be2c15e3af36e981a9487ea600ab86a808
Author: Herman van Hovell <he...@databricks.com>
AuthorDate: Mon Feb 27 09:15:09 2023 -0400
[SPARK-42580][CONNECT] Scala client add client side typed APIs
### What changes were proposed in this pull request?
This PR adds the client side typed API to the Spark Connect Scala Client.
### Why are the changes needed?
We want to reach API parity with the existing APIs.
### Does this PR introduce _any_ user-facing change?
Yes, it adds user API.
### How was this patch tested?
Added tests to `ClientE2ETestSuite`, and updated existing tests.
Closes #40175 from hvanhovell/SPARK-42580.
Authored-by: Herman van Hovell <he...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../org/apache/spark/sql/DataFrameReader.scala | 4 +-
.../main/scala/org/apache/spark/sql/Dataset.scala | 122 +++++++++++++--------
.../spark/sql/RelationalGroupedDataset.scala | 2 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 29 +++--
.../spark/sql/connect/client/SparkResult.scala | 37 ++++---
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 80 ++++++++++++--
.../scala/org/apache/spark/sql/DatasetSuite.scala | 6 +-
.../apache/spark/sql/PlanGenerationTestSuite.scala | 2 +-
.../sql/catalyst/encoders/AgnosticEncoder.scala | 11 +-
9 files changed, 204 insertions(+), 89 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 5a486efee31..3e17b03173b 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -171,7 +171,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
*/
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
- sparkSession.newDataset { builder =>
+ sparkSession.newDataFrame { builder =>
val dataSourceBuilder = builder.getReadBuilder.getDataSourceBuilder
assertSourceFormatSpecified()
dataSourceBuilder.setFormat(source)
@@ -308,7 +308,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
* @since 3.4.0
*/
def table(tableName: String): DataFrame = {
- sparkSession.newDataset { builder =>
+ sparkSession.newDataFrame { builder =>
builder.getReadBuilder.getNamedTableBuilder.setUnparsedIdentifier(tableName)
}
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index dcc770dfe55..73de35456fc 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -23,6 +23,8 @@ import scala.collection.mutable
import scala.util.control.NonFatal
import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder, StringEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
@@ -116,7 +118,10 @@ import org.apache.spark.util.Utils
*
* @since 3.4.0
*/
-class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val plan: proto.Plan)
+class Dataset[T] private[sql] (
+ val sparkSession: SparkSession,
+ private[sql] val plan: proto.Plan,
+ val encoder: AgnosticEncoder[T])
extends Serializable {
// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)
@@ -151,9 +156,32 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @group basic
* @since 3.4.0
*/
- def toDF(): DataFrame = {
- // Note this will change as soon as we add the typed APIs.
- this.asInstanceOf[Dataset[Row]]
+ def toDF(): DataFrame = new Dataset(sparkSession, plan, UnboundRowEncoder)
+
+ /**
+ * Returns a new Dataset where each record has been mapped on to the specified type. The method
+ * used to map columns depend on the type of `U`: <ul> <li>When `U` is a class, fields for the
+ * class will be mapped to columns of the same name (case sensitivity is determined by
+ * `spark.sql.caseSensitive`).</li> <li>When `U` is a tuple, the columns will be mapped by
+ * ordinal (i.e. the first column will be assigned to `_1`).</li> <li>When `U` is a primitive
+ * type (i.e. String, Int, etc), then the first column of the `DataFrame` will be used.</li>
+ * </ul>
+ *
+ * If the schema of the Dataset does not match the desired `U` type, you can use `select` along
+ * with `alias` or `as` to rearrange or rename as required.
+ *
+ * Note that `as[]` only changes the view of the data that is passed into typed operations, such
+ * as `map()`, and does not eagerly project away any columns that are not present in the
+ * specified class.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def as[U: Encoder]: Dataset[U] = {
+ val encoder = implicitly[Encoder[U]].asInstanceOf[AgnosticEncoder[U]]
+ // We should add some validation/coercion here. We cannot use `to`
+ // because that does not work with positional arguments.
+ new Dataset[U](sparkSession, plan, encoder)
}
/**
@@ -170,7 +198,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
@scala.annotation.varargs
- def toDF(colNames: String*): DataFrame = sparkSession.newDataset { builder =>
+ def toDF(colNames: String*): DataFrame = sparkSession.newDataFrame { builder =>
builder.getToDfBuilder
.setInput(plan.getRoot)
.addAllColumnNames(colNames.asJava)
@@ -192,7 +220,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @group basic
* @since 3.4.0
*/
- def to(schema: StructType): DataFrame = sparkSession.newDataset { builder =>
+ def to(schema: StructType): DataFrame = sparkSession.newDataFrame { builder =>
builder.getToSchemaBuilder
.setInput(plan.getRoot)
.setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
@@ -205,7 +233,11 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
def schema: StructType = {
- DataTypeProtoConverter.toCatalystType(analyze.getSchema).asInstanceOf[StructType]
+ if (encoder == UnboundRowEncoder) {
+ DataTypeProtoConverter.toCatalystType(analyze.getSchema).asInstanceOf[StructType]
+ } else {
+ encoder.schema
+ }
}
/**
@@ -469,7 +501,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
def show(numRows: Int, truncate: Int, vertical: Boolean): Unit = {
- val df = sparkSession.newDataset { builder =>
+ val df = sparkSession.newDataset(StringEncoder) { builder =>
builder.getShowStringBuilder
.setInput(plan.getRoot)
.setNumRows(numRows)
@@ -480,13 +512,13 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
assert(result.length == 1)
assert(result.schema.size == 1)
// scalastyle:off println
- println(result.toArray.head.getString(0))
+ println(result.toArray.head)
// scalastyle:on println
}
}
private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = {
- sparkSession.newDataset { builder =>
+ sparkSession.newDataFrame { builder =>
val joinBuilder = builder.getJoinBuilder
joinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
f(joinBuilder)
@@ -752,7 +784,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
}
private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
- sparkSession.newDataset { builder =>
+ sparkSession.newDataset(encoder) { builder =>
builder.getSortBuilder
.setInput(plan.getRoot)
.setIsGlobal(global)
@@ -860,11 +892,12 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
@scala.annotation.varargs
- def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset { builder =>
- builder.getHintBuilder
- .setInput(plan.getRoot)
- .setName(name)
- .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
+ def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset(encoder) {
+ builder =>
+ builder.getHintBuilder
+ .setInput(plan.getRoot)
+ .setName(name)
+ .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
}
/**
@@ -900,7 +933,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @group typedrel
* @since 3.4.0
*/
- def as(alias: String): Dataset[T] = sparkSession.newDataset { builder =>
+ def as(alias: String): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
builder.getSubqueryAliasBuilder
.setInput(plan.getRoot)
.setAlias(alias)
@@ -940,7 +973,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
@scala.annotation.varargs
- def select(cols: Column*): DataFrame = sparkSession.newDataset { builder =>
+ def select(cols: Column*): DataFrame = sparkSession.newDataFrame { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addAllExpressions(cols.map(_.expr).asJava)
@@ -990,7 +1023,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @group typedrel
* @since 3.4.0
*/
- def filter(condition: Column): Dataset[T] = sparkSession.newDataset { builder =>
+ def filter(condition: Column): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
}
@@ -1033,7 +1066,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
ids: Array[Column],
valuesOption: Option[Array[Column]],
variableColumnName: String,
- valueColumnName: String): DataFrame = sparkSession.newDataset { builder =>
+ valueColumnName: String): DataFrame = sparkSession.newDataFrame { builder =>
val unpivot = builder.getUnpivotBuilder
.setInput(plan.getRoot)
.addAllIds(ids.toSeq.map(_.expr).asJava)
@@ -1423,7 +1456,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @group typedrel
* @since 3.4.0
*/
- def limit(n: Int): Dataset[T] = sparkSession.newDataset { builder =>
+ def limit(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
builder.getLimitBuilder
.setInput(plan.getRoot)
.setLimit(n)
@@ -1435,7 +1468,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @group typedrel
* @since 3.4.0
*/
- def offset(n: Int): Dataset[T] = sparkSession.newDataset { builder =>
+ def offset(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
builder.getOffsetBuilder
.setInput(plan.getRoot)
.setOffset(n)
@@ -1443,7 +1476,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
private def buildSetOp(right: Dataset[T], setOpType: proto.SetOperation.SetOpType)(
f: proto.SetOperation.Builder => Unit): Dataset[T] = {
- sparkSession.newDataset { builder =>
+ sparkSession.newDataset(encoder) { builder =>
f(
builder.getSetOpBuilder
.setSetOpType(setOpType)
@@ -1707,7 +1740,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
- sparkSession.newDataset { builder =>
+ sparkSession.newDataset(encoder) { builder =>
builder.getSampleBuilder
.setInput(plan.getRoot)
.setWithReplacement(withReplacement)
@@ -1775,7 +1808,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
normalizedCumWeights
.sliding(2)
.map { case Array(low, high) =>
- sparkSession.newDataset[T] { builder =>
+ sparkSession.newDataset(encoder) { builder =>
builder.getSampleBuilder
.setInput(sortedInput)
.setWithReplacement(false)
@@ -1819,7 +1852,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
val aliases = values.zip(names).map { case (value, name) =>
value.name(name).expr.getAlias
}
- sparkSession.newDataset { builder =>
+ sparkSession.newDataFrame { builder =>
builder.getWithColumnsBuilder
.setInput(plan.getRoot)
.addAllAliases(aliases.asJava)
@@ -1910,7 +1943,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
def withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame = {
- sparkSession.newDataset { builder =>
+ sparkSession.newDataFrame { builder =>
builder.getWithColumnsRenamedBuilder
.setInput(plan.getRoot)
.putAllRenameColumnsMap(colsMap)
@@ -1929,7 +1962,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
.setExpr(col(columnName).expr)
.addName(columnName)
.setMetadata(metadata.json)
- sparkSession.newDataset { builder =>
+ sparkSession.newDataFrame { builder =>
builder.getWithColumnsBuilder
.setInput(plan.getRoot)
.addAliases(newAlias)
@@ -2083,7 +2116,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
@scala.annotation.varargs
def drop(col: Column, cols: Column*): DataFrame = buildDrop(col +: cols)
- private def buildDrop(cols: Seq[Column]): DataFrame = sparkSession.newDataset { builder =>
+ private def buildDrop(cols: Seq[Column]): DataFrame = sparkSession.newDataFrame { builder =>
builder.getDropBuilder
.setInput(plan.getRoot)
.addAllCols(cols.map(_.expr).asJava)
@@ -2096,7 +2129,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @group typedrel
* @since 3.4.0
*/
- def dropDuplicates(): Dataset[T] = sparkSession.newDataset { builder =>
+ def dropDuplicates(): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
builder.getDeduplicateBuilder
.setInput(plan.getRoot)
.setAllColumnsAsKeys(true)
@@ -2109,10 +2142,11 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @group typedrel
* @since 3.4.0
*/
- def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset { builder =>
- builder.getDeduplicateBuilder
- .setInput(plan.getRoot)
- .addAllColumnNames(colNames.asJava)
+ def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset(encoder) {
+ builder =>
+ builder.getDeduplicateBuilder
+ .setInput(plan.getRoot)
+ .addAllColumnNames(colNames.asJava)
}
/**
@@ -2166,7 +2200,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
@scala.annotation.varargs
- def describe(cols: String*): DataFrame = sparkSession.newDataset { builder =>
+ def describe(cols: String*): DataFrame = sparkSession.newDataFrame { builder =>
builder.getDescribeBuilder
.setInput(plan.getRoot)
.addAllCols(cols.asJava)
@@ -2241,7 +2275,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
@scala.annotation.varargs
- def summary(statistics: String*): DataFrame = sparkSession.newDataset { builder =>
+ def summary(statistics: String*): DataFrame = sparkSession.newDataFrame { builder =>
builder.getSummaryBuilder
.setInput(plan.getRoot)
.addAllStatistics(statistics.asJava)
@@ -2309,7 +2343,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
def tail(n: Int): Array[T] = {
- val lastN = sparkSession.newDataset[T] { builder =>
+ val lastN = sparkSession.newDataset(encoder) { builder =>
builder.getTailBuilder
.setInput(plan.getRoot)
.setLimit(n)
@@ -2340,7 +2374,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
def collect(): Array[T] = withResult { result =>
- result.toArray.asInstanceOf[Array[T]]
+ result.toArray
}
/**
@@ -2368,7 +2402,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
*/
def toLocalIterator(): java.util.Iterator[T] = {
// TODO make this a destructive iterator.
- collectResult().iterator.asInstanceOf[java.util.Iterator[T]]
+ collectResult().iterator
}
/**
@@ -2377,11 +2411,11 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
* @since 3.4.0
*/
def count(): Long = {
- groupBy().count().collect().head.getLong(0)
+ groupBy().count().as(PrimitiveLongEncoder).collect().head
}
private def buildRepartition(numPartitions: Int, shuffle: Boolean): Dataset[T] = {
- sparkSession.newDataset { builder =>
+ sparkSession.newDataset(encoder) { builder =>
builder.getRepartitionBuilder
.setInput(plan.getRoot)
.setNumPartitions(numPartitions)
@@ -2391,7 +2425,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
private def buildRepartitionByExpression(
numPartitions: Option[Int],
- partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset { builder =>
+ partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
@@ -2651,9 +2685,9 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
sparkSession.analyze(plan, proto.Explain.ExplainMode.SIMPLE)
}
- def collectResult(): SparkResult = sparkSession.execute(plan)
+ def collectResult(): SparkResult[T] = sparkSession.execute(plan, encoder)
- private[sql] def withResult[E](f: SparkResult => E): E = {
+ private[sql] def withResult[E](f: SparkResult[T] => E): E = {
val result = collectResult()
try f(result)
finally {
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 76d3ab5cf09..89bc5bfec57 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -42,7 +42,7 @@ class RelationalGroupedDataset protected[sql] (
pivot: Option[proto.Aggregate.Pivot] = None) {
private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
- df.sparkSession.newDataset { builder =>
+ df.sparkSession.newDataFrame { builder =>
builder.getAggregateBuilder
.setInput(df.plan.getRoot)
.addAllGroupingExpressions(groupingExprs.asJava)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index e85c7008ca9..3aed781855c 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -28,6 +28,8 @@ import org.apache.spark.SPARK_VERSION
import org.apache.spark.annotation.Experimental
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.util.Cleaner
@@ -118,7 +120,7 @@ class SparkSession(
* @since 3.4.0
*/
@Experimental
- def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = newDataset {
+ def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = newDataFrame {
builder =>
builder
.setSql(proto.SQL.newBuilder().setQuery(sqlText).putAllArgs(args))
@@ -169,7 +171,7 @@ class SparkSession(
*
* @since 3.4.0
*/
- def range(end: Long): Dataset[Row] = range(0, end)
+ def range(end: Long): Dataset[java.lang.Long] = range(0, end)
/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
@@ -177,7 +179,7 @@ class SparkSession(
*
* @since 3.4.0
*/
- def range(start: Long, end: Long): Dataset[Row] = {
+ def range(start: Long, end: Long): Dataset[java.lang.Long] = {
range(start, end, step = 1)
}
@@ -187,7 +189,7 @@ class SparkSession(
*
* @since 3.4.0
*/
- def range(start: Long, end: Long, step: Long): Dataset[Row] = {
+ def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
range(start, end, step, None)
}
@@ -197,7 +199,7 @@ class SparkSession(
*
* @since 3.4.0
*/
- def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[Row] = {
+ def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
range(start, end, step, Option(numPartitions))
}
@@ -221,8 +223,8 @@ class SparkSession(
start: Long,
end: Long,
step: Long,
- numPartitions: Option[Int]): Dataset[Row] = {
- newDataset { builder =>
+ numPartitions: Option[Int]): Dataset[java.lang.Long] = {
+ newDataset(BoxedLongEncoder) { builder =>
val rangeBuilder = builder.getRangeBuilder
.setStart(start)
.setEnd(end)
@@ -231,12 +233,17 @@ class SparkSession(
}
}
- private[sql] def newDataset[T](f: proto.Relation.Builder => Unit): Dataset[T] = {
+ private[sql] def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = {
+ newDataset(UnboundRowEncoder)(f)
+ }
+
+ private[sql] def newDataset[T](encoder: AgnosticEncoder[T])(
+ f: proto.Relation.Builder => Unit): Dataset[T] = {
val builder = proto.Relation.newBuilder()
f(builder)
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()
- new Dataset[T](this, plan)
+ new Dataset[T](this, plan, encoder)
}
private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = {
@@ -250,9 +257,9 @@ class SparkSession(
mode: proto.Explain.ExplainMode): proto.AnalyzePlanResponse =
client.analyze(plan, mode)
- private[sql] def execute(plan: proto.Plan): SparkResult = {
+ private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
- val result = new SparkResult(value, allocator)
+ val result = new SparkResult(value, allocator, encoder)
cleaner.register(result)
result
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 317c20cad3e..80db558918b 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -26,26 +26,37 @@ import org.apache.arrow.vector.FieldVector
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.connect.proto
-import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
-private[sql] class SparkResult(
+private[sql] class SparkResult[T](
responses: java.util.Iterator[proto.ExecutePlanResponse],
- allocator: BufferAllocator)
+ allocator: BufferAllocator,
+ encoder: AgnosticEncoder[T])
extends AutoCloseable
with Cleanable {
private[this] var numRecords: Int = 0
private[this] var structType: StructType = _
- private[this] var encoder: ExpressionEncoder[Row] = _
+ private[this] var boundEncoder: ExpressionEncoder[T] = _
private[this] val batches = mutable.Buffer.empty[ColumnarBatch]
+ private def createEncoder(schema: StructType): ExpressionEncoder[T] = {
+ val agnosticEncoder = if (encoder == UnboundRowEncoder) {
+ // Create a row encoder based on the schema.
+ RowEncoder.encoderFor(schema).asInstanceOf[AgnosticEncoder[T]]
+ } else {
+ encoder
+ }
+ ExpressionEncoder(agnosticEncoder)
+ }
+
private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = {
while (responses.hasNext) {
val response = responses.next()
@@ -57,7 +68,7 @@ private[sql] class SparkResult(
if (batches.isEmpty) {
structType = ArrowUtils.fromArrowSchema(root.getSchema)
// TODO: create encoders that directly operate on arrow vectors.
- encoder = RowEncoder(structType).resolveAndBind(structType.toAttributes)
+ boundEncoder = createEncoder(structType).resolveAndBind(structType.toAttributes)
}
while (reader.loadNextBatch()) {
val rowCount = root.getRowCount
@@ -108,8 +119,8 @@ private[sql] class SparkResult(
/**
* Create an Array with the contents of the result.
*/
- def toArray: Array[Row] = {
- val result = new Array[Row](length)
+ def toArray: Array[T] = {
+ val result = encoder.clsTag.newArray(length)
val rows = iterator
var i = 0
while (rows.hasNext) {
@@ -123,11 +134,11 @@ private[sql] class SparkResult(
/**
* Returns an iterator over the contents of the result.
*/
- def iterator: java.util.Iterator[Row] with AutoCloseable = {
- new java.util.Iterator[Row] with AutoCloseable {
+ def iterator: java.util.Iterator[T] with AutoCloseable = {
+ new java.util.Iterator[T] with AutoCloseable {
private[this] var batchIndex: Int = -1
private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator()
- private[this] var deserializer: Deserializer[Row] = _
+ private[this] var deserializer: Deserializer[T] = _
override def hasNext: Boolean = {
if (iterator.hasNext) {
return true
@@ -142,13 +153,13 @@ private[sql] class SparkResult(
batchIndex = nextBatchIndex
iterator = batches(nextBatchIndex).rowIterator()
if (deserializer == null) {
- deserializer = encoder.createDeserializer()
+ deserializer = boundEncoder.createDeserializer()
}
}
hasNextBatch
}
- override def next(): Row = {
+ override def next(): T = {
if (!hasNext) {
throw new NoSuchElementException
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 122e7d5d271..debb314f8c3 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -17,16 +17,18 @@
package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, PrintStream}
+import java.nio.file.Files
import scala.collection.JavaConverters._
+import scala.reflect.runtime.universe.TypeTag
import io.grpc.StatusRuntimeException
-import java.nio.file.Files
import org.apache.commons.io.FileUtils
import org.apache.commons.io.output.TeeOutputStream
import org.scalactic.TolerantNumerics
import org.apache.spark.SPARK_VERSION
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
import org.apache.spark.sql.functions.{aggregate, array, col, lit, rand, sequence, shuffle, transform, udf}
import org.apache.spark.sql.types._
@@ -54,9 +56,9 @@ class ClientE2ETestSuite extends RemoteSparkSession {
val df = spark.range(10).limit(3)
val result = df.collect()
assert(result.length == 3)
- assert(result(0).getLong(0) == 0)
- assert(result(1).getLong(0) == 1)
- assert(result(2).getLong(0) == 2)
+ assert(result(0) == 0)
+ assert(result(1) == 1)
+ assert(result(2) == 2)
}
test("simple udf") {
@@ -237,30 +239,40 @@ class ClientE2ETestSuite extends RemoteSparkSession {
checkFragments(result, fragmentsToCheck)
}
- private val simpleSchema = new StructType().add("id", "long", nullable = false)
+ private val simpleSchema = new StructType().add("value", "long", nullable = true)
// Dataset tests
test("Dataset inspection") {
val df = spark.range(10)
- val local = spark.newDataset { builder =>
+ val local = spark.newDataFrame { builder =>
builder.getLocalRelationBuilder.setSchema(simpleSchema.catalogString)
}
assert(!df.isLocal)
assert(local.isLocal)
assert(!df.isStreaming)
- assert(df.toString.contains("[id: bigint]"))
+ assert(df.toString.contains("[value: bigint]"))
assert(df.inputFiles.isEmpty)
}
test("Dataset schema") {
val df = spark.range(10)
assert(df.schema === simpleSchema)
- assert(df.dtypes === Array(("id", "LongType")))
- assert(df.columns === Array("id"))
+ assert(df.dtypes === Array(("value", "LongType")))
+ assert(df.columns === Array("value"))
testCapturedStdOut(df.printSchema(), simpleSchema.treeString)
testCapturedStdOut(df.printSchema(5), simpleSchema.treeString(5))
}
+ test("Dataframe schema") {
+ val df = spark.sql("select * from range(10)")
+ val expectedSchema = new StructType().add("id", "long", nullable = false)
+ assert(df.schema === expectedSchema)
+ assert(df.dtypes === Array(("id", "LongType")))
+ assert(df.columns === Array("id"))
+ testCapturedStdOut(df.printSchema(), expectedSchema.treeString)
+ testCapturedStdOut(df.printSchema(5), expectedSchema.treeString(5))
+ }
+
test("Dataset explain") {
val df = spark.range(10)
val simpleExplainFragments = Seq("== Physical Plan ==")
@@ -282,9 +294,9 @@ class ClientE2ETestSuite extends RemoteSparkSession {
}
test("Dataset result collection") {
- def checkResult(rows: TraversableOnce[Row], expectedValues: Long*): Unit = {
+ def checkResult(rows: TraversableOnce[java.lang.Long], expectedValues: Long*): Unit = {
rows.toIterator.zipAll(expectedValues.iterator, null, null).foreach {
- case (actual, expected) => assert(actual.getLong(0) === expected)
+ case (actual, expected) => assert(actual === expected)
}
}
val df = spark.range(10)
@@ -355,7 +367,11 @@ class ClientE2ETestSuite extends RemoteSparkSession {
implicit val tolerance = TolerantNumerics.tolerantDoubleEquality(0.01)
val df = spark.range(100)
- def checkSample(ds: DataFrame, lower: Double, upper: Double, seed: Long): Unit = {
+ def checkSample(
+ ds: Dataset[java.lang.Long],
+ lower: Double,
+ upper: Double,
+ seed: Long): Unit = {
assert(ds.plan.getRoot.hasSample)
val sample = ds.plan.getRoot.getSample
assert(sample.getSeed === seed)
@@ -375,6 +391,44 @@ class ClientE2ETestSuite extends RemoteSparkSession {
checkSample(datasets.get(3), 6.0 / 10.0, 1.0, 9L)
}
+ test("Dataset count") {
+ assert(spark.range(10).count() === 10)
+ }
+
+ // We can remove this as soon this is added to SQLImplicits.
+ private implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] =
+ ScalaReflection.encoderFor[T]
+
+ test("Dataset collect tuple") {
+ val result = spark
+ .range(3)
+ .select(col("id"), (col("id") % 2).cast("int").as("a"), (col("id") / lit(10.0d)).as("b"))
+ .as[(Long, Int, Double)]
+ .collect()
+ result.zipWithIndex.foreach { case ((id, a, b), i) =>
+ assert(id == i)
+ assert(a == id % 2)
+ assert(b == id / 10.0d)
+ }
+ }
+
+ test("Dataset collect complex type") {
+ val result = spark
+ .range(3)
+ .select(
+ (col("id") / lit(10.0d)).as("b"),
+ col("id"),
+ lit("world").as("d"),
+ (col("id") % 2).cast("int").as("a"))
+ .as[MyType]
+ .collect()
+ result.zipWithIndex.foreach { case (MyType(id, a, b), i) =>
+ assert(id == i)
+ assert(a == id % 2)
+ assert(b == id / 10.0d)
+ }
+ }
+
test("lambda functions") {
// This test is mostly to validate lambda variables are properly resolved.
val result = spark
@@ -447,3 +501,5 @@ class ClientE2ETestSuite extends RemoteSparkSession {
intercept[Exception](spark.conf.set("spark.sql.globalTempDatabase", "/dev/null"))
}
}
+
+private[sql] case class MyType(id: Long, a: Double, b: Double)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 9c07c5abe3c..4a26a32353a 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -69,7 +69,7 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
}
test("write") {
- val df = ss.newDataset(_ => ()).limit(10)
+ val df = ss.newDataFrame(_ => ()).limit(10)
val builder = proto.WriteOperation.newBuilder()
builder
@@ -101,7 +101,7 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
}
test("write V2") {
- val df = ss.newDataset(_ => ()).limit(10)
+ val df = ss.newDataFrame(_ => ()).limit(10)
val builder = proto.WriteOperationV2.newBuilder()
builder
@@ -129,7 +129,7 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
}
test("Pivot") {
- val df = ss.newDataset(_ => ())
+ val df = ss.newDataFrame(_ => ())
intercept[IllegalArgumentException] {
df.groupBy().pivot(Column("c"), Seq(Column("col")))
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 6a789b1494f..67ea148cb87 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -215,7 +215,7 @@ class PlanGenerationTestSuite
private val temporalsSchemaString = temporalsSchema.catalogString
- private def createLocalRelation(schema: String): DataFrame = session.newDataset { builder =>
+ private def createLocalRelation(schema: String): DataFrame = session.newDataFrame { builder =>
// TODO API is not consistent. Now we have two different ways of working with schemas!
builder.getLocalRelationBuilder.setSchema(schema)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 1a3c1089649..24c8bad5c2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -107,13 +107,20 @@ object AgnosticEncoders {
override def dataType: DataType = schema
}
- case class RowEncoder(fields: Seq[EncoderField]) extends AgnosticEncoder[Row] {
+ abstract class BaseRowEncoder extends AgnosticEncoder[Row] {
override def isPrimitive: Boolean = false
- override val schema: StructType = StructType(fields.map(_.structField))
override def dataType: DataType = schema
override def clsTag: ClassTag[Row] = classTag[Row]
}
+ case class RowEncoder(fields: Seq[EncoderField]) extends BaseRowEncoder {
+ override val schema: StructType = StructType(fields.map(_.structField))
+ }
+
+ object UnboundRowEncoder extends BaseRowEncoder {
+ override val schema: StructType = new StructType()
+ }
+
case class JavaBeanEncoder[K](
override val clsTag: ClassTag[K],
fields: Seq[EncoderField])
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org