You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/12/08 19:26:02 UTC

spark git commit: [SPARK-12188][SQL] Code refactoring and comment correction in Dataset APIs

Repository: spark
Updated Branches:
  refs/heads/master c0b13d556 -> 5d96a710a


[SPARK-12188][SQL] Code refactoring and comment correction in Dataset APIs

This PR contains the following updates:

- Created a new private variable `boundTEncoder` that can be shared by multiple functions, `RDD`, `select` and `collect`.
- Replaced all the `queryExecution.analyzed` by the function call `logicalPlan`
- A few API comments are using wrong class names (e.g., `DataFrame`) or parameter names (e.g., `n`)
- A few API descriptions are wrong. (e.g., `mapPartitions`)

marmbrus rxin cloud-fan Could you take a look and check if they are appropriate? Thank you!

Author: gatorsmile <ga...@gmail.com>

Closes #10184 from gatorsmile/datasetClean.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5d96a710
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5d96a710
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5d96a710

Branch: refs/heads/master
Commit: 5d96a710a5ed543ec81e383620fc3b2a808b26a1
Parents: c0b13d5
Author: gatorsmile <ga...@gmail.com>
Authored: Tue Dec 8 10:25:57 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Dec 8 10:25:57 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    | 80 ++++++++++----------
 1 file changed, 40 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5d96a710/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index d6bb1d2..3bd18a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -67,15 +67,21 @@ class Dataset[T] private[sql](
     tEncoder: Encoder[T]) extends Queryable with Serializable {
 
   /**
-   * An unresolved version of the internal encoder for the type of this dataset.  This one is marked
-   * implicit so that we can use it when constructing new [[Dataset]] objects that have the same
-   * object type (that will be possibly resolved to a different schema).
+   * An unresolved version of the internal encoder for the type of this [[Dataset]].  This one is
+   * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
+   * same object type (that will be possibly resolved to a different schema).
    */
   private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
 
   /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
   private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
-    unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
+    unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
+
+  /**
+   * The encoder where the expressions used to construct an object from an input row have been
+   * bound to the ordinals of the given schema.
+   */
+  private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
 
   private implicit def classTag = resolvedTEncoder.clsTag
 
@@ -89,7 +95,7 @@ class Dataset[T] private[sql](
   override def schema: StructType = resolvedTEncoder.schema
 
   /**
-   * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format.
+   * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
    * @since 1.6.0
    */
   override def printSchema(): Unit = toDF().printSchema()
@@ -111,7 +117,7 @@ class Dataset[T] private[sql](
    * ************* */
 
   /**
-   * Returns a new `Dataset` where each record has been mapped on to the specified type.  The
+   * Returns a new [[Dataset]] where each record has been mapped on to the specified type.  The
    * method used to map columns depend on the type of `U`:
    *  - When `U` is a class, fields for the class will be mapped to columns of the same name
    *    (case sensitivity is determined by `spark.sql.caseSensitive`)
@@ -145,7 +151,7 @@ class Dataset[T] private[sql](
   def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
 
   /**
-   * Returns this Dataset.
+   * Returns this [[Dataset]].
    * @since 1.6.0
    */
   // This is declared with parentheses to prevent the Scala compiler from treating
@@ -153,15 +159,12 @@ class Dataset[T] private[sql](
   def toDS(): Dataset[T] = this
 
   /**
-   * Converts this Dataset to an RDD.
+   * Converts this [[Dataset]] to an [[RDD]].
    * @since 1.6.0
    */
   def rdd: RDD[T] = {
-    val tEnc = resolvedTEncoder
-    val input = queryExecution.analyzed.output
     queryExecution.toRdd.mapPartitions { iter =>
-      val bound = tEnc.bind(input)
-      iter.map(bound.fromRow)
+      iter.map(boundTEncoder.fromRow)
     }
   }
 
@@ -189,7 +192,7 @@ class Dataset[T] private[sql](
   def show(numRows: Int): Unit = show(numRows, truncate = true)
 
   /**
-   * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters
+   * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
    * will be truncated, and all cells will be aligned right.
    *
    * @since 1.6.0
@@ -197,7 +200,7 @@ class Dataset[T] private[sql](
   def show(): Unit = show(20)
 
   /**
-   * Displays the top 20 rows of [[DataFrame]] in a tabular form.
+   * Displays the top 20 rows of [[Dataset]] in a tabular form.
    *
    * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
    *              be truncated and all cells will be aligned right
@@ -207,7 +210,7 @@ class Dataset[T] private[sql](
   def show(truncate: Boolean): Unit = show(20, truncate)
 
   /**
-   * Displays the [[DataFrame]] in a tabular form. For example:
+   * Displays the [[Dataset]] in a tabular form. For example:
    * {{{
    *   year  month AVG('Adj Close) MAX('Adj Close)
    *   1980  12    0.503218        0.595103
@@ -291,7 +294,7 @@ class Dataset[T] private[sql](
 
   /**
    * (Scala-specific)
-   * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+   * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
    * @since 1.6.0
    */
   def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
@@ -307,7 +310,7 @@ class Dataset[T] private[sql](
 
   /**
    * (Java-specific)
-   * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+   * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
    * @since 1.6.0
    */
   def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
@@ -341,28 +344,28 @@ class Dataset[T] private[sql](
 
   /**
    * (Scala-specific)
-   * Runs `func` on each element of this Dataset.
+   * Runs `func` on each element of this [[Dataset]].
    * @since 1.6.0
    */
   def foreach(func: T => Unit): Unit = rdd.foreach(func)
 
   /**
    * (Java-specific)
-   * Runs `func` on each element of this Dataset.
+   * Runs `func` on each element of this [[Dataset]].
    * @since 1.6.0
    */
   def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
 
   /**
    * (Scala-specific)
-   * Runs `func` on each partition of this Dataset.
+   * Runs `func` on each partition of this [[Dataset]].
    * @since 1.6.0
    */
   def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
 
   /**
    * (Java-specific)
-   * Runs `func` on each partition of this Dataset.
+   * Runs `func` on each partition of this [[Dataset]].
    * @since 1.6.0
    */
   def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
@@ -374,7 +377,7 @@ class Dataset[T] private[sql](
 
   /**
    * (Scala-specific)
-   * Reduces the elements of this Dataset using the specified binary function.  The given function
+   * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
    * must be commutative and associative or the result may be non-deterministic.
    * @since 1.6.0
    */
@@ -382,7 +385,7 @@ class Dataset[T] private[sql](
 
   /**
    * (Java-specific)
-   * Reduces the elements of this Dataset using the specified binary function.  The given function
+   * Reduces the elements of this Dataset using the specified binary function.  The given `func`
    * must be commutative and associative or the result may be non-deterministic.
    * @since 1.6.0
    */
@@ -390,11 +393,11 @@ class Dataset[T] private[sql](
 
   /**
    * (Scala-specific)
-   * Returns a [[GroupedDataset]] where the data is grouped by the given key function.
+   * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
    * @since 1.6.0
    */
   def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
-    val inputPlan = queryExecution.analyzed
+    val inputPlan = logicalPlan
     val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
     val executed = sqlContext.executePlan(withGroupingKey)
 
@@ -429,18 +432,18 @@ class Dataset[T] private[sql](
 
   /**
    * (Java-specific)
-   * Returns a [[GroupedDataset]] where the data is grouped by the given key function.
+   * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
    * @since 1.6.0
    */
-  def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
-    groupBy(f.call(_))(encoder)
+  def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
+    groupBy(func.call(_))(encoder)
 
   /* ****************** *
    *  Typed Relational  *
    * ****************** */
 
   /**
-   * Selects a set of column based expressions.
+   * Returns a new [[DataFrame]] by selecting a set of column based expressions.
    * {{{
    *   df.select($"colA", $"colB" + 1)
    * }}}
@@ -464,8 +467,8 @@ class Dataset[T] private[sql](
       sqlContext,
       Project(
         c1.withInputType(
-          resolvedTEncoder.bind(queryExecution.analyzed.output),
-          queryExecution.analyzed.output).named :: Nil,
+          boundTEncoder,
+          logicalPlan.output).named :: Nil,
         logicalPlan))
   }
 
@@ -477,7 +480,7 @@ class Dataset[T] private[sql](
   protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
     val namedColumns =
-      columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
+      columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
     val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
 
     new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
@@ -654,7 +657,7 @@ class Dataset[T] private[sql](
    * Returns an array that contains all the elements in this [[Dataset]].
    *
    * Running collect requires moving all the data into the application's driver process, and
-   * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+   * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
    *
    * For Java API, use [[collectAsList]].
    * @since 1.6.0
@@ -662,17 +665,14 @@ class Dataset[T] private[sql](
   def collect(): Array[T] = {
     // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
     // to convert the rows into objects of type T.
-    val tEnc = resolvedTEncoder
-    val input = queryExecution.analyzed.output
-    val bound = tEnc.bind(input)
-    queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
+    queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
   }
 
   /**
    * Returns an array that contains all the elements in this [[Dataset]].
    *
    * Running collect requires moving all the data into the application's driver process, and
-   * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+   * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
    *
    * For Java API, use [[collectAsList]].
    * @since 1.6.0
@@ -683,7 +683,7 @@ class Dataset[T] private[sql](
    * Returns the first `num` elements of this [[Dataset]] as an array.
    *
    * Running take requires moving data into the application's driver process, and doing so with
-   * a very large `n` can crash the driver process with OutOfMemoryError.
+   * a very large `num` can crash the driver process with OutOfMemoryError.
    * @since 1.6.0
    */
   def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
@@ -692,7 +692,7 @@ class Dataset[T] private[sql](
    * Returns the first `num` elements of this [[Dataset]] as an array.
    *
    * Running take requires moving data into the application's driver process, and doing so with
-   * a very large `n` can crash the driver process with OutOfMemoryError.
+   * a very large `num` can crash the driver process with OutOfMemoryError.
    * @since 1.6.0
    */
   def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org