You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/01 01:25:34 UTC

spark git commit: [SPARK-17528][SQL] data should be copied properly before saving into InternalRow

Repository: spark
Updated Branches:
  refs/heads/master fd1325522 -> 4eb41879c


[SPARK-17528][SQL] data should be copied properly before saving into InternalRow

## What changes were proposed in this pull request?

For performance reasons, `UnsafeRow.getString`, `getStruct`, etc. return a "pointer" that points to a memory region of this unsafe row. This makes the unsafe projection a little dangerous, because all of its output rows share one instance.

When we implement SQL operators, we should be careful to not cache the input rows because they may be produced by unsafe projection from child operator and thus its content may change overtime.

However, when we updating values of InternalRow(e.g. in mutable projection and safe projection), we only copy UTF8String, we should also copy InternalRow, ArrayData and MapData. This PR fixes this, and also fixes the copy of vairous InternalRow, ArrayData and MapData implementations.

## How was this patch tested?

new regression tests

Author: Wenchen Fan <we...@databricks.com>

Closes #18483 from cloud-fan/fix-copy.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4eb41879
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4eb41879
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4eb41879

Branch: refs/heads/master
Commit: 4eb41879ce774dec1d16b2281ab1fbf41f9d418a
Parents: fd13255
Author: Wenchen Fan <we...@databricks.com>
Authored: Sat Jul 1 09:25:29 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Sat Jul 1 09:25:29 2017 +0800

----------------------------------------------------------------------
 .../apache/spark/unsafe/types/UTF8String.java   |   6 ++
 .../apache/spark/sql/catalyst/InternalRow.scala |  27 ++++-
 .../spark/sql/catalyst/expressions/Cast.scala   |   2 +-
 .../expressions/SpecificInternalRow.scala       |  12 ---
 .../expressions/aggregate/collect.scala         |   2 +-
 .../expressions/aggregate/interfaces.scala      |   6 ++
 .../expressions/codegen/CodeGenerator.scala     |   6 +-
 .../codegen/GenerateSafeProjection.scala        |   2 -
 .../spark/sql/catalyst/expressions/rows.scala   |  23 ++--
 .../sql/catalyst/util/GenericArrayData.scala    |  10 +-
 .../scala/org/apache/spark/sql/RowTest.scala    |   4 -
 .../sql/catalyst/expressions/MapDataSuite.scala |  57 ----------
 .../codegen/GeneratedProjectionSuite.scala      |  36 +++++++
 .../sql/catalyst/util/ComplexDataSuite.scala    | 107 +++++++++++++++++++
 .../sql/execution/vectorized/ColumnarBatch.java |   2 +-
 .../SortBasedAggregationIterator.scala          |  15 +--
 .../columnar/GenerateColumnAccessor.scala       |   1 -
 .../execution/window/AggregateProcessor.scala   |   7 +-
 18 files changed, 212 insertions(+), 113 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 40b9fc9..9de4ca7 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -1088,6 +1088,12 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     return fromBytes(getBytes());
   }
 
