You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/02/06 21:38:12 UTC

spark git commit: [SPARK-5595][SPARK-5603][SQL] Add a rule to do PreInsert type casting and field renaming and invalidating in memory cache after INSERT

Repository: spark
Updated Branches:
  refs/heads/master 0b7eb3f3b -> 3eccf29ce


[SPARK-5595][SPARK-5603][SQL] Add a rule to do PreInsert type casting and field renaming and invalidating in memory cache after INSERT

This PR adds a rule to Analyzer that will add preinsert data type casting and field renaming to the select clause in an `INSERT INTO/OVERWRITE` statement. Also, with the change of this PR, we always invalidate our in memory data cache after inserting into a BaseRelation.

cc marmbrus liancheng

Author: Yin Huai <yh...@databricks.com>

Closes #4373 from yhuai/insertFollowUp and squashes the following commits:

08237a7 [Yin Huai] Merge remote-tracking branch 'upstream/master' into insertFollowUp
316542e [Yin Huai] Doc update.
c9ccfeb [Yin Huai] Revert a unnecessary change.
84aecc4 [Yin Huai] Address comments.
1951fe1 [Yin Huai] Merge remote-tracking branch 'upstream/master'
c18da34 [Yin Huai] Invalidate cache after insert.
727f21a [Yin Huai] Preinsert casting and renaming.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3eccf29c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3eccf29c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3eccf29c

Branch: refs/heads/master
Commit: 3eccf29ce061559c86e6f7338851932fc89a9afd
Parents: 0b7eb3f
Author: Yin Huai <yh...@databricks.com>
Authored: Fri Feb 6 12:38:07 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Feb 6 12:38:07 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/SQLContext.scala |  6 +-
 .../apache/spark/sql/json/JSONRelation.scala    | 12 ++-
 .../spark/sql/sources/DataSourceStrategy.scala  |  2 +-
 .../org/apache/spark/sql/sources/commands.scala | 10 ++-
 .../apache/spark/sql/sources/interfaces.scala   | 16 ++++
 .../org/apache/spark/sql/sources/rules.scala    | 76 +++++++++++++++++++
 .../org/apache/spark/sql/json/JsonSuite.scala   | 25 ++++++
 .../spark/sql/sources/DataSourceTest.scala      |  6 +-
 .../spark/sql/sources/InsertIntoSuite.scala     | 80 ++++++++++++++++++++
 .../org/apache/spark/sql/hive/HiveContext.scala |  1 +
 10 files changed, 227 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 706ef6a..bf39906 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -91,7 +91,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
 
   @transient
   protected[sql] lazy val analyzer: Analyzer =
-    new Analyzer(catalog, functionRegistry, caseSensitive = true)
+    new Analyzer(catalog, functionRegistry, caseSensitive = true) {
+      override val extendedRules =
+        sources.PreInsertCastAndRename ::
+        Nil
+    }
 
   @transient
   protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer

http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index f27585d..c4e14c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -72,7 +72,6 @@ private[sql] case class JSONRelation(
     userSpecifiedSchema: Option[StructType])(
     @transient val sqlContext: SQLContext)
   extends TableScan with InsertableRelation {
-
   // TODO: Support partitioned JSON relation.
   private def baseRDD = sqlContext.sparkContext.textFile(path)
 
@@ -99,10 +98,21 @@ private[sql] case class JSONRelation(
             s"Unable to clear output directory ${filesystemPath.toString} prior"
               + s" to INSERT OVERWRITE a JSON table:\n${e.toString}")
       }
+      // Write the data.
       data.toJSON.saveAsTextFile(path)
+      // Right now, we assume that the schema is not changed. We will not update the schema.
+      // schema = data.schema
     } else {
       // TODO: Support INSERT INTO
       sys.error("JSON table only support INSERT OVERWRITE for now.")
     }
   }
