You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/10/24 07:46:26 UTC

spark git commit: [SPARK-20822][SQL] Generate code to directly get value from ColumnVector for table cache

Repository: spark
Updated Branches:
  refs/heads/master d9798c834 -> c30d5cfc7


[SPARK-20822][SQL] Generate code to directly get value from ColumnVector for table cache

## What changes were proposed in this pull request?

This PR generates the Java code to directly get a value for a column in `ColumnVector` without using an iterator (e.g. at lines 54-69 in the generated code example) for table cache (e.g. `dataframe.cache`). This PR improves runtime performance by eliminating data copy from column-oriented storage to `InternalRow` in a `SpecificColumnarIterator` iterator for primitive type. Another PR will support primitive type array.

Benchmark result: **1.2x**
```
OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic
Intel(R) Xeon(R) CPU E5-2667 v3  3.20GHz
Int Sum with IntDelta cache:             Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
InternalRow codegen                            731 /  812         43.0          23.2       1.0X
ColumnVector codegen                           616 /  772         51.0          19.6       1.2X
```
Benchmark program
```
  intSumBenchmark(sqlContext, 1024 * 1024 * 30)
  def intSumBenchmark(sqlContext: SQLContext, values: Int): Unit = {
    import sqlContext.implicits._
    val benchmarkPT = new Benchmark("Int Sum with IntDelta cache", values, 20)
    Seq(("InternalRow", "false"), ("ColumnVector", "true")).foreach {
      case (str, value) =>
        withSQLConf(sqlContext, SQLConf. COLUMN_VECTOR_CODEGEN.key -> value) { // tentatively added for benchmarking
          val dfPassThrough = sqlContext.sparkContext.parallelize(0 to values - 1, 1).toDF().cache()
          dfPassThrough.count()       // force to create df.cache()
          benchmarkPT.addCase(s"$str codegen") { iter =>
            dfPassThrough.agg(sum("value")).collect
          }
          dfPassThrough.unpersist(true)
        }
    }
    benchmarkPT.run()
  }
```

