You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2019/03/10 06:01:00 UTC

[spark] branch branch-2.4 updated: [SPARK-27097][CHERRY-PICK 2.4] Avoid embedding platform-dependent offsets literally in whole-stage generated code

This is an automated email from the ASF dual-hosted git repository.

dbtsai pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new a017a1c  [SPARK-27097][CHERRY-PICK 2.4] Avoid embedding platform-dependent offsets literally in whole-stage generated code
a017a1c is described below

commit a017a1c1afc2e49e61df7c1d23c9c5058708fac8
Author: Kris Mok <kr...@databricks.com>
AuthorDate: Sun Mar 10 06:00:36 2019 +0000

    [SPARK-27097][CHERRY-PICK 2.4] Avoid embedding platform-dependent offsets literally in whole-stage generated code
    
    ## What changes were proposed in this pull request?
    
    Spark SQL performs whole-stage code generation to speed up query execution. There are two steps to it:
    - Java source code is generated from the physical query plan on the driver. A single version of the source code is generated from a query plan, and sent to all executors.
      - It's compiled to bytecode on the driver to catch compilation errors before sending to executors, but currently only the generated source code gets sent to the executors. The bytecode compilation is for fail-fast only.
    - Executors receive the generated source code and compile to bytecode, then the query runs like a hand-written Java program.
    
    In this model, there's an implicit assumption about the driver and executors being run on similar platforms. Some code paths accidentally embedded platform-dependent object layout information into the generated code, such as:
    ```java
    Platform.putLong(buffer, /* offset */ 24, /* value */ 1);
    ```
    This code expects a field to be at offset +24 of the `buffer` object, and sets a value to that field.
    But whole-stage code generation generally uses platform-dependent information from the driver. If the object layout is significantly different on the driver and executors, the generated code can be reading/writing to wrong offsets on the executors, causing all kinds of data corruption.
    
    One code pattern that leads to such problem is the use of `Platform.XXX` constants in generated code, e.g. `Platform.BYTE_ARRAY_OFFSET`.
    
    Bad:
    ```scala
    val baseOffset = Platform.BYTE_ARRAY_OFFSET
    // codegen template:
    s"Platform.putLong($buffer, $baseOffset, $value);"
    ```
    This will embed the value of `Platform.BYTE_ARRAY_OFFSET` on the driver into the generated code.
    
    Good:
    ```scala
    val baseOffset = "Platform.BYTE_ARRAY_OFFSET"
    // codegen template:
    s"Platform.putLong($buffer, $baseOffset, $value);"
    ```
    This will generate the offset symbolically -- `Platform.putLong(buffer, Platform.BYTE_ARRAY_OFFSET, value)`, which will be able to pick up the correct value on the executors.
    
    Caveat: these offset constants are declared as runtime-initialized `static final` in Java, so they're not compile-time constants from the Java language's perspective. It does lead to a slightly increased size of the generated code, but this is necessary for correctness.
    
    NOTE: there can be other patterns that generate platform-dependent code on the driver which is invalid on the executors. e.g. if the endianness is different between the driver and the executors, and if some generated code makes strong assumption about endianness, it would also be problematic.
    
    ## How was this patch tested?
    
    Added a new test suite `WholeStageCodegenSparkSubmitSuite`. This test suite needs to set the driver's extraJavaOptions to force the driver and executor use different Java object layouts, so it's run as an actual SparkSubmit job.
    
    Authored-by: Kris Mok <kris.mokdatabricks.com>
    
    Closes #24032 from gatorsmile/testFailure.
    
    Lead-authored-by: Kris Mok <kr...@databricks.com>
    Co-authored-by: gatorsmile <ga...@gmail.com>
    Signed-off-by: DB Tsai <d_...@apple.com>
---
 .../unsafe/sort/UnsafeSorterSpillReader.java       |  3 +-
 .../codegen/GenerateUnsafeRowJoiner.scala          | 20 ++---
 .../expressions/collectionOperations.scala         |  6 +-
 .../catalyst/expressions/complexTypeCreator.scala  |  2 +-
 .../WholeStageCodegenSparkSubmitSuite.scala        | 93 ++++++++++++++++++++++
 5 files changed, 108 insertions(+), 16 deletions(-)

diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index fb179d0..bfca670 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -51,7 +51,6 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
 
   private byte[] arr = new byte[1024 * 1024];
   private Object baseObject = arr;
-  private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
   private final TaskContext taskContext = TaskContext.get();
 
   public UnsafeSorterSpillReader(
@@ -132,7 +131,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
 
   @Override
   public long getBaseOffset() {
-    return baseOffset;
+    return Platform.BYTE_ARRAY_OFFSET;
   }
 
   @Override
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index febf7b0..070570d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -55,7 +55,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
 
   def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
     val ctx = new CodegenContext
-    val offset = Platform.BYTE_ARRAY_OFFSET
+    val offset = "Platform.BYTE_ARRAY_OFFSET"
     val getLong = "Platform.getLong"
     val putLong = "Platform.putLong"
 
@@ -92,7 +92,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
           s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})"
         }
       }
-      s"$putLong(buf, ${offset + i * 8}, $bits);\n"
+      s"$putLong(buf, $offset + ${i * 8}, $bits);\n"
     }
 
     val copyBitsets = ctx.splitExpressions(
@@ -102,12 +102,12 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
                   ("java.lang.Object", "obj2") :: ("long", "offset2") :: Nil)
 
     // --------------------- copy fixed length portion from row 1 ----------------------- //
-    var cursor = offset + outputBitsetWords * 8
+    var cursor = outputBitsetWords * 8
     val copyFixedLengthRow1 = s"""
        |// Copy fixed length data for row1
        |Platform.copyMemory(
        |  obj1, offset1 + ${bitset1Words * 8},
-       |  buf, $cursor,
+       |  buf, $offset + $cursor,
        |  ${schema1.size * 8});
      """.stripMargin
     cursor += schema1.size * 8
@@ -117,7 +117,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
        |// Copy fixed length data for row2
        |Platform.copyMemory(
        |  obj2, offset2 + ${bitset2Words * 8},
-       |  buf, $cursor,
+       |  buf, $offset + $cursor,
        |  ${schema2.size * 8});
      """.stripMargin
     cursor += schema2.size * 8
@@ -129,7 +129,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
        |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1;
        |Platform.copyMemory(
        |  obj1, offset1 + ${(bitset1Words + schema1.size) * 8},
-       |  buf, $cursor,
+       |  buf, $offset + $cursor,
        |  numBytesVariableRow1);
      """.stripMargin
 
@@ -140,7 +140,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
        |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2;
        |Platform.copyMemory(
        |  obj2, offset2 + ${(bitset2Words + schema2.size) * 8},
-       |  buf, $cursor + numBytesVariableRow1,
+       |  buf, $offset + $cursor + numBytesVariableRow1,
        |  numBytesVariableRow2);
      """.stripMargin
 
@@ -161,7 +161,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
           } else {
             s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
           }
-        val cursor = offset + outputBitsetWords * 8 + i * 8
+        val cursor = outputBitsetWords * 8 + i * 8
         // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's
         // output as a de-facto specification for the internal layout of data.
         //
@@ -198,9 +198,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
         // Thus it is safe to perform `existingOffset != 0` checks here in the place of
         // more expensive null-bit checks.
         s"""
