You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/06/26 07:47:04 UTC

spark git commit: [SPARK-8635] [SQL] improve performance of CatalystTypeConverters

Repository: spark
Updated Branches:
  refs/heads/master 40360112c -> 1a79f0eb8


[SPARK-8635] [SQL] improve performance of CatalystTypeConverters

In `CatalystTypeConverters.createToCatalystConverter`, we add special handling for primitive types. We can apply this strategy to more places to improve performance.

Author: Wenchen Fan <cl...@outlook.com>

Closes #7018 from cloud-fan/converter and squashes the following commits:

8b16630 [Wenchen Fan] another fix
326c82c [Wenchen Fan] optimize type converter


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1a79f0eb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1a79f0eb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1a79f0eb

Branch: refs/heads/master
Commit: 1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb
Parents: 4036011
Author: Wenchen Fan <cl...@outlook.com>
Authored: Thu Jun 25 22:44:26 2015 -0700
Committer: Davies Liu <da...@databricks.com>
Committed: Thu Jun 25 22:44:26 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/CatalystTypeConverters.scala   | 60 ++++++++++++--------
 .../sql/catalyst/expressions/ScalaUdf.scala     |  3 +-
 .../scala/org/apache/spark/sql/DataFrame.scala  |  4 +-
 .../sql/execution/stat/FrequentItems.scala      |  2 +-
 .../sql/execution/stat/StatFunctions.scala      |  2 +-
 .../spark/sql/sources/DataSourceStrategy.scala  |  2 +-
 .../org/apache/spark/sql/sources/commands.scala |  4 +-
 .../spark/sql/sources/TableScanSuite.scala      |  4 +-
 8 files changed, 48 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1a79f0eb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 429fc40..012f8bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -52,6 +52,13 @@ object CatalystTypeConverters {
     }
   }
 
+  private def isWholePrimitive(dt: DataType): Boolean = dt match {
+    case dt if isPrimitive(dt) => true
+    case ArrayType(elementType, _) => isWholePrimitive(elementType)
+    case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
+    case _ => false
+  }
+
   private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
     val converter = dataType match {
       case udt: UserDefinedType[_] => UDTConverter(udt)
@@ -148,6 +155,8 @@ object CatalystTypeConverters {
 
     private[this] val elementConverter = getConverterForType(elementType)
 
+    private[this] val isNoChange = isWholePrimitive(elementType)
+
     override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
       scalaValue match {
         case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
@@ -166,8 +175,10 @@ object CatalystTypeConverters {
     override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
       if (catalystValue == null) {
         null
+      } else if (isNoChange) {
+        catalystValue
       } else {
-        catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala)
+        catalystValue.map(elementConverter.toScala)
       }
     }
 
@@ -183,6 +194,8 @@ object CatalystTypeConverters {
     private[this] val keyConverter = getConverterForType(keyType)
     private[this] val valueConverter = getConverterForType(valueType)
 
+    private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType)
+
     override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match {
       case m: Map[_, _] =>
         m.map { case (k, v) =>
@@ -203,6 +216,8 @@ object CatalystTypeConverters {
     override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = {
       if (catalystValue == null) {
         null
+      } else if (isNoChange) {
+        catalystValue
       } else {
         catalystValue.map { case (k, v) =>
           keyConverter.toScala(k) -> valueConverter.toScala(v)
@@ -258,16 +273,13 @@ object CatalystTypeConverters {
       toScala(row(column).asInstanceOf[InternalRow])
   }
 
-  private object StringConverter extends CatalystTypeConverter[Any, String, Any] {
+  private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] {
     override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
       case str: String => UTF8String.fromString(str)
       case utf8: UTF8String => utf8
     }
-    override def toScala(catalystValue: Any): String = catalystValue match {
-      case null => null
-      case str: String => str
-      case utf8: UTF8String => utf8.toString()
-    }
+    override def toScala(catalystValue: UTF8String): String =
+      if (catalystValue == null) null else catalystValue.toString
     override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString
   }
 
@@ -275,7 +287,8 @@ object CatalystTypeConverters {
     override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue)
     override def toScala(catalystValue: Any): Date =
       if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int])
-    override def toScalaImpl(row: InternalRow, column: Int): Date = toScala(row.getInt(column))
+    override def toScalaImpl(row: InternalRow, column: Int): Date =
+      DateTimeUtils.toJavaDate(row.getInt(column))
   }
 
   private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] {
@@ -285,7 +298,7 @@ object CatalystTypeConverters {
       if (catalystValue == null) null
       else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long])
     override def toScalaImpl(row: InternalRow, column: Int): Timestamp =
-      toScala(row.getLong(column))
+      DateTimeUtils.toJavaTimestamp(row.getLong(column))
   }
 
   private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
@@ -296,10 +309,7 @@ object CatalystTypeConverters {
     }
     override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
     override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
-      row.get(column) match {
-        case d: JavaBigDecimal => d
-        case d: Decimal => d.toJavaBigDecimal
-      }
+      row.get(column).asInstanceOf[Decimal].toJavaBigDecimal
   }
 
   private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
@@ -363,6 +373,19 @@ object CatalystTypeConverters {
   }
 
   /**
+   * Creates a converter function that will convert Catalyst types to Scala type.
+   * Typical use case would be converting a collection of rows that have the same schema. You will
+   * call this function once to get a converter, and apply it to every row.
+   */
+  private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
+    if (isPrimitive(dataType)) {
+      identity
+    } else {
+      getConverterForType(dataType).toScala
+    }
+  }
+
+  /**
    *  Converts Scala objects to Catalyst rows / types.
    *
    *  Note: This should be called before do evaluation on Row
@@ -389,15 +412,6 @@ object CatalystTypeConverters {
    * produced by createToScalaConverter.
    */
   def convertToScala(catalystValue: Any, dataType: DataType): Any = {
-    getConverterForType(dataType).toScala(catalystValue)
-  }
-
-  /**
-   * Creates a converter function that will convert Catalyst types to Scala type.
-   * Typical use case would be converting a collection of rows that have the same schema. You will
-   * call this function once to get a converter, and apply it to every row.
-   */
-  private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
-    getConverterForType(dataType).toScala
+    createToScalaConverter(dataType)(catalystValue)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1a79f0eb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 3992f1f..55df72f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.catalyst
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.types.DataType
 