Motivating example
```
val dsInt = spark.range(3).cache
dsInt.count // force to build cache
dsInt.filter(_ > 0).collect
```
Generated code
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inmemorytablescan_input;
/* 009 */   private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_numOutputRows;
/* 010 */   private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_scanTime;
/* 011 */   private long inmemorytablescan_scanTime1;
/* 012 */   private org.apache.spark.sql.execution.vectorized.ColumnarBatch inmemorytablescan_batch;
/* 013 */   private int inmemorytablescan_batchIdx;
/* 014 */   private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector inmemorytablescan_colInstance0;
/* 015 */   private UnsafeRow inmemorytablescan_result;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder inmemorytablescan_holder;
/* 017 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter inmemorytablescan_rowWriter;
/* 018 */   private org.apache.spark.sql.execution.metric.SQLMetric filter_numOutputRows;
/* 019 */   private UnsafeRow filter_result;
/* 020 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder filter_holder;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter filter_rowWriter;
/* 022 */
/* 023 */   public GeneratedIterator(Object[] references) {
/* 024 */     this.references = references;
/* 025 */   }
/* 026 */
/* 027 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 028 */     partitionIndex = index;
/* 029 */     this.inputs = inputs;
/* 030 */     inmemorytablescan_input = inputs[0];
/* 031 */     inmemorytablescan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 032 */     inmemorytablescan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 033 */     inmemorytablescan_scanTime1 = 0;
/* 034 */     inmemorytablescan_batch = null;
/* 035 */     inmemorytablescan_batchIdx = 0;
/* 036 */     inmemorytablescan_colInstance0 = null;
/* 037 */     inmemorytablescan_result = new UnsafeRow(1);
/* 038 */     inmemorytablescan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(inmemorytablescan_result, 0);
/* 039 */     inmemorytablescan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(inmemorytablescan_holder, 1);
/* 040 */     filter_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 041 */     filter_result = new UnsafeRow(1);
/* 042 */     filter_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(filter_result, 0);
/* 043 */     filter_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_holder, 1);
/* 044 */
/* 045 */   }
/* 046 */
/* 047 */   protected void processNext() throws java.io.IOException {
/* 048 */     if (inmemorytablescan_batch == null) {
/* 049 */       inmemorytablescan_nextBatch();
/* 050 */     }
/* 051 */     while (inmemorytablescan_batch != null) {
/* 052 */       int inmemorytablescan_numRows = inmemorytablescan_batch.numRows();
/* 053 */       int inmemorytablescan_localEnd = inmemorytablescan_numRows - inmemorytablescan_batchIdx;
/* 054 */       for (int inmemorytablescan_localIdx = 0; inmemorytablescan_localIdx < inmemorytablescan_localEnd; inmemorytablescan_localIdx++) {
/* 055 */         int inmemorytablescan_rowIdx = inmemorytablescan_batchIdx + inmemorytablescan_localIdx;
/* 056 */         int inmemorytablescan_value = inmemorytablescan_colInstance0.getInt(inmemorytablescan_rowIdx);
/* 057 */
/* 058 */         boolean filter_isNull = false;
/* 059 */
/* 060 */         boolean filter_value = false;
/* 061 */         filter_value = inmemorytablescan_value > 1;
/* 062 */         if (!filter_value) continue;
/* 063 */
/* 064 */         filter_numOutputRows.add(1);
/* 065 */
/* 066 */         filter_rowWriter.write(0, inmemorytablescan_value);
/* 067 */         append(filter_result);
/* 068 */         if (shouldStop()) { inmemorytablescan_batchIdx = inmemorytablescan_rowIdx + 1; return; }
/* 069 */       }
/* 070 */       inmemorytablescan_batchIdx = inmemorytablescan_numRows;
/* 071 */       inmemorytablescan_batch = null;
/* 072 */       inmemorytablescan_nextBatch();
/* 073 */     }
/* 074 */     inmemorytablescan_scanTime.add(inmemorytablescan_scanTime1 / (1000 * 1000));
/* 075 */     inmemorytablescan_scanTime1 = 0;
/* 076 */   }
/* 077 */
/* 078 */   private void inmemorytablescan_nextBatch() throws java.io.IOException {
/* 079 */     long getBatchStart = System.nanoTime();
/* 080 */     if (inmemorytablescan_input.hasNext()) {
/* 081 */       org.apache.spark.sql.execution.columnar.CachedBatch inmemorytablescan_cachedBatch = (org.apache.spark.sql.execution.columnar.CachedBatch)inmemorytablescan_input.next();
/* 082 */       inmemorytablescan_batch = org.apache.spark.sql.execution.columnar.InMemoryRelation$.MODULE$.createColumn(inmemorytablescan_cachedBatch);
/* 083 */
/* 084 */       inmemorytablescan_numOutputRows.add(inmemorytablescan_batch.numRows());
/* 085 */       inmemorytablescan_batchIdx = 0;
/* 086 */       inmemorytablescan_colInstance0 = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) inmemorytablescan_batch.column(0); org.apache.spark.sql.execution.columnar.ColumnAccessor$.MODULE$.decompress(inmemorytablescan_cachedBatch.buffers()[0], (org.apache.spark.sql.execution.vectorized.WritableColumnVector) inmemorytablescan_colInstance0, org.apache.spark.sql.types.DataTypes.IntegerType, inmemorytablescan_cachedBatch.numRows());
/* 087 */
/* 088 */     }
/* 089 */     inmemorytablescan_scanTime1 += System.nanoTime() - getBatchStart;
/* 090 */   }
/* 091 */ }
```

## How was this patch tested?

Add test cases into `DataFrameTungstenSuite` and `WholeStageCodegenSuite`

Author: Kazuaki Ishizaki <is...@jp.ibm.com>

Closes #18747 from kiszk/SPARK-20822a.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c30d5cfc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c30d5cfc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c30d5cfc

Branch: refs/heads/master
Commit: c30d5cfc7117bdadd63bf730e88398139e0f65f4
Parents: d9798c8
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Tue Oct 24 08:46:22 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Oct 24 08:46:22 2017 +0100

----------------------------------------------------------------------
 .../spark/sql/execution/ColumnarBatchScan.scala |  3 --
 .../sql/execution/WholeStageCodegenExec.scala   | 24 +++++----
 .../sql/execution/columnar/ColumnAccessor.scala |  8 +++
 .../columnar/InMemoryTableScanExec.scala        | 57 +++++++++++++++++---
 .../spark/sql/DataFrameTungstenSuite.scala      | 36 +++++++++++++
 .../sql/execution/WholeStageCodegenSuite.scala  | 32 +++++++++++
 6 files changed, 141 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 1afe83e..eb01e12 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector}
 import org.apache.spark.sql.types.DataType
@@ -31,8 +30,6 @@ import org.apache.spark.sql.types.DataType
  */
 private[sql] trait ColumnarBatchScan extends CodegenSupport {
 
-  val inMemoryTableScan: InMemoryTableScanExec = null
-
   def vectorTypes: Option[Seq[String]] = None
 
   override lazy val metrics = Map(

http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 1aaaf89..e37d133 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -282,6 +282,18 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
 
 object WholeStageCodegenExec {
   val PIPELINE_DURATION_METRIC = "duration"
+
+  private def numOfNestedFields(dataType: DataType): Int = dataType match {
+    case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum
+    case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
+    case a: ArrayType => numOfNestedFields(a.elementType)
+    case u: UserDefinedType[_] => numOfNestedFields(u.sqlType)
+    case _ => 1
+  }
+
+  def isTooManyFields(conf: SQLConf, dataType: DataType): Boolean = {
+    numOfNestedFields(dataType) > conf.wholeStageMaxNumFields
+  }
 }
 
 /**
@@ -490,22 +502,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
     case _ => true
   }
 
-  private def numOfNestedFields(dataType: DataType): Int = dataType match {
-    case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum
-    case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
-    case a: ArrayType => numOfNestedFields(a.elementType)
-    case u: UserDefinedType[_] => numOfNestedFields(u.sqlType)
-    case _ => 1
-  }
-
   private def supportCodegen(plan: SparkPlan): Boolean = plan match {
     case plan: CodegenSupport if plan.supportCodegen =>
       val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
       // the generated code will be huge if there are too many columns
       val hasTooManyOutputFields =
-        numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields
+        WholeStageCodegenExec.isTooManyFields(conf, plan.schema)
       val hasTooManyInputFields =
-        plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields)
+        plan.children.exists(p => WholeStageCodegenExec.isTooManyFields(conf, p.schema))
       !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields
     case _ => false
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
index 24c8ac8..445933d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
@@ -163,4 +163,12 @@ private[sql] object ColumnAccessor {
       throw new RuntimeException("Not support non-primitive type now")
     }
   }
+
+  def decompress(
+      array: Array[Byte], columnVector: WritableColumnVector, dataType: DataType, numRows: Int):
+      Unit = {
+    val byteBuffer = ByteBuffer.wrap(array)
+    val columnAccessor = ColumnAccessor(dataType, byteBuffer)
+    decompress(columnAccessor, columnVector, numRows)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 139da1c..43386e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -23,21 +23,66 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
-import org.apache.spark.sql.execution.LeafExecNode
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.UserDefinedType
+import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.vectorized._
+import org.apache.spark.sql.types._
 
 
 case class InMemoryTableScanExec(
     attributes: Seq[Attribute],
     predicates: Seq[Expression],
     @transient relation: InMemoryRelation)
-  extends LeafExecNode {
+  extends LeafExecNode with ColumnarBatchScan {
 
   override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren
 
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+  override def vectorTypes: Option[Seq[String]] =
+    Option(Seq.fill(attributes.length)(classOf[OnHeapColumnVector].getName))
+
+  /**
+   * If true, get data from ColumnVector in ColumnarBatch, which are generally faster.
+   * If false, get data from UnsafeRow build from ColumnVector
+   */
+  override val supportCodegen: Boolean = {
+    // In the initial implementation, for ease of review
+    // support only primitive data types and # of fields is less than wholeStageMaxNumFields
+    relation.schema.fields.forall(f => f.dataType match {
+      case BooleanType | ByteType | ShortType | IntegerType | LongType |
+           FloatType | DoubleType => true
+      case _ => false
+    }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema)
+  }
+
+  private val columnIndices =
+    attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray
+
+  private val relationSchema = relation.schema.toArray
+
+  private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i)))
+
+  private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = {
+    val rowCount = cachedColumnarBatch.numRows
+    val columnVectors = OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema)
+    val columnarBatch = new ColumnarBatch(
+      columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount)
+    columnarBatch.setNumRows(rowCount)
+
+    for (i <- 0 until attributes.length) {
+      ColumnAccessor.decompress(
+        cachedColumnarBatch.buffers(columnIndices(i)),
+        columnarBatch.column(i).asInstanceOf[WritableColumnVector],
+        columnarBatchSchema.fields(i).dataType, rowCount)
+    }
+    columnarBatch
+  }
+
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    assert(supportCodegen)
+    val buffers = relation.cachedColumnBuffers
+    // HACK ALERT: This is actually an RDD[ColumnarBatch].
+    // We're taking advantage of Scala's type erasure here to pass these batches along.
+    Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]])
+  }
 
   override def output: Seq[Attribute] = attributes
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
index fe6ba83..0881212 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
@@ -73,4 +73,40 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
     val df = spark.createDataFrame(data, schema)
     assert(df.select("b").first() === Row(outerStruct))
   }
+
+  test("primitive data type accesses in persist data") {
+    val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong,
+      31.25.toFloat, 63.75, null)
+    val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType,
+      FloatType, DoubleType, IntegerType)
+    val schemas = dataTypes.zipWithIndex.map { case (dataType, index) =>
+      StructField(s"col$index", dataType, true)
+    }
+    val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
+    val df = spark.createDataFrame(rdd, StructType(schemas))
+    val row = df.persist.take(1).apply(0)
+    checkAnswer(df, row)
+  }
+
+  test("access cache multiple times") {
+    val df0 = sparkContext.parallelize(Seq(1, 2, 3), 1).toDF("x").cache
+    df0.count
+    val df1 = df0.filter("x > 1")
+    checkAnswer(df1, Seq(Row(2), Row(3)))
+    val df2 = df0.filter("x > 2")
+    checkAnswer(df2, Row(3))
+
+    val df10 = sparkContext.parallelize(Seq(3, 4, 5, 6), 1).toDF("x").cache
+    for (_ <- 0 to 2) {
+      val df11 = df10.filter("x > 5")
+      checkAnswer(df11, Row(6))
+    }
+  }
+
+  test("access only some column of the all of columns") {
+    val df = spark.range(1, 10).map(i => (i, (i + 1).toDouble)).toDF("l", "d")
+    df.cache
+    df.count
+    assert(df.filter("d < 3").count == 1)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 098e4cf..bc05dca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.{QueryTest, Row, SaveMode}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
 import org.apache.spark.sql.expressions.scalalang.typed
@@ -117,6 +118,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
     assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
   }
 
+  test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") {
+    import testImplicits._
+
+    val dsInt = spark.range(3).cache
+    dsInt.count
+    val dsIntFilter = dsInt.filter(_ > 0)
+    val planInt = dsIntFilter.queryExecution.executedPlan
+    assert(planInt.find(p =>
+      p.isInstanceOf[WholeStageCodegenExec] &&
+      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] &&
+      p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
+        .isInstanceOf[InMemoryTableScanExec] &&
+      p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
+        .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined
+    )
+    assert(dsIntFilter.collect() === Array(1, 2))
+
+    // cache for string type is not supported for InMemoryTableScanExec
+    val dsString = spark.range(3).map(_.toString).cache
+    dsString.count
+    val dsStringFilter = dsString.filter(_ == "1")
+    val planString = dsStringFilter.queryExecution.executedPlan
+    assert(planString.find(p =>
+      p.isInstanceOf[WholeStageCodegenExec] &&
+      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] &&
+      !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
+        .isInstanceOf[InMemoryTableScanExec]).isDefined
+    )
+    assert(dsStringFilter.collect() === Array("1"))
+  }
+
   test("SPARK-19512 codegen for comparing structs is incorrect") {
     // this would raise CompileException before the fix
     spark.range(10)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org