You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/27 13:15:23 UTC

[spark] branch master updated: [SPARK-42580][CONNECT] Scala client add client side typed APIs

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5243d0be2c1 [SPARK-42580][CONNECT] Scala client add client side typed APIs
5243d0be2c1 is described below

commit 5243d0be2c15e3af36e981a9487ea600ab86a808
Author: Herman van Hovell <he...@databricks.com>
AuthorDate: Mon Feb 27 09:15:09 2023 -0400

    [SPARK-42580][CONNECT] Scala client add client side typed APIs
    
    ### What changes were proposed in this pull request?
    This PR adds the client side typed API to the Spark Connect Scala Client.
    
    ### Why are the changes needed?
    We want to reach API parity with the existing APIs.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, it adds user API.
    
    ### How was this patch tested?
    Added tests to `ClientE2ETestSuite`, and updated existing tests.
    
    Closes #40175 from hvanhovell/SPARK-42580.
    
    Authored-by: Herman van Hovell <he...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../org/apache/spark/sql/DataFrameReader.scala     |   4 +-
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 122 +++++++++++++--------
 .../spark/sql/RelationalGroupedDataset.scala       |   2 +-
 .../scala/org/apache/spark/sql/SparkSession.scala  |  29 +++--
 .../spark/sql/connect/client/SparkResult.scala     |  37 ++++---
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  80 ++++++++++++--
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |   6 +-
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   2 +-
 .../sql/catalyst/encoders/AgnosticEncoder.scala    |  11 +-
 9 files changed, 204 insertions(+), 89 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 5a486efee31..3e17b03173b 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -171,7 +171,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
    */
   @scala.annotation.varargs
   def load(paths: String*): DataFrame = {
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataFrame { builder =>
       val dataSourceBuilder = builder.getReadBuilder.getDataSourceBuilder
       assertSourceFormatSpecified()
       dataSourceBuilder.setFormat(source)
@@ -308,7 +308,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
    * @since 3.4.0
    */
   def table(tableName: String): DataFrame = {
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataFrame { builder =>
       builder.getReadBuilder.getNamedTableBuilder.setUnparsedIdentifier(tableName)
     }
   }
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index dcc770dfe55..73de35456fc 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -23,6 +23,8 @@ import scala.collection.mutable
 import scala.util.control.NonFatal
 
 import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder, StringEncoder, UnboundRowEncoder}
 import org.apache.spark.sql.catalyst.expressions.RowOrdering
 import org.apache.spark.sql.connect.client.SparkResult
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
@@ -116,7 +118,10 @@ import org.apache.spark.util.Utils
  *
  * @since 3.4.0
  */
-class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val plan: proto.Plan)
+class Dataset[T] private[sql] (
+    val sparkSession: SparkSession,
+    private[sql] val plan: proto.Plan,
+    val encoder: AgnosticEncoder[T])
     extends Serializable {
   // Make sure we don't forget to set plan id.
   assert(plan.getRoot.getCommon.hasPlanId)
@@ -151,9 +156,32 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @group basic
    * @since 3.4.0
    */
-  def toDF(): DataFrame = {
-    // Note this will change as soon as we add the typed APIs.
-    this.asInstanceOf[Dataset[Row]]
+  def toDF(): DataFrame = new Dataset(sparkSession, plan, UnboundRowEncoder)
+
+  /**
+   * Returns a new Dataset where each record has been mapped on to the specified type. The method
+   * used to map columns depend on the type of `U`: <ul> <li>When `U` is a class, fields for the
+   * class will be mapped to columns of the same name (case sensitivity is determined by
+   * `spark.sql.caseSensitive`).</li> <li>When `U` is a tuple, the columns will be mapped by
+   * ordinal (i.e. the first column will be assigned to `_1`).</li> <li>When `U` is a primitive
+   * type (i.e. String, Int, etc), then the first column of the `DataFrame` will be used.</li>
+   * </ul>
+   *
+   * If the schema of the Dataset does not match the desired `U` type, you can use `select` along
+   * with `alias` or `as` to rearrange or rename as required.
+   *
+   * Note that `as[]` only changes the view of the data that is passed into typed operations, such
+   * as `map()`, and does not eagerly project away any columns that are not present in the
+   * specified class.
+   *
+   * @group basic
+   * @since 3.4.0
+   */
+  def as[U: Encoder]: Dataset[U] = {
+    val encoder = implicitly[Encoder[U]].asInstanceOf[AgnosticEncoder[U]]
+    // We should add some validation/coercion here. We cannot use `to`
+    // because that does not work with positional arguments.
+    new Dataset[U](sparkSession, plan, encoder)
   }
 
   /**
@@ -170,7 +198,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   @scala.annotation.varargs
-  def toDF(colNames: String*): DataFrame = sparkSession.newDataset { builder =>
+  def toDF(colNames: String*): DataFrame = sparkSession.newDataFrame { builder =>
     builder.getToDfBuilder
       .setInput(plan.getRoot)
       .addAllColumnNames(colNames.asJava)
@@ -192,7 +220,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @group basic
    * @since 3.4.0
    */
-  def to(schema: StructType): DataFrame = sparkSession.newDataset { builder =>
+  def to(schema: StructType): DataFrame = sparkSession.newDataFrame { builder =>
     builder.getToSchemaBuilder
       .setInput(plan.getRoot)
       .setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
@@ -205,7 +233,11 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   def schema: StructType = {
-    DataTypeProtoConverter.toCatalystType(analyze.getSchema).asInstanceOf[StructType]
+    if (encoder == UnboundRowEncoder) {
+      DataTypeProtoConverter.toCatalystType(analyze.getSchema).asInstanceOf[StructType]
+    } else {
+      encoder.schema
+    }
   }
 
   /**
@@ -469,7 +501,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   def show(numRows: Int, truncate: Int, vertical: Boolean): Unit = {
-    val df = sparkSession.newDataset { builder =>
+    val df = sparkSession.newDataset(StringEncoder) { builder =>
       builder.getShowStringBuilder
         .setInput(plan.getRoot)
         .setNumRows(numRows)
@@ -480,13 +512,13 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
       assert(result.length == 1)
       assert(result.schema.size == 1)
       // scalastyle:off println
-      println(result.toArray.head.getString(0))
+      println(result.toArray.head)
     // scalastyle:on println
     }
   }
 
   private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = {
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataFrame { builder =>
       val joinBuilder = builder.getJoinBuilder
       joinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
       f(joinBuilder)
@@ -752,7 +784,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
   }
 
   private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataset(encoder) { builder =>
       builder.getSortBuilder
         .setInput(plan.getRoot)
         .setIsGlobal(global)
@@ -860,11 +892,12 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   @scala.annotation.varargs
-  def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset { builder =>
-    builder.getHintBuilder
-      .setInput(plan.getRoot)
-      .setName(name)
-      .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
+  def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset(encoder) {
+    builder =>
+      builder.getHintBuilder
+        .setInput(plan.getRoot)
+        .setName(name)
+        .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
   }
 
   /**
@@ -900,7 +933,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @group typedrel
    * @since 3.4.0
    */
-  def as(alias: String): Dataset[T] = sparkSession.newDataset { builder =>
+  def as(alias: String): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
     builder.getSubqueryAliasBuilder
       .setInput(plan.getRoot)
       .setAlias(alias)
@@ -940,7 +973,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   @scala.annotation.varargs
-  def select(cols: Column*): DataFrame = sparkSession.newDataset { builder =>
+  def select(cols: Column*): DataFrame = sparkSession.newDataFrame { builder =>
     builder.getProjectBuilder
       .setInput(plan.getRoot)
       .addAllExpressions(cols.map(_.expr).asJava)
@@ -990,7 +1023,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @group typedrel
    * @since 3.4.0
    */
-  def filter(condition: Column): Dataset[T] = sparkSession.newDataset { builder =>
+  def filter(condition: Column): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
     builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
   }
 
@@ -1033,7 +1066,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
       ids: Array[Column],
       valuesOption: Option[Array[Column]],
       variableColumnName: String,
-      valueColumnName: String): DataFrame = sparkSession.newDataset { builder =>
+      valueColumnName: String): DataFrame = sparkSession.newDataFrame { builder =>
     val unpivot = builder.getUnpivotBuilder
       .setInput(plan.getRoot)
       .addAllIds(ids.toSeq.map(_.expr).asJava)
@@ -1423,7 +1456,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @group typedrel
    * @since 3.4.0
    */
-  def limit(n: Int): Dataset[T] = sparkSession.newDataset { builder =>
+  def limit(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
     builder.getLimitBuilder
       .setInput(plan.getRoot)
       .setLimit(n)
@@ -1435,7 +1468,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @group typedrel
    * @since 3.4.0
    */
-  def offset(n: Int): Dataset[T] = sparkSession.newDataset { builder =>
+  def offset(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
     builder.getOffsetBuilder
       .setInput(plan.getRoot)
       .setOffset(n)
@@ -1443,7 +1476,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
 
   private def buildSetOp(right: Dataset[T], setOpType: proto.SetOperation.SetOpType)(
       f: proto.SetOperation.Builder => Unit): Dataset[T] = {
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataset(encoder) { builder =>
       f(
         builder.getSetOpBuilder
           .setSetOpType(setOpType)
@@ -1707,7 +1740,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataset(encoder) { builder =>
       builder.getSampleBuilder
         .setInput(plan.getRoot)
         .setWithReplacement(withReplacement)
@@ -1775,7 +1808,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
     normalizedCumWeights
       .sliding(2)
       .map { case Array(low, high) =>
-        sparkSession.newDataset[T] { builder =>
+        sparkSession.newDataset(encoder) { builder =>
           builder.getSampleBuilder
             .setInput(sortedInput)
             .setWithReplacement(false)
@@ -1819,7 +1852,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
     val aliases = values.zip(names).map { case (value, name) =>
       value.name(name).expr.getAlias
     }
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataFrame { builder =>
       builder.getWithColumnsBuilder
         .setInput(plan.getRoot)
         .addAllAliases(aliases.asJava)
@@ -1910,7 +1943,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   def withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame = {
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataFrame { builder =>
       builder.getWithColumnsRenamedBuilder
         .setInput(plan.getRoot)
         .putAllRenameColumnsMap(colsMap)
@@ -1929,7 +1962,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
       .setExpr(col(columnName).expr)
       .addName(columnName)
       .setMetadata(metadata.json)
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataFrame { builder =>
       builder.getWithColumnsBuilder
         .setInput(plan.getRoot)
         .addAliases(newAlias)
@@ -2083,7 +2116,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
   @scala.annotation.varargs
   def drop(col: Column, cols: Column*): DataFrame = buildDrop(col +: cols)
 
-  private def buildDrop(cols: Seq[Column]): DataFrame = sparkSession.newDataset { builder =>
+  private def buildDrop(cols: Seq[Column]): DataFrame = sparkSession.newDataFrame { builder =>
     builder.getDropBuilder
       .setInput(plan.getRoot)
       .addAllCols(cols.map(_.expr).asJava)
@@ -2096,7 +2129,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @group typedrel
    * @since 3.4.0
    */
-  def dropDuplicates(): Dataset[T] = sparkSession.newDataset { builder =>
+  def dropDuplicates(): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
     builder.getDeduplicateBuilder
       .setInput(plan.getRoot)
       .setAllColumnsAsKeys(true)
@@ -2109,10 +2142,11 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @group typedrel
    * @since 3.4.0
    */
-  def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset { builder =>
-    builder.getDeduplicateBuilder
-      .setInput(plan.getRoot)
-      .addAllColumnNames(colNames.asJava)
+  def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset(encoder) {
+    builder =>
+      builder.getDeduplicateBuilder
+        .setInput(plan.getRoot)
+        .addAllColumnNames(colNames.asJava)
   }
 
   /**
@@ -2166,7 +2200,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   @scala.annotation.varargs
-  def describe(cols: String*): DataFrame = sparkSession.newDataset { builder =>
+  def describe(cols: String*): DataFrame = sparkSession.newDataFrame { builder =>
     builder.getDescribeBuilder
       .setInput(plan.getRoot)
       .addAllCols(cols.asJava)
@@ -2241,7 +2275,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   @scala.annotation.varargs
-  def summary(statistics: String*): DataFrame = sparkSession.newDataset { builder =>
+  def summary(statistics: String*): DataFrame = sparkSession.newDataFrame { builder =>
     builder.getSummaryBuilder
       .setInput(plan.getRoot)
       .addAllStatistics(statistics.asJava)
@@ -2309,7 +2343,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   def tail(n: Int): Array[T] = {
-    val lastN = sparkSession.newDataset[T] { builder =>
+    val lastN = sparkSession.newDataset(encoder) { builder =>
       builder.getTailBuilder
         .setInput(plan.getRoot)
         .setLimit(n)
@@ -2340,7 +2374,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   def collect(): Array[T] = withResult { result =>
-    result.toArray.asInstanceOf[Array[T]]
+    result.toArray
   }
 
   /**
@@ -2368,7 +2402,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    */
   def toLocalIterator(): java.util.Iterator[T] = {
     // TODO make this a destructive iterator.
-    collectResult().iterator.asInstanceOf[java.util.Iterator[T]]
+    collectResult().iterator
   }
 
   /**
@@ -2377,11 +2411,11 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
    * @since 3.4.0
    */
   def count(): Long = {
-    groupBy().count().collect().head.getLong(0)
+    groupBy().count().as(PrimitiveLongEncoder).collect().head
   }
 
   private def buildRepartition(numPartitions: Int, shuffle: Boolean): Dataset[T] = {
-    sparkSession.newDataset { builder =>
+    sparkSession.newDataset(encoder) { builder =>
       builder.getRepartitionBuilder
         .setInput(plan.getRoot)
         .setNumPartitions(numPartitions)
@@ -2391,7 +2425,7 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
 
   private def buildRepartitionByExpression(
       numPartitions: Option[Int],
-      partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset { builder =>
+      partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
     val repartitionBuilder = builder.getRepartitionByExpressionBuilder
       .setInput(plan.getRoot)
       .addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
@@ -2651,9 +2685,9 @@ class Dataset[T] private[sql] (val sparkSession: SparkSession, private[sql] val
     sparkSession.analyze(plan, proto.Explain.ExplainMode.SIMPLE)
   }
 
-  def collectResult(): SparkResult = sparkSession.execute(plan)
+  def collectResult(): SparkResult[T] = sparkSession.execute(plan, encoder)
 
-  private[sql] def withResult[E](f: SparkResult => E): E = {
+  private[sql] def withResult[E](f: SparkResult[T] => E): E = {
     val result = collectResult()
     try f(result)
     finally {
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 76d3ab5cf09..89bc5bfec57 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -42,7 +42,7 @@ class RelationalGroupedDataset protected[sql] (
     pivot: Option[proto.Aggregate.Pivot] = None) {
 
   private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
-    df.sparkSession.newDataset { builder =>
+    df.sparkSession.newDataFrame { builder =>
       builder.getAggregateBuilder
         .setInput(df.plan.getRoot)
         .addAllGroupingExpressions(groupingExprs.asJava)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index e85c7008ca9..3aed781855c 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -28,6 +28,8 @@ import org.apache.spark.SPARK_VERSION
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder}
 import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
 import org.apache.spark.sql.connect.client.util.Cleaner
 
@@ -118,7 +120,7 @@ class SparkSession(
    * @since 3.4.0
    */
   @Experimental
-  def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = newDataset {
+  def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = newDataFrame {
     builder =>
       builder
         .setSql(proto.SQL.newBuilder().setQuery(sqlText).putAllArgs(args))
@@ -169,7 +171,7 @@ class SparkSession(
    *
    * @since 3.4.0
    */
-  def range(end: Long): Dataset[Row] = range(0, end)
+  def range(end: Long): Dataset[java.lang.Long] = range(0, end)
 
   /**
    * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
@@ -177,7 +179,7 @@ class SparkSession(
    *
    * @since 3.4.0
    */
-  def range(start: Long, end: Long): Dataset[Row] = {
+  def range(start: Long, end: Long): Dataset[java.lang.Long] = {
     range(start, end, step = 1)
   }
 
@@ -187,7 +189,7 @@ class SparkSession(
    *
    * @since 3.4.0
    */
-  def range(start: Long, end: Long, step: Long): Dataset[Row] = {
+  def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
     range(start, end, step, None)
   }
 
@@ -197,7 +199,7 @@ class SparkSession(
    *
    * @since 3.4.0
    */
-  def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[Row] = {
+  def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
     range(start, end, step, Option(numPartitions))
   }
 
@@ -221,8 +223,8 @@ class SparkSession(
       start: Long,
       end: Long,
       step: Long,
-      numPartitions: Option[Int]): Dataset[Row] = {
-    newDataset { builder =>
+      numPartitions: Option[Int]): Dataset[java.lang.Long] = {
+    newDataset(BoxedLongEncoder) { builder =>
       val rangeBuilder = builder.getRangeBuilder
         .setStart(start)
         .setEnd(end)
@@ -231,12 +233,17 @@ class SparkSession(
     }
   }
 
-  private[sql] def newDataset[T](f: proto.Relation.Builder => Unit): Dataset[T] = {
+  private[sql] def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = {
+    newDataset(UnboundRowEncoder)(f)
+  }
+
+  private[sql] def newDataset[T](encoder: AgnosticEncoder[T])(
+      f: proto.Relation.Builder => Unit): Dataset[T] = {
     val builder = proto.Relation.newBuilder()
     f(builder)
     builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
     val plan = proto.Plan.newBuilder().setRoot(builder).build()
-    new Dataset[T](this, plan)
+    new Dataset[T](this, plan, encoder)
   }
 
   private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = {
@@ -250,9 +257,9 @@ class SparkSession(
       mode: proto.Explain.ExplainMode): proto.AnalyzePlanResponse =
     client.analyze(plan, mode)
 
-  private[sql] def execute(plan: proto.Plan): SparkResult = {
+  private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
     val value = client.execute(plan)
-    val result = new SparkResult(value, allocator)
+    val result = new SparkResult(value, allocator, encoder)
     cleaner.register(result)
     result
   }
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 317c20cad3e..80db558918b 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -26,26 +26,37 @@ import org.apache.arrow.vector.FieldVector
 import org.apache.arrow.vector.ipc.ArrowStreamReader
 
 import org.apache.spark.connect.proto
-import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
 import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
 
-private[sql] class SparkResult(
+private[sql] class SparkResult[T](
     responses: java.util.Iterator[proto.ExecutePlanResponse],
-    allocator: BufferAllocator)
+    allocator: BufferAllocator,
+    encoder: AgnosticEncoder[T])
     extends AutoCloseable
     with Cleanable {
 
   private[this] var numRecords: Int = 0
   private[this] var structType: StructType = _
-  private[this] var encoder: ExpressionEncoder[Row] = _
+  private[this] var boundEncoder: ExpressionEncoder[T] = _
   private[this] val batches = mutable.Buffer.empty[ColumnarBatch]
 
+  private def createEncoder(schema: StructType): ExpressionEncoder[T] = {
+    val agnosticEncoder = if (encoder == UnboundRowEncoder) {
+      // Create a row encoder based on the schema.
+      RowEncoder.encoderFor(schema).asInstanceOf[AgnosticEncoder[T]]
+    } else {
+      encoder
+    }
+    ExpressionEncoder(agnosticEncoder)
+  }
+
   private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = {
     while (responses.hasNext) {
       val response = responses.next()
@@ -57,7 +68,7 @@ private[sql] class SparkResult(
           if (batches.isEmpty) {
             structType = ArrowUtils.fromArrowSchema(root.getSchema)
             // TODO: create encoders that directly operate on arrow vectors.
-            encoder = RowEncoder(structType).resolveAndBind(structType.toAttributes)
+            boundEncoder = createEncoder(structType).resolveAndBind(structType.toAttributes)
           }
           while (reader.loadNextBatch()) {
             val rowCount = root.getRowCount
@@ -108,8 +119,8 @@ private[sql] class SparkResult(
   /**
    * Create an Array with the contents of the result.
    */
-  def toArray: Array[Row] = {
-    val result = new Array[Row](length)
+  def toArray: Array[T] = {
+    val result = encoder.clsTag.newArray(length)
     val rows = iterator
     var i = 0
     while (rows.hasNext) {
@@ -123,11 +134,11 @@ private[sql] class SparkResult(
   /**
    * Returns an iterator over the contents of the result.
    */
-  def iterator: java.util.Iterator[Row] with AutoCloseable = {
-    new java.util.Iterator[Row] with AutoCloseable {
+  def iterator: java.util.Iterator[T] with AutoCloseable = {
+    new java.util.Iterator[T] with AutoCloseable {
       private[this] var batchIndex: Int = -1
       private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator()
-      private[this] var deserializer: Deserializer[Row] = _
+      private[this] var deserializer: Deserializer[T] = _
       override def hasNext: Boolean = {
         if (iterator.hasNext) {
           return true
@@ -142,13 +153,13 @@ private[sql] class SparkResult(
           batchIndex = nextBatchIndex
           iterator = batches(nextBatchIndex).rowIterator()
           if (deserializer == null) {
-            deserializer = encoder.createDeserializer()
+            deserializer = boundEncoder.createDeserializer()
           }
         }
         hasNextBatch
       }
 
-      override def next(): Row = {
+      override def next(): T = {
         if (!hasNext) {
           throw new NoSuchElementException
         }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 122e7d5d271..debb314f8c3 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -17,16 +17,18 @@
 package org.apache.spark.sql
 
 import java.io.{ByteArrayOutputStream, PrintStream}
+import java.nio.file.Files
 
 import scala.collection.JavaConverters._
+import scala.reflect.runtime.universe.TypeTag
 
 import io.grpc.StatusRuntimeException
-import java.nio.file.Files
 import org.apache.commons.io.FileUtils
 import org.apache.commons.io.output.TeeOutputStream
 import org.scalactic.TolerantNumerics
 
 import org.apache.spark.SPARK_VERSION
+import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
 import org.apache.spark.sql.functions.{aggregate, array, col, lit, rand, sequence, shuffle, transform, udf}
 import org.apache.spark.sql.types._
@@ -54,9 +56,9 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     val df = spark.range(10).limit(3)
     val result = df.collect()
     assert(result.length == 3)
-    assert(result(0).getLong(0) == 0)
-    assert(result(1).getLong(0) == 1)
-    assert(result(2).getLong(0) == 2)
+    assert(result(0) == 0)
+    assert(result(1) == 1)
+    assert(result(2) == 2)
   }
 
   test("simple udf") {
@@ -237,30 +239,40 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     checkFragments(result, fragmentsToCheck)
   }
 
-  private val simpleSchema = new StructType().add("id", "long", nullable = false)
+  private val simpleSchema = new StructType().add("value", "long", nullable = true)
 
   // Dataset tests
   test("Dataset inspection") {
     val df = spark.range(10)
-    val local = spark.newDataset { builder =>
+    val local = spark.newDataFrame { builder =>
       builder.getLocalRelationBuilder.setSchema(simpleSchema.catalogString)
     }
     assert(!df.isLocal)
     assert(local.isLocal)
     assert(!df.isStreaming)
-    assert(df.toString.contains("[id: bigint]"))
+    assert(df.toString.contains("[value: bigint]"))
     assert(df.inputFiles.isEmpty)
   }
 
   test("Dataset schema") {
     val df = spark.range(10)
     assert(df.schema === simpleSchema)
-    assert(df.dtypes === Array(("id", "LongType")))
-    assert(df.columns === Array("id"))
+    assert(df.dtypes === Array(("value", "LongType")))
+    assert(df.columns === Array("value"))
     testCapturedStdOut(df.printSchema(), simpleSchema.treeString)
     testCapturedStdOut(df.printSchema(5), simpleSchema.treeString(5))
   }
 
+  test("Dataframe schema") {
+    val df = spark.sql("select * from range(10)")
+    val expectedSchema = new StructType().add("id", "long", nullable = false)
+    assert(df.schema === expectedSchema)
+    assert(df.dtypes === Array(("id", "LongType")))
+    assert(df.columns === Array("id"))
+    testCapturedStdOut(df.printSchema(), expectedSchema.treeString)
+    testCapturedStdOut(df.printSchema(5), expectedSchema.treeString(5))
+  }
+
   test("Dataset explain") {
     val df = spark.range(10)
     val simpleExplainFragments = Seq("== Physical Plan ==")
@@ -282,9 +294,9 @@ class ClientE2ETestSuite extends RemoteSparkSession {
   }
 
   test("Dataset result collection") {
-    def checkResult(rows: TraversableOnce[Row], expectedValues: Long*): Unit = {
+    def checkResult(rows: TraversableOnce[java.lang.Long], expectedValues: Long*): Unit = {
       rows.toIterator.zipAll(expectedValues.iterator, null, null).foreach {
-        case (actual, expected) => assert(actual.getLong(0) === expected)
+        case (actual, expected) => assert(actual === expected)
       }
     }
     val df = spark.range(10)
@@ -355,7 +367,11 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     implicit val tolerance = TolerantNumerics.tolerantDoubleEquality(0.01)
 
     val df = spark.range(100)
-    def checkSample(ds: DataFrame, lower: Double, upper: Double, seed: Long): Unit = {
+    def checkSample(
+        ds: Dataset[java.lang.Long],
+        lower: Double,
+        upper: Double,
+        seed: Long): Unit = {
       assert(ds.plan.getRoot.hasSample)
       val sample = ds.plan.getRoot.getSample
       assert(sample.getSeed === seed)
@@ -375,6 +391,44 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     checkSample(datasets.get(3), 6.0 / 10.0, 1.0, 9L)
   }
 
+  test("Dataset count") {
+    assert(spark.range(10).count() === 10)
+  }
+
+  // We can remove this as soon this is added to SQLImplicits.
+  private implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] =
+    ScalaReflection.encoderFor[T]
+
+  test("Dataset collect tuple") {
+    val result = spark
+      .range(3)
+      .select(col("id"), (col("id") % 2).cast("int").as("a"), (col("id") / lit(10.0d)).as("b"))
+      .as[(Long, Int, Double)]
+      .collect()
+    result.zipWithIndex.foreach { case ((id, a, b), i) =>
+      assert(id == i)
+      assert(a == id % 2)
+      assert(b == id / 10.0d)
+    }
+  }
+
+  test("Dataset collect complex type") {
+    val result = spark
+      .range(3)
+      .select(
+        (col("id") / lit(10.0d)).as("b"),
+        col("id"),
+        lit("world").as("d"),
+        (col("id") % 2).cast("int").as("a"))
+      .as[MyType]
+      .collect()
+    result.zipWithIndex.foreach { case (MyType(id, a, b), i) =>
+      assert(id == i)
+      assert(a == id % 2)
+      assert(b == id / 10.0d)
+    }
+  }
+
   test("lambda functions") {
     // This test is mostly to validate lambda variables are properly resolved.
     val result = spark
@@ -447,3 +501,5 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     intercept[Exception](spark.conf.set("spark.sql.globalTempDatabase", "/dev/null"))
   }
 }
+
+private[sql] case class MyType(id: Long, a: Double, b: Double)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 9c07c5abe3c..4a26a32353a 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -69,7 +69,7 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
   }
 
   test("write") {
-    val df = ss.newDataset(_ => ()).limit(10)
+    val df = ss.newDataFrame(_ => ()).limit(10)
 
     val builder = proto.WriteOperation.newBuilder()
     builder
@@ -101,7 +101,7 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
   }
 
   test("write V2") {
-    val df = ss.newDataset(_ => ()).limit(10)
+    val df = ss.newDataFrame(_ => ()).limit(10)
 
     val builder = proto.WriteOperationV2.newBuilder()
     builder
@@ -129,7 +129,7 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
   }
 
   test("Pivot") {
-    val df = ss.newDataset(_ => ())
+    val df = ss.newDataFrame(_ => ())
     intercept[IllegalArgumentException] {
       df.groupBy().pivot(Column("c"), Seq(Column("col")))
     }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 6a789b1494f..67ea148cb87 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -215,7 +215,7 @@ class PlanGenerationTestSuite
 
   private val temporalsSchemaString = temporalsSchema.catalogString
 
-  private def createLocalRelation(schema: String): DataFrame = session.newDataset { builder =>
+  private def createLocalRelation(schema: String): DataFrame = session.newDataFrame { builder =>
     // TODO API is not consistent. Now we have two different ways of working with schemas!
     builder.getLocalRelationBuilder.setSchema(schema)
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 1a3c1089649..24c8bad5c2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -107,13 +107,20 @@ object AgnosticEncoders {
     override def dataType: DataType = schema
   }
 
-  case class RowEncoder(fields: Seq[EncoderField]) extends AgnosticEncoder[Row] {
+  abstract class BaseRowEncoder extends AgnosticEncoder[Row] {
     override def isPrimitive: Boolean = false
-    override val schema: StructType = StructType(fields.map(_.structField))
     override def dataType: DataType = schema
     override def clsTag: ClassTag[Row] = classTag[Row]
   }
 
+  case class RowEncoder(fields: Seq[EncoderField]) extends BaseRowEncoder {
+    override val schema: StructType = StructType(fields.map(_.structField))
+  }
+
+  object UnboundRowEncoder extends BaseRowEncoder {
+    override val schema: StructType = new StructType()
+  }
+
   case class JavaBeanEncoder[K](
       override val clsTag: ClassTag[K],
       fields: Seq[EncoderField])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org