You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/12/03 16:05:26 UTC

spark git commit: [SPARK-25498][SQL] InterpretedMutableProjection should handle UnsafeRow

Repository: spark
Updated Branches:
  refs/heads/master 5e5b9f2ee -> 04046e543


[SPARK-25498][SQL] InterpretedMutableProjection should handle UnsafeRow

## What changes were proposed in this pull request?
Since `AggregationIterator` uses `MutableProjection` for `UnsafeRow`, `InterpretedMutableProjection` needs to handle `UnsafeRow` as buffer internally for fixed-length types only.

## How was this patch tested?
Run 'SQLQueryTestSuite' with the interpreted mode.

Closes #22512 from maropu/InterpreterTest.

Authored-by: Takeshi Yamamuro <ya...@apache.org>
Signed-off-by: Wenchen Fan <we...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/04046e54
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/04046e54
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/04046e54

Branch: refs/heads/master
Commit: 04046e5432acb1132fa567f2230723bc1a92a482
Parents: 5e5b9f2
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Tue Dec 4 00:05:15 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Dec 4 00:05:15 2018 +0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/InternalRow.scala | 22 ++++++
 .../InterpretedMutableProjection.scala          | 23 +++++-
 .../expressions/ExpressionEvalHelper.scala      | 11 +++
 .../expressions/MutableProjectionSuite.scala    | 81 ++++++++++++++++++++
 .../expressions/UnsafeRowConverterSuite.scala   | 15 +---
 .../sql-tests/inputs/change-column.sql          |  1 +
 .../test/resources/sql-tests/inputs/udaf.sql    |  3 +
 .../sql-tests/results/change-column.sql.out     | 10 ++-
 .../resources/sql-tests/results/udaf.sql.out    | 18 ++++-
 .../apache/spark/sql/SQLQueryTestSuite.scala    | 27 ++++++-
 10 files changed, 192 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index e49c10b..bdab407 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -157,4 +157,26 @@ object InternalRow {
       getValueNullSafe
     }
   }
+
+  /**
+   * Returns a writer for an `InternalRow` with given data type.
+   */
+  def getWriter(ordinal: Int, dt: DataType): (InternalRow, Any) => Unit = dt match {
+    case BooleanType => (input, v) => input.setBoolean(ordinal, v.asInstanceOf[Boolean])
+    case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
+    case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short])
+    case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
+    case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
+    case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
+    case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double])
+    case DecimalType.Fixed(precision, _) =>
+      (input, v) => input.setDecimal(ordinal, v.asInstanceOf[Decimal], precision)
+    case udt: UserDefinedType[_] => getWriter(ordinal, udt.sqlType)
+    case NullType => (input, _) => input.setNullAt(ordinal)
+    case StringType => (input, v) => input.update(ordinal, v.asInstanceOf[UTF8String].copy())
+    case _: StructType => (input, v) => input.update(ordinal, v.asInstanceOf[InternalRow].copy())
+    case _: ArrayType => (input, v) => input.update(ordinal, v.asInstanceOf[ArrayData].copy())
+    case _: MapType => (input, v) => input.update(ordinal, v.asInstanceOf[MapData].copy())
+    case _ => (input, v) => input.update(ordinal, v)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
index 0654108..122a564 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
@@ -49,10 +49,31 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable
   def currentValue: InternalRow = mutableRow
 
   override def target(row: InternalRow): MutableProjection = {
+    // If `mutableRow` is `UnsafeRow`, `MutableProjection` accepts fixed-length types only
+    require(!row.isInstanceOf[UnsafeRow] ||
+      validExprs.forall { case (e, _) => UnsafeRow.isFixedLength(e.dataType) },
+      "MutableProjection cannot use UnsafeRow for output data types: " +
+        validExprs.map(_._1.dataType).filterNot(UnsafeRow.isFixedLength)
+          .map(_.catalogString).mkString(", "))
     mutableRow = row
     this
   }
 