+
+  override def hashCode(): Int = 41 * (41 + path.hashCode) + schema.hashCode()
+
+  override def equals(other: Any): Boolean = other match {
+    case that: JSONRelation =>
+      (this.path == that.path) && (this.schema == that.schema)
+    case _ => false
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index d23ffb8..624369a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -59,7 +59,7 @@ private[sql] object DataSourceStrategy extends Strategy {
       if (partition.nonEmpty) {
         sys.error(s"Insert into a partition is not allowed because $l is not partitioned.")
       }
-      execution.ExecutedCommand(InsertIntoRelation(t, query, overwrite)) :: Nil
+      execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil
 
     case _ => Nil
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index d7942dc..c9cd0e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -19,17 +19,21 @@ package org.apache.spark.sql.sources
 import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.sql.execution.{LogicalRDD, RunnableCommand}
 
-private[sql] case class InsertIntoRelation(
-    relation: InsertableRelation,
+private[sql] case class InsertIntoDataSource(
+    logicalRelation: LogicalRelation,
     query: LogicalPlan,
     overwrite: Boolean)
   extends RunnableCommand {
 
   override def run(sqlContext: SQLContext) = {
+    val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
     relation.insert(DataFrame(sqlContext, query), overwrite)
 
+    // Invalidate the cache.
+    sqlContext.cacheManager.invalidateCache(logicalRelation)
+
     Seq.empty[Row]
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 40fc1f2..a640ba5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -158,6 +158,22 @@ trait CatalystScan extends BaseRelation {
 }
 
 @DeveloperApi
+/**
+ * ::DeveloperApi::
+ * A BaseRelation that can be used to insert data into it through the insert method.
+ * If overwrite in insert method is true, the old data in the relation should be overwritten with
+ * the new data. If overwrite in insert method is false, the new data should be appended.
+ *
+ * InsertableRelation has the following three assumptions.
+ * 1. It assumes that the data (Rows in the DataFrame) provided to the insert method
+ * exactly matches the ordinal of fields in the schema of the BaseRelation.
+ * 2. It assumes that the schema of this relation will not be changed.
+ * Even if the insert method updates the schema (e.g. a relation of JSON or Parquet data may have a
+ * schema update after an insert operation), the new schema will not be used.
+ * 3. It assumes that fields of the data provided in the insert method are nullable.
+ * If a data source needs to check the actual nullability of a field, it needs to do it in the
+ * insert method.
+ */
 trait InsertableRelation extends BaseRelation {
   def insert(data: DataFrame, overwrite: Boolean): Unit
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
new file mode 100644
index 0000000..4ed22d3
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Alias}
+import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.DataType
+
+/**
+ * A rule to do pre-insert data type casting and field renaming. Before we insert into
+ * an [[InsertableRelation]], we will use this rule to make sure that
+ * the columns to be inserted have the correct data type and fields have the correct names.
+ * @param resolver The resolver used by the Analyzer.
+ */
+private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    plan.transform {
+      // Wait until children are resolved.
+      case p: LogicalPlan if !p.childrenResolved => p
+
+      // We are inserting into an InsertableRelation.
+      case i @ InsertIntoTable(
+      l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite) => {
+        // First, make sure the data to be inserted have the same number of fields with the
+        // schema of the relation.
+        if (l.output.size != child.output.size) {
+          sys.error(
+            s"$l requires that the query in the SELECT clause of the INSERT INTO/OVERWRITE " +
+              s"statement generates the same number of columns as its schema.")
+        }
+        castAndRenameChildOutput(i, l.output, child)
+      }
+    }
+  }
+
+  /** If necessary, cast data types and rename fields to the expected types and names. */
+  def castAndRenameChildOutput(
+      insertInto: InsertIntoTable,
+      expectedOutput: Seq[Attribute],
+      child: LogicalPlan) = {
+    val newChildOutput = expectedOutput.zip(child.output).map {
+      case (expected, actual) =>
+        val needCast = !DataType.equalsIgnoreNullability(expected.dataType, actual.dataType)
+        // We want to make sure the filed names in the data to be inserted exactly match
+        // names in the schema.
+        val needRename = expected.name != actual.name
+        (needCast, needRename) match {
+          case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)()
+          case (false, true) => Alias(actual, expected.name)()
+          case (_, _) => actual
+        }
+    }
+
+    if (newChildOutput == child.output) {
+      insertInto
+    } else {
+      insertInto.copy(child = Project(newChildOutput, child))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 1396c6b..926ba68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.Dsl._
 import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
+import org.apache.spark.sql.sources.LogicalRelation
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.test.TestSQLContext._
 import org.apache.spark.sql.types._
@@ -923,6 +924,30 @@ class JsonSuite extends QueryTest {
       sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"),
       Row(5, null)
     )
+  }
 
+  test("JSONRelation equality test") {
+    val relation1 =
+      JSONRelation("path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(null)
+    val logicalRelation1 = LogicalRelation(relation1)
+    val relation2 =
+      JSONRelation("path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(
+        org.apache.spark.sql.test.TestSQLContext)
+    val logicalRelation2 = LogicalRelation(relation2)
+    val relation3 =
+      JSONRelation("path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)))(null)
+    val logicalRelation3 = LogicalRelation(relation3)
+
+    assert(relation1 === relation2)
+    assert(logicalRelation1.sameResult(logicalRelation2),
+      s"$logicalRelation1 and $logicalRelation2 should be considered having the same result.")
+
+    assert(relation1 !== relation3)
+    assert(!logicalRelation1.sameResult(logicalRelation3),
+      s"$logicalRelation1 and $logicalRelation3 should be considered not having the same result.")
+
+    assert(relation2 !== relation3)
+    assert(!logicalRelation2.sameResult(logicalRelation3),
+      s"$logicalRelation2 and $logicalRelation3 should be considered not having the same result.")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 9626252..53f5f74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -28,7 +28,11 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
   implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) {
     @transient
     override protected[sql] lazy val analyzer: Analyzer =
-      new Analyzer(catalog, functionRegistry, caseSensitive = false)
+      new Analyzer(catalog, functionRegistry, caseSensitive = false) {
+        override val extendedRules =
+          PreInsertCastAndRename ::
+          Nil
+      }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
index f91cea6..36e504e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
@@ -63,6 +63,41 @@ class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll {
     )
   }
 
+  test("PreInsert casting and renaming") {
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, a * 4 FROM jt
+      """.stripMargin)
+
+    checkAnswer(
+      sql("SELECT a, b FROM jsonTable"),
+      (1 to 10).map(i => Row(i * 2, s"${i * 4}"))
+    )
+
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a * 4 AS A, a * 6 as c FROM jt
+      """.stripMargin)
+
+    checkAnswer(
+      sql("SELECT a, b FROM jsonTable"),
+      (1 to 10).map(i => Row(i * 4, s"${i * 6}"))
+    )
+  }
+
+  test("SELECT clause generating a different number of columns is not allowed.") {
+    val message = intercept[RuntimeException] {
+      sql(
+        s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt
+      """.stripMargin)
+    }.getMessage
+    assert(
+      message.contains("generates the same number of columns as its schema"),
+      "SELECT clause generating a different number of columns should not be not allowed."
+    )
+  }
+
   test("INSERT OVERWRITE a JSONRelation multiple times") {
     sql(
       s"""
@@ -93,4 +128,49 @@ class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll {
       """.stripMargin)
     }
   }
+
+  test("Caching")  {
+    // Cached Query Execution
+    cacheTable("jsonTable")
+    assertCached(sql("SELECT * FROM jsonTable"))
+    checkAnswer(
+      sql("SELECT * FROM jsonTable"),
+      (1 to 10).map(i => Row(i, s"str$i")))
+
+    assertCached(sql("SELECT a FROM jsonTable"))
+    checkAnswer(
+      sql("SELECT a FROM jsonTable"),
+      (1 to 10).map(Row(_)).toSeq)
+
+    assertCached(sql("SELECT a FROM jsonTable WHERE a < 5"))
+    checkAnswer(
+      sql("SELECT a FROM jsonTable WHERE a < 5"),
+      (1 to 4).map(Row(_)).toSeq)
+
+    assertCached(sql("SELECT a * 2 FROM jsonTable"))
+    checkAnswer(
+      sql("SELECT a * 2 FROM jsonTable"),
+      (1 to 10).map(i => Row(i * 2)).toSeq)
+
+    assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
+    checkAnswer(
+      sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
+      (2 to 10).map(i => Row(i, i - 1)).toSeq)
+
+    // Insert overwrite and keep the same schema.
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt
+      """.stripMargin)
+    // jsonTable should be recached.
+    assertCached(sql("SELECT * FROM jsonTable"))
+    // The cached data is the new data.
+    checkAnswer(
+      sql("SELECT a, b FROM jsonTable"),
+      sql("SELECT a * 2, b FROM jt").collect())
+
+    // Verify uncaching
+    uncacheTable("jsonTable")
+    assertCached(sql("SELECT * FROM jsonTable"), 0)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3eccf29c/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index d2371d4..ad37b7d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -324,6 +324,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
         catalog.PreInsertionCasts ::
         ExtractPythonUdfs ::
         ResolveUdtfsAlias ::
+        sources.PreInsertCastAndRename ::
         Nil
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org