+  public UTF8String copy() {
+    byte[] bytes = new byte[numBytes];
+    copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes);
+    return fromBytes(bytes);
+  }
+
   @Override
   public int compareTo(@Nonnull final UTF8String other) {
     int len = Math.min(numBytes, other.numBytes);

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 256f64e..2911064 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -18,7 +18,9 @@
 package org.apache.spark.sql.catalyst
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
 import org.apache.spark.sql.types.{DataType, Decimal, StructType}
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * An abstract class for row used internally in Spark SQL, which only contains the columns as
@@ -33,6 +35,10 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
 
   def setNullAt(i: Int): Unit
 
+  /**
+   * Updates the value at column `i`. Note that after updating, the given value will be kept in this
+   * row, and the caller side should guarantee that this value won't be changed afterwards.
+   */
   def update(i: Int, value: Any): Unit
 
   // default implementation (slow)
@@ -58,7 +64,15 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
   def copy(): InternalRow
 
   /** Returns true if there are any NULL values in this row. */
-  def anyNull: Boolean
+  def anyNull: Boolean = {
+    val len = numFields
+    var i = 0
+    while (i < len) {
+      if (isNullAt(i)) { return true }
+      i += 1
+    }
+    false
+  }
 
   /* ---------------------- utility methods for Scala ---------------------- */
 
@@ -94,4 +108,15 @@ object InternalRow {
 
   /** Returns an empty [[InternalRow]]. */
   val empty = apply()
+
+  /**
+   * Copies the given value if it's string/struct/array/map type.
+   */
+  def copyValue(value: Any): Any = value match {
+    case v: UTF8String => v.copy()
+    case v: InternalRow => v.copy()
+    case v: ArrayData => v.copy()
+    case v: MapData => v.copy()
+    case _ => value
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 43df19b..3862e64 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -1047,7 +1047,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
         final $rowClass $result = new $rowClass(${fieldsCasts.length});
         final InternalRow $tmpRow = $c;
         $fieldsEvalCode
-        $evPrim = $result.copy();
+        $evPrim = $result;
       """
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
index 74e0b46..75feaf6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 
 /**
@@ -220,17 +219,6 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen
 
   override def isNullAt(i: Int): Boolean = values(i).isNull
 
-  override def copy(): InternalRow = {
-    val newValues = new Array[Any](values.length)
-    var i = 0
-    while (i < values.length) {
-      newValues(i) = values(i).boxed
-      i += 1
-    }
-
-    new GenericInternalRow(newValues)
-  }
-
   override protected def genericGet(i: Int): Any = values(i).boxed
 
   override def update(ordinal: Int, value: Any) {

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 26cd9ab..0d2f988 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -52,7 +52,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
     // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
     // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
     if (value != null) {
-      buffer += value
+      buffer += InternalRow.copyValue(value)
     }
     buffer
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index fffcc7c..7af4901 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -317,6 +317,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
    * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`.
    *
    * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
+   *
+   * Note that, the input row may be produced by unsafe projection and it may not be safe to cache
+   * some fields of the input row, as the values can be changed unexpectedly.
    */
   def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit
 
@@ -326,6 +329,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
    *
    * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
    * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
+   *
+   * Note that, the input row may be produced by unsafe projection and it may not be safe to cache
+   * some fields of the input row, as the values can be changed unexpectedly.
    */
   def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 5158949..b15bf2c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -408,9 +408,11 @@ class CodegenContext {
     dataType match {
       case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
       case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
-      // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
-      case StringType => s"$row.update($ordinal, $value.clone())"
       case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
+      // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
+      // it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
+      case StringType | _: StructType | _: ArrayType | _: MapType =>
+        s"$row.update($ordinal, $value.copy())"
       case _ => s"$row.update($ordinal, $value)"
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index f708aef..dd0419d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -131,8 +131,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
     case s: StructType => createCodeForStruct(ctx, input, s)
     case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
     case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
-    // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
-    case StringType => ExprCode("", "false", s"$input.clone()")
     case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
     case _ => ExprCode("", "false", input)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 751b821..65539a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -50,16 +50,6 @@ trait BaseGenericInternalRow extends InternalRow {
   override def getMap(ordinal: Int): MapData = getAs(ordinal)
   override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
 
-  override def anyNull: Boolean = {
-    val len = numFields
-    var i = 0
-    while (i < len) {
-      if (isNullAt(i)) { return true }
-      i += 1
-    }
-    false
-  }
-
   override def toString: String = {
     if (numFields == 0) {
       "[empty row]"
@@ -79,6 +69,17 @@ trait BaseGenericInternalRow extends InternalRow {
     }
   }
 
+  override def copy(): GenericInternalRow = {
+    val len = numFields
+    val newValues = new Array[Any](len)
+    var i = 0
+    while (i < len) {
+      newValues(i) = InternalRow.copyValue(genericGet(i))
+      i += 1
+    }
+    new GenericInternalRow(newValues)
+  }
+
   override def equals(o: Any): Boolean = {
     if (!o.isInstanceOf[BaseGenericInternalRow]) {
       return false
@@ -206,6 +207,4 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow
   override def setNullAt(i: Int): Unit = { values(i) = null}
 
   override def update(i: Int, value: Any): Unit = { values(i) = value }
-
-  override def copy(): GenericInternalRow = this
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index dd660c8..9e39ed9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
 
   def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray))
 
-  override def copy(): ArrayData = new GenericArrayData(array.clone())
+  override def copy(): ArrayData = {
+    val newValues = new Array[Any](array.length)
+    var i = 0
+    while (i < array.length) {
+      newValues(i) = InternalRow.copyValue(array(i))
+      i += 1
+    }
+    new GenericArrayData(newValues)
+  }
 
   override def numElements(): Int = array.length
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
index c9c9599..25699de 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
@@ -121,10 +121,6 @@ class RowTest extends FunSpec with Matchers {
       externalRow should be theSameInstanceAs externalRow.copy()
     }
 
-    it("copy should return same ref for internal rows") {
-      internalRow should be theSameInstanceAs internalRow.copy()
-    }
-
     it("toSeq should not expose internal state for external rows") {
       val modifiedValues = modifyValues(externalRow.toSeq)
       externalRow.toSeq should not equal modifiedValues

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala
deleted file mode 100644
index 25a675a..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import scala.collection._
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
-import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
-import org.apache.spark.unsafe.types.UTF8String
-
-class MapDataSuite extends SparkFunSuite {
-
-  test("inequality tests") {
-    def u(str: String): UTF8String = UTF8String.fromString(str)
-
-    // test data
-    val testMap1 = Map(u("key1") -> 1)
-    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
-    val testMap3 = Map(u("key1") -> 1)
-    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)
-
-    // ArrayBasedMapData
-    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
-    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
-    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
-    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
-    assert(testArrayMap1 !== testArrayMap3)
-    assert(testArrayMap2 !== testArrayMap4)
-
-    // UnsafeMapData
-    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
-    val row = new GenericInternalRow(1)
-    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
-      row.update(0, map)
-      val unsafeRow = unsafeConverter.apply(row)
-      unsafeRow.getMap(0).copy
-    }
-    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
-    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
index 58ea5b9..0cd0d88 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -172,4 +172,40 @@ class GeneratedProjectionSuite extends SparkFunSuite {
     assert(unsafe1 === unsafe3)
     assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7))
   }
+
+  test("MutableProjection should not cache content from the input row") {
+    val mutableProj = GenerateMutableProjection.generate(
+      Seq(BoundReference(0, new StructType().add("i", StringType), true)))
+    val row = new GenericInternalRow(1)
+    mutableProj.target(row)
+
+    val unsafeProj = GenerateUnsafeProjection.generate(
+      Seq(BoundReference(0, new StructType().add("i", StringType), true)))
+    val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a"))))
+
+    mutableProj.apply(unsafeRow)
+    assert(row.getStruct(0, 1).getString(0) == "a")
+
+    // Even if the input row of the mutable projection has been changed, the target mutable row
+    // should keep same.
+    unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b"))))
+    assert(row.getStruct(0, 1).getString(0).toString == "a")
+  }
+
+  test("SafeProjection should not cache content from the input row") {
+    val safeProj = GenerateSafeProjection.generate(
+      Seq(BoundReference(0, new StructType().add("i", StringType), true)))
+
+    val unsafeProj = GenerateUnsafeProjection.generate(
+      Seq(BoundReference(0, new StructType().add("i", StringType), true)))
+    val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a"))))
+
+    val row = safeProj.apply(unsafeRow)
+    assert(row.getStruct(0, 1).getString(0) == "a")
+
+    // Even if the input row of the mutable projection has been changed, the target mutable row
+    // should keep same.
+    unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b"))))
+    assert(row.getStruct(0, 1).getString(0).toString == "a")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala
new file mode 100644
index 0000000..9d28591
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import scala.collection._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+class ComplexDataSuite extends SparkFunSuite {
+  def utf8(str: String): UTF8String = UTF8String.fromString(str)
+
+  test("inequality tests for MapData") {
+    // test data
+    val testMap1 = Map(utf8("key1") -> 1)
+    val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2)
+    val testMap3 = Map(utf8("key1") -> 1)
+    val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2)
+
+    // ArrayBasedMapData
+    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
+    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
+    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
+    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
+    assert(testArrayMap1 !== testArrayMap3)
+    assert(testArrayMap2 !== testArrayMap4)
+
+    // UnsafeMapData
+    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
+    val row = new GenericInternalRow(1)
+    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
+      row.update(0, map)
+      val unsafeRow = unsafeConverter.apply(row)
+      unsafeRow.getMap(0).copy
+    }
+    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
+    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
+  }
+
+  test("GenericInternalRow.copy return a new instance that is independent from the old one") {
+    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
+    val unsafeRow = project.apply(InternalRow(utf8("a")))
+
+    val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0)))
+    val copiedGenericRow = genericRow.copy()
+    assert(copiedGenericRow.getString(0) == "a")
+    project.apply(InternalRow(UTF8String.fromString("b")))
+    // The copied internal row should not be changed externally.
+    assert(copiedGenericRow.getString(0) == "a")
+  }
+
+  test("SpecificMutableRow.copy return a new instance that is independent from the old one") {
+    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
+    val unsafeRow = project.apply(InternalRow(utf8("a")))
+
+    val mutableRow = new SpecificInternalRow(Seq(StringType))
+    mutableRow(0) = unsafeRow.getUTF8String(0)
+    val copiedMutableRow = mutableRow.copy()
+    assert(copiedMutableRow.getString(0) == "a")
+    project.apply(InternalRow(UTF8String.fromString("b")))
+    // The copied internal row should not be changed externally.
+    assert(copiedMutableRow.getString(0) == "a")
+  }
+
+  test("GenericArrayData.copy return a new instance that is independent from the old one") {
+    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
+    val unsafeRow = project.apply(InternalRow(utf8("a")))
+
+    val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0)))
+    val copiedGenericArray = genericArray.copy()
+    assert(copiedGenericArray.getUTF8String(0).toString == "a")
+    project.apply(InternalRow(UTF8String.fromString("b")))
+    // The copied array data should not be changed externally.
+    assert(copiedGenericArray.getUTF8String(0).toString == "a")
+  }
+
+  test("copy on nested complex type") {
+    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
+    val unsafeRow = project.apply(InternalRow(utf8("a")))
+
+    val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0))))
+    val copied = arrayOfRow.copy()
+    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
+    project.apply(InternalRow(UTF8String.fromString("b")))
+    // The copied data should not be changed externally.
+    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index e23a643..34dc3af 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -149,7 +149,7 @@ public final class ColumnarBatch {
           } else if (dt instanceof DoubleType) {
             row.setDouble(i, getDouble(i));
           } else if (dt instanceof StringType) {
-            row.update(i, getUTF8String(i));
+            row.update(i, getUTF8String(i).copy());
           } else if (dt instanceof BinaryType) {
             row.update(i, getBinary(i));
           } else if (dt instanceof DecimalType) {

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index bea2dce..a5a444b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -86,17 +86,6 @@ class SortBasedAggregationIterator(
   // The aggregation buffer used by the sort-based aggregation.
   private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer
 
-  // This safe projection is used to turn the input row into safe row. This is necessary
-  // because the input row may be produced by unsafe projection in child operator and all the
-  // produced rows share one byte array. However, when we update the aggregate buffer according to
-  // the input row, we may cache some values from input row, e.g. `Max` will keep the max value from
-  // input row via MutableProjection, `CollectList` will keep all values in an array via
-  // ImperativeAggregate framework. These values may get changed unexpectedly if the underlying
-  // unsafe projection update the shared byte array. By applying a safe projection to the input row,
-  // we can cut down the connection from input row to the shared byte array, and thus it's safe to
-  // cache values from input row while updating the aggregation buffer.
-  private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType))
-
   protected def initialize(): Unit = {
     if (inputIterator.hasNext) {
       initializeBuffer(sortBasedAggregationBuffer)
@@ -119,7 +108,7 @@ class SortBasedAggregationIterator(
     // We create a variable to track if we see the next group.
     var findNextPartition = false
     // firstRowInNextGroup is the first row of this group. We first process it.
-    processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup))
+    processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
 
     // The search will stop when we see the next group or there is no
     // input row left in the iter.
@@ -130,7 +119,7 @@ class SortBasedAggregationIterator(
 
       // Check if the current row belongs the current input row.
       if (currentGroupingKey == groupingKey) {
-        processRow(sortBasedAggregationBuffer, safeProj(currentRow))
+        processRow(sortBasedAggregationBuffer, currentRow)
       } else {
         // We find a new group.
         findNextPartition = true

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index d3fa0dc..fc977f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -56,7 +56,6 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalR
   // all other methods inherited from GenericMutableRow are not need
   override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException
   override def numFields: Int = throw new UnsupportedOperationException
-  override def copy(): InternalRow = throw new UnsupportedOperationException
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4eb41879/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
index 2195c6e..bc141b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala
@@ -145,13 +145,10 @@ private[window] final class AggregateProcessor(
 
   /** Update the buffer. */
   def update(input: InternalRow): Unit = {
-    // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that
-    // MutableProjection makes copies of the complex input objects it buffer.
-    val copy = input.copy()
-    updateProjection(join(buffer, copy))
+    updateProjection(join(buffer, input))
     var i = 0
     while (i < numImperatives) {
-      imperatives(i).update(buffer, copy)
+      imperatives(i).update(buffer, input)
       i += 1
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org