+  private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case (e, i) =>
+    val writer = InternalRow.getWriter(i, e.dataType)
+    if (!e.nullable) {
+      (v: Any) => writer(mutableRow, v)
+    } else {
+      (v: Any) => {
+        if (v == null) {
+          mutableRow.setNullAt(i)
+        } else {
+          writer(mutableRow, v)
+        }
+      }
+    }
+  }.toArray
+
   override def apply(input: InternalRow): InternalRow = {
     var i = 0
     while (i < validExprs.length) {
@@ -64,7 +85,7 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable
     i = 0
     while (i < validExprs.length) {
       val (_, ordinal) = validExprs(i)
-      mutableRow(ordinal) = buffer(ordinal)
+      fieldWriters(i)(buffer(ordinal))
       i += 1
     }
     mutableRow

http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index eb33325..a7282e1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -456,4 +456,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
       diff < eps * math.min(absX, absY)
     }
   }
+
+  def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = {
+    val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN)
+    for (fallbackMode <- modes) {
+      test(s"$name with $fallbackMode") {
+        withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
+          f
+        }
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
new file mode 100644
index 0000000..2db1c3b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+  val fixedLengthTypes = Array[DataType](
+    BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
+    DateType, TimestampType)
+
+  val variableLengthTypes = Array(
+    StringType, DecimalType.defaultConcreteType, CalendarIntervalType, BinaryType,
+    ArrayType(StringType), MapType(IntegerType, StringType),
+    StructType.fromDDL("a INT, b STRING"), ObjectType(classOf[java.lang.Integer]))
+
+  def createMutableProjection(dataTypes: Array[DataType]): MutableProjection = {
+    MutableProjection.create(dataTypes.zipWithIndex.map(x => BoundReference(x._2, x._1, true)))
+  }
+
+  testBothCodegenAndInterpreted("fixed-length types") {
+    val inputRow = InternalRow.fromSeq(Seq(true, 3.toByte, 15.toShort, -83, 129L, 1.0f, 5.0, 1, 2L))
+    val proj = createMutableProjection(fixedLengthTypes)
+    assert(proj(inputRow) === inputRow)
+  }
+
+  testBothCodegenAndInterpreted("unsafe buffer") {
+    val inputRow = InternalRow.fromSeq(Seq(false, 1.toByte, 9.toShort, -18, 53L, 3.2f, 7.8, 4, 9L))
+    val numBytes = UnsafeRow.calculateBitSetWidthInBytes(fixedLengthTypes.length)
+    val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length)
+    val proj = createMutableProjection(fixedLengthTypes)
+    val projUnsafeRow = proj.target(unsafeBuffer)(inputRow)
+    assert(FromUnsafeProjection.apply(fixedLengthTypes)(projUnsafeRow) === inputRow)
+  }
+
+  testBothCodegenAndInterpreted("variable-length types") {
+    val proj = createMutableProjection(variableLengthTypes)
+    val scalaValues = Seq("abc", BigDecimal(10), CalendarInterval.fromString("interval 1 day"),
+      Array[Byte](1, 2), Array("123", "456"), Map(1 -> "a", 2 -> "b"), Row(1, "a"),
+      new java.lang.Integer(5))
+    val inputRow = InternalRow.fromSeq(scalaValues.zip(variableLengthTypes).map {
+      case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v)
+    })
+    val projRow = proj(inputRow)
+    variableLengthTypes.zipWithIndex.foreach { case (dataType, index) =>
+      val toScala = CatalystTypeConverters.createToScalaConverter(dataType)
+      assert(toScala(projRow.get(index, dataType)) === toScala(inputRow.get(index, dataType)))
+    }
+  }
+
+  test("unsupported types for unsafe buffer") {
+    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString) {
+      val proj = createMutableProjection(Array(StringType))
+      val errMsg = intercept[IllegalArgumentException] {
+        proj.target(new UnsafeRow(1))
+      }.getMessage
+      assert(errMsg.contains("MutableProjection cannot use UnsafeRow for output data types:"))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 5a646d9..268372b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -26,26 +26,15 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.PlanTestBase
 import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{IntegerType, LongType, _}
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types.UTF8String
 
-class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase {
+class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase
+    with ExpressionEvalHelper {
 
   private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size)
 
-  private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = {
-    val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN)
-    for (fallbackMode <- modes) {
-      test(s"$name with $fallbackMode") {
-        withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
-          f
-        }
-      }
-    }
-  }
-
   testBothCodegenAndInterpreted("basic conversion with only primitive types") {
     val factory = UnsafeProjection
     val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)