-           |existingOffset = $getLong(buf, $cursor);
+           |existingOffset = $getLong(buf, $offset + $cursor);
            |if (existingOffset != 0) {
-           |    $putLong(buf, $cursor, existingOffset + ($shift << 32));
+           |    $putLong(buf, $offset + $cursor, existingOffset + ($shift << 32));
            |}
          """.stripMargin
       }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 0a39b43..7ff4cd3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -666,7 +666,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
     val keyArrayData = ctx.freshName("keyArrayData")
     val valueArrayData = ctx.freshName("valueArrayData")
 
-    val baseOffset = Platform.BYTE_ARRAY_OFFSET
+    val baseOffset = "Platform.BYTE_ARRAY_OFFSET"
     val keySize = dataType.keyType.defaultSize
     val valueSize = dataType.valueType.defaultSize
     val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)"
@@ -696,8 +696,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
        |  final byte[] $data = new byte[(int)$byteArraySize];
        |  UnsafeMapData $unsafeMapData = new UnsafeMapData();
        |  Platform.putLong($data, $baseOffset, $keySectionSize);
-       |  Platform.putLong($data, ${baseOffset + 8}, $numEntries);
-       |  Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries);
+       |  Platform.putLong($data, $baseOffset + 8, $numEntries);
+       |  Platform.putLong($data, $baseOffset + 8 + $keySectionSize, $numEntries);
        |  $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize);
        |  ArrayData $keyArrayData = $unsafeMapData.keyArray();
        |  ArrayData $valueArrayData = $unsafeMapData.valueArray();
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index fd8b5e9..16b4a1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -124,7 +124,7 @@ private [sql] object GenArrayData {
       val unsafeArraySizeInBytes =
         UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
         ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
-      val baseOffset = Platform.BYTE_ARRAY_OFFSET
+      val baseOffset = "Platform.BYTE_ARRAY_OFFSET"
 
       val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
       val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala
new file mode 100644
index 0000000..f985386
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSparkSubmitSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers}
+import org.scalatest.concurrent.TimeLimits
+
+import org.apache.spark.{SparkFunSuite, TestUtils}
+import org.apache.spark.deploy.SparkSubmitSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{LocalSparkSession, QueryTest, Row, SparkSession}
+import org.apache.spark.sql.functions.{array, col, count, lit}
+import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.ResetSystemProperties
+
+// Due to the need to set driver's extraJavaOptions, this test needs to use actual SparkSubmit.
+class WholeStageCodegenSparkSubmitSuite extends SparkFunSuite
+  with Matchers
+  with BeforeAndAfterEach
+  with ResetSystemProperties {
+
+  test("Generated code on driver should not embed platform-specific constant") {
+    val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+
+    // HotSpot JVM specific: Set up a local cluster with the driver/executor using mismatched
+    // settings of UseCompressedOops JVM option.
+    val argsForSparkSubmit = Seq(
+      "--class", WholeStageCodegenSparkSubmitSuite.getClass.getName.stripSuffix("$"),
+      "--master", "local-cluster[1,1,1024]",
+      "--driver-memory", "1g",
+      "--conf", "spark.ui.enabled=false",
+      "--conf", "spark.master.rest.enabled=false",
+      "--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedOops",
+      "--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedOops",
+      unusedJar.toString)
+    SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..")
+  }
+}
+
+object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging {
+
+  var spark: SparkSession = _
+
+  def main(args: Array[String]): Unit = {
+    TestUtils.configTestLog4j("INFO")
+
+    spark = SparkSession.builder().getOrCreate()
+
+    // Make sure the test is run where the driver and the executors uses different object layouts
+    val driverArrayHeaderSize = Platform.BYTE_ARRAY_OFFSET
+    val executorArrayHeaderSize =
+      spark.sparkContext.range(0, 1).map(_ => Platform.BYTE_ARRAY_OFFSET).collect.head.toInt
+    assert(driverArrayHeaderSize > executorArrayHeaderSize)
+
+    val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v")
+      .groupBy(array(col("v"))).agg(count(col("*")))
+    val plan = df.queryExecution.executedPlan
+    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
+
+    val expectedAnswer =
+      Row(Array(0), 7178) ::
+        Row(Array(1), 7178) ::
+        Row(Array(2), 7178) ::
+        Row(Array(3), 7177) ::
+        Row(Array(4), 7177) ::
+        Row(Array(5), 7177) ::
+        Row(Array(6), 7177) ::
+        Row(Array(7), 7177) ::
+        Row(Array(8), 7177) ::
+        Row(Array(9), 7177) :: Nil
+    val result = df.collect
+    QueryTest.sameRows(result.toSeq, expectedAnswer) match {
+      case Some(errMsg) => fail(errMsg)
+      case _ =>
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org