You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/10/24 07:46:26 UTC
spark git commit: [SPARK-20822][SQL] Generate code to directly get
value from ColumnVector for table cache
Repository: spark
Updated Branches:
refs/heads/master d9798c834 -> c30d5cfc7
[SPARK-20822][SQL] Generate code to directly get value from ColumnVector for table cache
## What changes were proposed in this pull request?
This PR generates the Java code to directly get a value for a column in `ColumnVector` without using an iterator (e.g. at lines 54-69 in the generated code example) for table cache (e.g. `dataframe.cache`). This PR improves runtime performance by eliminating data copy from column-oriented storage to `InternalRow` in a `SpecificColumnarIterator` iterator for primitive type. Another PR will support primitive type array.
Benchmark result: **1.2x**
```
OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic
Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz
Int Sum with IntDelta cache: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
InternalRow codegen 731 / 812 43.0 23.2 1.0X
ColumnVector codegen 616 / 772 51.0 19.6 1.2X
```
Benchmark program
```
intSumBenchmark(sqlContext, 1024 * 1024 * 30)
def intSumBenchmark(sqlContext: SQLContext, values: Int): Unit = {
import sqlContext.implicits._
val benchmarkPT = new Benchmark("Int Sum with IntDelta cache", values, 20)
Seq(("InternalRow", "false"), ("ColumnVector", "true")).foreach {
case (str, value) =>
withSQLConf(sqlContext, SQLConf. COLUMN_VECTOR_CODEGEN.key -> value) { // tentatively added for benchmarking
val dfPassThrough = sqlContext.sparkContext.parallelize(0 to values - 1, 1).toDF().cache()
dfPassThrough.count() // force to create df.cache()
benchmarkPT.addCase(s"$str codegen") { iter =>
dfPassThrough.agg(sum("value")).collect
}
dfPassThrough.unpersist(true)
}
}
benchmarkPT.run()
}
```
Motivating example
```
val dsInt = spark.range(3).cache
dsInt.count // force to build cache
dsInt.filter(_ > 0).collect
```
Generated code
```
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inmemorytablescan_input;
/* 009 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_numOutputRows;
/* 010 */ private org.apache.spark.sql.execution.metric.SQLMetric inmemorytablescan_scanTime;
/* 011 */ private long inmemorytablescan_scanTime1;
/* 012 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch inmemorytablescan_batch;
/* 013 */ private int inmemorytablescan_batchIdx;
/* 014 */ private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector inmemorytablescan_colInstance0;
/* 015 */ private UnsafeRow inmemorytablescan_result;
/* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder inmemorytablescan_holder;
/* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter inmemorytablescan_rowWriter;
/* 018 */ private org.apache.spark.sql.execution.metric.SQLMetric filter_numOutputRows;
/* 019 */ private UnsafeRow filter_result;
/* 020 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder filter_holder;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter filter_rowWriter;
/* 022 */
/* 023 */ public GeneratedIterator(Object[] references) {
/* 024 */ this.references = references;
/* 025 */ }
/* 026 */
/* 027 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 028 */ partitionIndex = index;
/* 029 */ this.inputs = inputs;
/* 030 */ inmemorytablescan_input = inputs[0];
/* 031 */ inmemorytablescan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 032 */ inmemorytablescan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 033 */ inmemorytablescan_scanTime1 = 0;
/* 034 */ inmemorytablescan_batch = null;
/* 035 */ inmemorytablescan_batchIdx = 0;
/* 036 */ inmemorytablescan_colInstance0 = null;
/* 037 */ inmemorytablescan_result = new UnsafeRow(1);
/* 038 */ inmemorytablescan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(inmemorytablescan_result, 0);
/* 039 */ inmemorytablescan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(inmemorytablescan_holder, 1);
/* 040 */ filter_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 041 */ filter_result = new UnsafeRow(1);
/* 042 */ filter_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(filter_result, 0);
/* 043 */ filter_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_holder, 1);
/* 044 */
/* 045 */ }
/* 046 */
/* 047 */ protected void processNext() throws java.io.IOException {
/* 048 */ if (inmemorytablescan_batch == null) {
/* 049 */ inmemorytablescan_nextBatch();
/* 050 */ }
/* 051 */ while (inmemorytablescan_batch != null) {
/* 052 */ int inmemorytablescan_numRows = inmemorytablescan_batch.numRows();
/* 053 */ int inmemorytablescan_localEnd = inmemorytablescan_numRows - inmemorytablescan_batchIdx;
/* 054 */ for (int inmemorytablescan_localIdx = 0; inmemorytablescan_localIdx < inmemorytablescan_localEnd; inmemorytablescan_localIdx++) {
/* 055 */ int inmemorytablescan_rowIdx = inmemorytablescan_batchIdx + inmemorytablescan_localIdx;
/* 056 */ int inmemorytablescan_value = inmemorytablescan_colInstance0.getInt(inmemorytablescan_rowIdx);
/* 057 */
/* 058 */ boolean filter_isNull = false;
/* 059 */
/* 060 */ boolean filter_value = false;
/* 061 */ filter_value = inmemorytablescan_value > 1;
/* 062 */ if (!filter_value) continue;
/* 063 */
/* 064 */ filter_numOutputRows.add(1);
/* 065 */
/* 066 */ filter_rowWriter.write(0, inmemorytablescan_value);
/* 067 */ append(filter_result);
/* 068 */ if (shouldStop()) { inmemorytablescan_batchIdx = inmemorytablescan_rowIdx + 1; return; }
/* 069 */ }
/* 070 */ inmemorytablescan_batchIdx = inmemorytablescan_numRows;
/* 071 */ inmemorytablescan_batch = null;
/* 072 */ inmemorytablescan_nextBatch();
/* 073 */ }
/* 074 */ inmemorytablescan_scanTime.add(inmemorytablescan_scanTime1 / (1000 * 1000));
/* 075 */ inmemorytablescan_scanTime1 = 0;
/* 076 */ }
/* 077 */
/* 078 */ private void inmemorytablescan_nextBatch() throws java.io.IOException {
/* 079 */ long getBatchStart = System.nanoTime();
/* 080 */ if (inmemorytablescan_input.hasNext()) {
/* 081 */ org.apache.spark.sql.execution.columnar.CachedBatch inmemorytablescan_cachedBatch = (org.apache.spark.sql.execution.columnar.CachedBatch)inmemorytablescan_input.next();
/* 082 */ inmemorytablescan_batch = org.apache.spark.sql.execution.columnar.InMemoryRelation$.MODULE$.createColumn(inmemorytablescan_cachedBatch);
/* 083 */
/* 084 */ inmemorytablescan_numOutputRows.add(inmemorytablescan_batch.numRows());
/* 085 */ inmemorytablescan_batchIdx = 0;
/* 086 */ inmemorytablescan_colInstance0 = (org.apache.spark.sql.execution.vectorized.OnHeapColumnVector) inmemorytablescan_batch.column(0); org.apache.spark.sql.execution.columnar.ColumnAccessor$.MODULE$.decompress(inmemorytablescan_cachedBatch.buffers()[0], (org.apache.spark.sql.execution.vectorized.WritableColumnVector) inmemorytablescan_colInstance0, org.apache.spark.sql.types.DataTypes.IntegerType, inmemorytablescan_cachedBatch.numRows());
/* 087 */
/* 088 */ }
/* 089 */ inmemorytablescan_scanTime1 += System.nanoTime() - getBatchStart;
/* 090 */ }
/* 091 */ }
```
## How was this patch tested?
Add test cases into `DataFrameTungstenSuite` and `WholeStageCodegenSuite`
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Closes #18747 from kiszk/SPARK-20822a.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c30d5cfc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c30d5cfc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c30d5cfc
Branch: refs/heads/master
Commit: c30d5cfc7117bdadd63bf730e88398139e0f65f4
Parents: d9798c8
Author: Kazuaki Ishizaki <is...@jp.ibm.com>
Authored: Tue Oct 24 08:46:22 2017 +0100
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Oct 24 08:46:22 2017 +0100
----------------------------------------------------------------------
.../spark/sql/execution/ColumnarBatchScan.scala | 3 --
.../sql/execution/WholeStageCodegenExec.scala | 24 +++++----
.../sql/execution/columnar/ColumnAccessor.scala | 8 +++
.../columnar/InMemoryTableScanExec.scala | 57 +++++++++++++++++---
.../spark/sql/DataFrameTungstenSuite.scala | 36 +++++++++++++
.../sql/execution/WholeStageCodegenSuite.scala | 32 +++++++++++
6 files changed, 141 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 1afe83e..eb01e12 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.DataType
@@ -31,8 +30,6 @@ import org.apache.spark.sql.types.DataType
*/
private[sql] trait ColumnarBatchScan extends CodegenSupport {
- val inMemoryTableScan: InMemoryTableScanExec = null
-
def vectorTypes: Option[Seq[String]] = None
override lazy val metrics = Map(
http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 1aaaf89..e37d133 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -282,6 +282,18 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
object WholeStageCodegenExec {
val PIPELINE_DURATION_METRIC = "duration"
+
+ private def numOfNestedFields(dataType: DataType): Int = dataType match {
+ case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum
+ case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
+ case a: ArrayType => numOfNestedFields(a.elementType)
+ case u: UserDefinedType[_] => numOfNestedFields(u.sqlType)
+ case _ => 1
+ }
+
+ def isTooManyFields(conf: SQLConf, dataType: DataType): Boolean = {
+ numOfNestedFields(dataType) > conf.wholeStageMaxNumFields
+ }
}
/**
@@ -490,22 +502,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
case _ => true
}
- private def numOfNestedFields(dataType: DataType): Int = dataType match {
- case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum
- case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
- case a: ArrayType => numOfNestedFields(a.elementType)
- case u: UserDefinedType[_] => numOfNestedFields(u.sqlType)
- case _ => 1
- }
-
private def supportCodegen(plan: SparkPlan): Boolean = plan match {
case plan: CodegenSupport if plan.supportCodegen =>
val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
// the generated code will be huge if there are too many columns
val hasTooManyOutputFields =
- numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields
+ WholeStageCodegenExec.isTooManyFields(conf, plan.schema)
val hasTooManyInputFields =
- plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields)
+ plan.children.exists(p => WholeStageCodegenExec.isTooManyFields(conf, p.schema))
!willFallback && !hasTooManyOutputFields && !hasTooManyInputFields
case _ => false
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
index 24c8ac8..445933d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
@@ -163,4 +163,12 @@ private[sql] object ColumnAccessor {
throw new RuntimeException("Not support non-primitive type now")
}
}
+
+ def decompress(
+ array: Array[Byte], columnVector: WritableColumnVector, dataType: DataType, numRows: Int):
+ Unit = {
+ val byteBuffer = ByteBuffer.wrap(array)
+ val columnAccessor = ColumnAccessor(dataType, byteBuffer)
+ decompress(columnAccessor, columnVector, numRows)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 139da1c..43386e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -23,21 +23,66 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
-import org.apache.spark.sql.execution.LeafExecNode
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.UserDefinedType
+import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.vectorized._
+import org.apache.spark.sql.types._
case class InMemoryTableScanExec(
attributes: Seq[Attribute],
predicates: Seq[Expression],
@transient relation: InMemoryRelation)
- extends LeafExecNode {
+ extends LeafExecNode with ColumnarBatchScan {
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren
- override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+ override def vectorTypes: Option[Seq[String]] =
+ Option(Seq.fill(attributes.length)(classOf[OnHeapColumnVector].getName))
+
+ /**
+ * If true, get data from ColumnVector in ColumnarBatch, which are generally faster.
+ * If false, get data from UnsafeRow build from ColumnVector
+ */
+ override val supportCodegen: Boolean = {
+ // In the initial implementation, for ease of review
+ // support only primitive data types and # of fields is less than wholeStageMaxNumFields
+ relation.schema.fields.forall(f => f.dataType match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType => true
+ case _ => false
+ }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema)
+ }
+
+ private val columnIndices =
+ attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray
+
+ private val relationSchema = relation.schema.toArray
+
+ private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i)))
+
+ private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = {
+ val rowCount = cachedColumnarBatch.numRows
+ val columnVectors = OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema)
+ val columnarBatch = new ColumnarBatch(
+ columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount)
+ columnarBatch.setNumRows(rowCount)
+
+ for (i <- 0 until attributes.length) {
+ ColumnAccessor.decompress(
+ cachedColumnarBatch.buffers(columnIndices(i)),
+ columnarBatch.column(i).asInstanceOf[WritableColumnVector],
+ columnarBatchSchema.fields(i).dataType, rowCount)
+ }
+ columnarBatch
+ }
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ assert(supportCodegen)
+ val buffers = relation.cachedColumnBuffers
+ // HACK ALERT: This is actually an RDD[ColumnarBatch].
+ // We're taking advantage of Scala's type erasure here to pass these batches along.
+ Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]])
+ }
override def output: Seq[Attribute] = attributes
http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
index fe6ba83..0881212 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
@@ -73,4 +73,40 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
val df = spark.createDataFrame(data, schema)
assert(df.select("b").first() === Row(outerStruct))
}
+
+ test("primitive data type accesses in persist data") {
+ val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong,
+ 31.25.toFloat, 63.75, null)
+ val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType,
+ FloatType, DoubleType, IntegerType)
+ val schemas = dataTypes.zipWithIndex.map { case (dataType, index) =>
+ StructField(s"col$index", dataType, true)
+ }
+ val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
+ val df = spark.createDataFrame(rdd, StructType(schemas))
+ val row = df.persist.take(1).apply(0)
+ checkAnswer(df, row)
+ }
+
+ test("access cache multiple times") {
+ val df0 = sparkContext.parallelize(Seq(1, 2, 3), 1).toDF("x").cache
+ df0.count
+ val df1 = df0.filter("x > 1")
+ checkAnswer(df1, Seq(Row(2), Row(3)))
+ val df2 = df0.filter("x > 2")
+ checkAnswer(df2, Row(3))
+
+ val df10 = sparkContext.parallelize(Seq(3, 4, 5, 6), 1).toDF("x").cache
+ for (_ <- 0 to 2) {
+ val df11 = df10.filter("x > 5")
+ checkAnswer(df11, Row(6))
+ }
+ }
+
+ test("access only some column of the all of columns") {
+ val df = spark.range(1, 10).map(i => (i, (i + 1).toDouble)).toDF("l", "d")
+ df.cache
+ df.count
+ assert(df.filter("d < 3").count == 1)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c30d5cfc/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 098e4cf..bc05dca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
@@ -117,6 +118,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
}
+ test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") {
+ import testImplicits._
+
+ val dsInt = spark.range(3).cache
+ dsInt.count
+ val dsIntFilter = dsInt.filter(_ > 0)
+ val planInt = dsIntFilter.queryExecution.executedPlan
+ assert(planInt.find(p =>
+ p.isInstanceOf[WholeStageCodegenExec] &&
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] &&
+ p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
+ .isInstanceOf[InMemoryTableScanExec] &&
+ p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
+ .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined
+ )
+ assert(dsIntFilter.collect() === Array(1, 2))
+
+ // cache for string type is not supported for InMemoryTableScanExec
+ val dsString = spark.range(3).map(_.toString).cache
+ dsString.count
+ val dsStringFilter = dsString.filter(_ == "1")
+ val planString = dsStringFilter.queryExecution.executedPlan
+ assert(planString.find(p =>
+ p.isInstanceOf[WholeStageCodegenExec] &&
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] &&
+ !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
+ .isInstanceOf[InMemoryTableScanExec]).isDefined
+ )
+ assert(dsStringFilter.collect() === Array("1"))
+ }
+
test("SPARK-19512 codegen for comparing structs is incorrect") {
// this would raise CompileException before the fix
spark.range(10)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org