http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
index 2909024..6f5ac22 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
@@ -54,3 +54,4 @@ ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C';
 -- DROP TEST TABLE
 DROP TABLE test_change;
 DROP TABLE partition_table;
+DROP VIEW global_temp.global_temp_view;

http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/core/src/test/resources/sql-tests/inputs/udaf.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql
index 2183ba2..58613a1 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql
@@ -11,3 +11,6 @@ SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1;
 CREATE FUNCTION udaf1 AS 'test.non.existent.udaf';
 
 SELECT default.udaf1(int_col1) as udaf1 from t1;
+
+DROP FUNCTION myDoubleAvg;
+DROP FUNCTION udaf1;

http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
index ff1ecbc..1146178 100644
--- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 33
+-- Number of queries: 34
 
 
 -- !query 0
@@ -313,3 +313,11 @@ DROP TABLE partition_table
 struct<>
 -- !query 32 output
 
+
+
+-- !query 33
+DROP VIEW global_temp.global_temp_view
+-- !query 33 schema
+struct<>
+-- !query 33 output
+

http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/core/src/test/resources/sql-tests/results/udaf.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out
index 87824ab..f4455bb 100644
--- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 6
+-- Number of queries: 8
 
 
 -- !query 0
@@ -52,3 +52,19 @@ struct<>
 -- !query 5 output
 org.apache.spark.sql.AnalysisException
 Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7
+
+
+-- !query 6
+DROP FUNCTION myDoubleAvg
+-- !query 6 schema
+struct<>
+-- !query 6 output
+
+
+
+-- !query 7
+DROP FUNCTION udaf1
+-- !query 7 schema
+struct<>
+-- !query 7 output
+

http://git-wip-us.apache.org/repos/asf/spark/blob/04046e54/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index 6ca3ac5..fd180ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -22,11 +22,13 @@ import java.util.{Locale, TimeZone}
 
 import scala.util.control.NonFatal
 
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile}
 import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.StructType
 
@@ -140,6 +142,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
     val input = fileToString(new File(testCase.inputFile))
 
     val (comments, code) = input.split("\n").partition(_.startsWith("--"))
+
+    // Runs all the tests on both codegen-only and interpreter modes
+    val codegenConfigSets = Array(CODEGEN_ONLY, NO_CODEGEN).map {
+      case codegenFactoryMode =>
+        Array(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactoryMode.toString)
+    }
     val configSets = {
       val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5))
       val configs = configLines.map(_.split(",").map { confAndValue =>
@@ -148,12 +156,25 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
       })
       // When we are regenerating the golden files we don't need to run all the configs as they
       // all need to return the same result
-      if (regenerateGoldenFiles && configs.nonEmpty) {
-        configs.take(1)
+      if (regenerateGoldenFiles) {
+        if (configs.nonEmpty) {
+          configs.take(1)
+        } else {
+          Array.empty[Array[(String, String)]]
+        }
       } else {
-        configs
+        if (configs.nonEmpty) {
+          codegenConfigSets.flatMap { codegenConfig =>
+            configs.map { config =>
+              config ++ codegenConfig
+            }
+          }
+        } else {
+          codegenConfigSets
+        }
       }
     }
+
     // List of SQL queries to run
     // note: this is not a robust way to split queries using semicolon, but works for now.
     val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org