@@ -39,7 +38,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
     (1 to 22).map { x =>
       val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _)
       val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n  " + _)
-      lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n  " + _)
+      val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n  " + _)
       val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n      " + _)
 
       s"""case $x =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1a79f0eb/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index f3f0f53..0db4df3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1418,12 +1418,14 @@ class DataFrame private[sql](
   lazy val rdd: RDD[Row] = {
     // use a local variable to make sure the map closure doesn't capture the whole DataFrame
     val schema = this.schema
-    queryExecution.executedPlan.execute().mapPartitions { rows =>
+    internalRowRdd.mapPartitions { rows =>
       val converter = CatalystTypeConverters.createToScalaConverter(schema)
       rows.map(converter(_).asInstanceOf[Row])
     }
   }
 
+  private[sql] def internalRowRdd = queryExecution.executedPlan.execute()
+
   /**
    * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
    * @group rdd

http://git-wip-us.apache.org/repos/asf/spark/blob/1a79f0eb/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 8df1da0..3ebbf96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging {
       (name, originalSchema.fields(index).dataType)
     }
 
-    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
+    val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)(
       seqOp = (counts, row) => {
         var i = 0
         while (i < numCols) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1a79f0eb/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 93383e5..252c611 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -81,7 +81,7 @@ private[sql] object StatFunctions extends Logging {
         s"with dataType ${data.get.dataType} not supported.")
     }
     val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
-    df.select(columns: _*).rdd.aggregate(new CovarianceCounter)(
+    df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)(
       seqOp = (counter, row) => {
         counter.add(row.getDouble(0), row.getDouble(1))
       },

http://git-wip-us.apache.org/repos/asf/spark/blob/1a79f0eb/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index a8f56f4..ce16e05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -313,7 +313,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
       output: Seq[Attribute],
       rdd: RDD[Row]): RDD[InternalRow] = {
     if (relation.relation.needConversion) {
-      execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType))
+      execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
     } else {
       rdd.map(_.asInstanceOf[InternalRow])
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/1a79f0eb/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index fb6173f..dbb369c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
     writerContainer.driverSideSetup()
 
     try {
-      df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
+      df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _)
       writerContainer.commitJob()
       relation.refresh()
     } catch { case cause: Throwable =>
@@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
     writerContainer.driverSideSetup()
 
     try {
-      df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
+      df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _)
       writerContainer.commitJob()
       relation.refresh()
     } catch { case cause: Throwable =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1a79f0eb/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 79eac93..de0ed0c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -88,9 +88,9 @@ case class AllDataTypesScan(
         UTF8String.fromString(s"varchar_$i"),
         Seq(i, i + 1),
         Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
-        Map(i -> i.toString),
+        Map(i -> UTF8String.fromString(i.toString)),
         Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
-        Row(i, i.toString),
+        Row(i, UTF8String.fromString(i.toString)),
         Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
           InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org