You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/11/11 01:00:48 UTC

spark git commit: [SPARK-18185] Fix all forms of INSERT / OVERWRITE TABLE for Datasource tables

Repository: spark
Updated Branches:
  refs/heads/master e0deee1f7 -> a3356343c


[SPARK-18185] Fix all forms of INSERT / OVERWRITE TABLE for Datasource tables

## What changes were proposed in this pull request?

As of current 2.1, INSERT OVERWRITE with dynamic partitions against a Datasource table will overwrite the entire table instead of only the partitions matching the static keys, as in Hive. It also doesn't respect custom partition locations.

This PR adds support for all these operations to Datasource tables managed by the Hive metastore. It is implemented as follows
- During planning time, the full set of partitions affected by an INSERT or OVERWRITE command is read from the Hive metastore.
- The planner identifies any partitions with custom locations and includes this in the write task metadata.
- FileFormatWriter tasks refer to this custom locations map when determining where to write for dynamic partition output.
- When the write job finishes, the set of written partitions is compared against the initial set of matched partitions, and the Hive metastore is updated to reflect the newly added / removed partitions.

It was necessary to introduce a method for staging files with absolute output paths to `FileCommitProtocol`. These files are not handled by the Hadoop output committer but are moved to their final locations when the job commits.

The overwrite behavior of legacy Datasource tables is also changed: no longer will the entire table be overwritten if a partial partition spec is present.

cc cloud-fan yhuai

## How was this patch tested?

Unit tests, existing tests.

Author: Eric Liang <ek...@databricks.com>
Author: Wenchen Fan <we...@databricks.com>

Closes #15814 from ericl/sc-5027.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a3356343
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a3356343
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a3356343

Branch: refs/heads/master
Commit: a3356343cbf58b930326f45721fb4ecade6f8029
Parents: e0deee1
Author: Eric Liang <ek...@databricks.com>
Authored: Thu Nov 10 17:00:43 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Nov 10 17:00:43 2016 -0800

----------------------------------------------------------------------
 .../spark/internal/io/FileCommitProtocol.scala  |  15 ++
 .../io/HadoopMapReduceCommitProtocol.scala      |  63 +++++++-
 .../spark/sql/catalyst/parser/AstBuilder.scala  |  12 +-
 .../plans/logical/basicLogicalOperators.scala   |  10 +-
 .../sql/catalyst/parser/PlanParserSuite.scala   |   4 +-
 .../sql/execution/datasources/DataSource.scala  |  20 +--
 .../datasources/DataSourceStrategy.scala        |  94 +++++++----
 .../datasources/FileFormatWriter.scala          |  26 ++-
 .../InsertIntoHadoopFsRelationCommand.scala     |  61 ++++++-
 .../datasources/PartitioningUtils.scala         |  10 ++
 .../execution/streaming/FileStreamSink.scala    |   2 +-
 .../streaming/ManifestFileCommitProtocol.scala  |   6 +
 .../PartitionProviderCompatibilitySuite.scala   | 161 ++++++++++++++++++-
 13 files changed, 411 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
index fb80205..afd2250 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
@@ -82,10 +82,25 @@ abstract class FileCommitProtocol {
    *
    * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest
    * are left to the commit protocol implementation to decide.
+   *
+   * Important: it is the caller's responsibility to add uniquely identifying content to "ext"
+   * if a task is going to write out multiple files to the same dir. The file commit protocol only
+   * guarantees that files written by different tasks will not conflict.
    */
   def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String
 
   /**
+   * Similar to newTaskTempFile(), but allows files to committed to an absolute output location.
+   * Depending on the implementation, there may be weaker guarantees around adding files this way.
+   *
+   * Important: it is the caller's responsibility to add uniquely identifying content to "ext"
+   * if a task is going to write out multiple files to the same dir. The file commit protocol only
+   * guarantees that files written by different tasks will not conflict.
+   */
+  def newTaskTempFileAbsPath(
+      taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String
+
+  /**
    * Commits a task after the writes succeed. Must be called on the executors when running tasks.
    */
   def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
index 6b0bcb8..b2d9b8d 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.internal.io
 
-import java.util.Date
+import java.util.{Date, UUID}
+
+import scala.collection.mutable
 
 import org.apache.hadoop.conf.Configurable
 import org.apache.hadoop.fs.Path
@@ -42,6 +44,19 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
   /** OutputCommitter from Hadoop is not serializable so marking it transient. */
   @transient private var committer: OutputCommitter = _
 
+  /**
+   * Tracks files staged by this task for absolute output paths. These outputs are not managed by
+   * the Hadoop OutputCommitter, so we must move these to their final locations on job commit.
+   *
+   * The mapping is from the temp output path to the final desired output path of the file.
+   */
+  @transient private var addedAbsPathFiles: mutable.Map[String, String] = null
+
+  /**
+   * The staging directory for all files committed with absolute output paths.
+   */
+  private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId)
+
   protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
     val format = context.getOutputFormatClass.newInstance()
     // If OutputFormat is Configurable, we should set conf to it.
@@ -54,11 +69,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
 
   override def newTaskTempFile(
       taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
-    // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
-    // Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
-    // the file name is fine and won't overflow.
-    val split = taskContext.getTaskAttemptID.getTaskID.getId
-    val filename = f"part-$split%05d-$jobId$ext"
+    val filename = getFilename(taskContext, ext)
 
     val stagingDir: String = committer match {
       // For FileOutputCommitter it has its own staging path called "work path".
@@ -73,6 +84,28 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
     }
   }
 
+  override def newTaskTempFileAbsPath(
+      taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
+    val filename = getFilename(taskContext, ext)
+    val absOutputPath = new Path(absoluteDir, filename).toString
+
+    // Include a UUID here to prevent file collisions for one task writing to different dirs.
+    // In principle we could include hash(absoluteDir) instead but this is simpler.
+    val tmpOutputPath = new Path(
+      absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString
+
+    addedAbsPathFiles(tmpOutputPath) = absOutputPath
+    tmpOutputPath
+  }
+
+  private def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
+    // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
+    // Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
+    // the file name is fine and won't overflow.
+    val split = taskContext.getTaskAttemptID.getTaskID.getId
+    f"part-$split%05d-$jobId$ext"
+  }
+
   override def setupJob(jobContext: JobContext): Unit = {
     // Setup IDs
     val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0)
@@ -93,26 +126,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
 
   override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = {
     committer.commitJob(jobContext)
+    val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]])
+      .foldLeft(Map[String, String]())(_ ++ _)
+    logDebug(s"Committing files staged for absolute locations $filesToMove")
+    val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
+    for ((src, dst) <- filesToMove) {
+      fs.rename(new Path(src), new Path(dst))
+    }
+    fs.delete(absPathStagingDir, true)
   }
 
   override def abortJob(jobContext: JobContext): Unit = {
     committer.abortJob(jobContext, JobStatus.State.FAILED)
+    val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
+    fs.delete(absPathStagingDir, true)
   }
 
   override def setupTask(taskContext: TaskAttemptContext): Unit = {
     committer = setupCommitter(taskContext)
     committer.setupTask(taskContext)
+    addedAbsPathFiles = mutable.Map[String, String]()
   }
 
   override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
     val attemptId = taskContext.getTaskAttemptID
     SparkHadoopMapRedUtil.commitTask(
       committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId)
-    EmptyTaskCommitMessage
+    new TaskCommitMessage(addedAbsPathFiles.toMap)
   }
 
   override def abortTask(taskContext: TaskAttemptContext): Unit = {
     committer.abortTask(taskContext)
+    // best effort cleanup of other staged files
+    for ((src, _) <- addedAbsPathFiles) {
+      val tmp = new Path(src)
+      tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false)
+    }
   }
 
   /** Whether we are using a direct output committer */

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 2c4db0d..3fa7bf1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -172,24 +172,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
     val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
     val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
 
-    val dynamicPartitionKeys = partitionKeys.filter(_._2.isEmpty)
+    val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
     if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) {
       throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " +
         "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx)
     }
     val overwrite = ctx.OVERWRITE != null
-    val overwritePartition =
-      if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) {
-        Some(partitionKeys.map(t => (t._1, t._2.get)))
-      } else {
-        None
-      }
+    val staticPartitionKeys: Map[String, String] =
+      partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get))
 
     InsertIntoTable(
       UnresolvedRelation(tableIdent, None),
       partitionKeys,
       query,
-      OverwriteOptions(overwrite, overwritePartition),
+      OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty),
       ctx.EXISTS != null)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index dcae7b0..4dcc288 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -349,13 +349,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
  * Options for writing new data into a table.
  *
  * @param enabled whether to overwrite existing data in the table.
- * @param specificPartition only data in the specified partition will be overwritten.
+ * @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions
+ *                            that match this partial partition spec. If empty, all partitions
+ *                            will be overwritten.
  */
 case class OverwriteOptions(
     enabled: Boolean,
-    specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) {
-  if (specificPartition.isDefined) {
-    assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.")
+    staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) {
+  if (staticPartitionKeys.nonEmpty) {
+    assert(enabled, "Overwrite must be enabled when specifying specific partitions.")
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 5f0f6ee..9aae520 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -185,9 +185,9 @@ class PlanParserSuite extends PlanTest {
         OverwriteOptions(
           overwrite,
           if (overwrite && partition.nonEmpty) {
-            Some(partition.map(kv => (kv._1, kv._2.get)))
+            partition.map(kv => (kv._1, kv._2.get))
           } else {
-            None
+            Map.empty
           }),
         ifNotExists)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 5d66394..65422f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -417,15 +417,17 @@ case class DataSource(
         // will be adjusted within InsertIntoHadoopFsRelation.
         val plan =
           InsertIntoHadoopFsRelationCommand(
-            outputPath,
-            columns,
-            bucketSpec,
-            format,
-            _ => Unit, // No existing table needs to be refreshed.
-            options,
-            data.logicalPlan,
-            mode,
-            catalogTable)
+            outputPath = outputPath,
+            staticPartitionKeys = Map.empty,
+            customPartitionLocations = Map.empty,
+            partitionColumns = columns,
+            bucketSpec = bucketSpec,
+            fileFormat = format,
+            refreshFunction = _ => Unit, // No existing table needs to be refreshed.
+            options = options,
+            query = data.logicalPlan,
+            mode = mode,
+            catalogTable = catalogTable)
         sparkSession.sessionState.executePlan(plan).toRdd
         // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
         copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 739aeac..4f19a2d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -24,10 +24,10 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
 import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation}
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions._
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
-import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec}
+import org.apache.spark.sql.execution.command._
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -182,41 +182,53 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
           "Cannot overwrite a path that is also being read from.")
       }
 
-      val overwritingSinglePartition =
-        overwrite.specificPartition.isDefined &&
+      val partitionSchema = query.resolve(
+        t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
+      val partitionsTrackedByCatalog =
         t.sparkSession.sessionState.conf.manageFilesourcePartitions &&
+        l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
         l.catalogTable.get.tracksPartitionsInCatalog
 
-      val effectiveOutputPath = if (overwritingSinglePartition) {
-        val partition = t.sparkSession.sessionState.catalog.getPartition(
-          l.catalogTable.get.identifier, overwrite.specificPartition.get)
-        new Path(partition.location)
-      } else {
-        outputPath
-      }
-
-      val effectivePartitionSchema = if (overwritingSinglePartition) {
-        Nil
-      } else {
-        query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
+      var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
+      var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
+
+      // When partitions are tracked by the catalog, compute all custom partition locations that
+      // may be relevant to the insertion job.
+      if (partitionsTrackedByCatalog) {
+        val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions(
+          l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys))
+        initialMatchingPartitions = matchingPartitions.map(_.spec)
+        customPartitionLocations = getCustomPartitionLocations(
+          t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions)
       }
 
+      // Callback for updating metastore partition metadata after the insertion job completes.
+      // TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand
       def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
-        if (l.catalogTable.isDefined && updatedPartitions.nonEmpty &&
-            l.catalogTable.get.partitionColumnNames.nonEmpty &&
-            l.catalogTable.get.tracksPartitionsInCatalog) {
-          val metastoreUpdater = AlterTableAddPartitionCommand(
-            l.catalogTable.get.identifier,
-            updatedPartitions.map(p => (p, None)),
-            ifNotExists = true)
-          metastoreUpdater.run(t.sparkSession)
+        if (partitionsTrackedByCatalog) {
+          val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions
+          if (newPartitions.nonEmpty) {
+            AlterTableAddPartitionCommand(
+              l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)),
+              ifNotExists = true).run(t.sparkSession)
+          }
+          if (overwrite.enabled) {
+            val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
+            if (deletedPartitions.nonEmpty) {
+              AlterTableDropPartitionCommand(
+                l.catalogTable.get.identifier, deletedPartitions.toSeq,
+                ifExists = true, purge = true).run(t.sparkSession)
+            }
+          }
         }
         t.location.refresh()
       }
 
       val insertCmd = InsertIntoHadoopFsRelationCommand(
-        effectiveOutputPath,
-        effectivePartitionSchema,
+        outputPath,
+        if (overwrite.enabled) overwrite.staticPartitionKeys else Map.empty,
+        customPartitionLocations,
+        partitionSchema,
         t.bucketSpec,
         t.fileFormat,
         refreshPartitionsCallback,
@@ -227,6 +239,34 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
 
       insertCmd
   }
+
+  /**
+   * Given a set of input partitions, returns those that have locations that differ from the
+   * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by
+   * the user.
+   *
+   * @return a mapping from partition specs to their custom locations
+   */
+  private def getCustomPartitionLocations(
+      spark: SparkSession,
+      table: CatalogTable,
+      basePath: Path,
+      partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
+    val hadoopConf = spark.sessionState.newHadoopConf
+    val fs = basePath.getFileSystem(hadoopConf)
+    val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+    partitions.flatMap { p =>
+      val defaultLocation = qualifiedBasePath.suffix(
+        "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString
+      val catalogLocation = new Path(p.location).makeQualified(
+        fs.getUri, fs.getWorkingDirectory).toString
+      if (catalogLocation != defaultLocation) {
+        Some(p.spec -> catalogLocation)
+      } else {
+        None
+      }
+    }.toMap
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 69b3fa6..4e4b0e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -47,6 +47,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 /** A helper object for writing FileFormat data out to a location. */
 object FileFormatWriter extends Logging {
 
+  /** Describes how output files should be placed in the filesystem. */
+  case class OutputSpec(
+    outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String])
+
   /** A shared job description for all the write tasks. */
   private class WriteJobDescription(
       val uuid: String,  // prevent collision between different (appending) write jobs
@@ -56,7 +60,8 @@ object FileFormatWriter extends Logging {
       val partitionColumns: Seq[Attribute],
       val nonPartitionColumns: Seq[Attribute],
       val bucketSpec: Option[BucketSpec],
-      val path: String)
+      val path: String,
+      val customPartitionLocations: Map[TablePartitionSpec, String])
     extends Serializable {
 
     assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
@@ -83,7 +88,7 @@ object FileFormatWriter extends Logging {
       plan: LogicalPlan,
       fileFormat: FileFormat,
       committer: FileCommitProtocol,
-      outputPath: String,
+      outputSpec: OutputSpec,
       hadoopConf: Configuration,
       partitionColumns: Seq[Attribute],
       bucketSpec: Option[BucketSpec],
@@ -93,7 +98,7 @@ object FileFormatWriter extends Logging {
     val job = Job.getInstance(hadoopConf)
     job.setOutputKeyClass(classOf[Void])
     job.setOutputValueClass(classOf[InternalRow])
-    FileOutputFormat.setOutputPath(job, new Path(outputPath))
+    FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
 
     val partitionSet = AttributeSet(partitionColumns)
     val dataColumns = plan.output.filterNot(partitionSet.contains)
@@ -111,7 +116,8 @@ object FileFormatWriter extends Logging {
       partitionColumns = partitionColumns,
       nonPartitionColumns = dataColumns,
       bucketSpec = bucketSpec,
-      path = outputPath)
+      path = outputSpec.outputPath,
+      customPartitionLocations = outputSpec.customPartitionLocations)
 
     SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
       // This call shouldn't be put into the `try` block below because it only initializes and
@@ -308,7 +314,17 @@ object FileFormatWriter extends Logging {
       }
       val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext)
 
-      val path = committer.newTaskTempFile(taskAttemptContext, partDir, ext)
+      val customPath = partDir match {
+        case Some(dir) =>
+          description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
+        case _ =>
+          None
+      }
+      val path = if (customPath.isDefined) {
+        committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
+      } else {
+        committer.newTaskTempFile(taskAttemptContext, partDir, ext)
+      }
       val newWriter = description.outputWriterFactory.newInstance(
         path = path,
         dataSchema = description.nonPartitionColumns.toStructType,

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index a0a8cb5..28975e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources
 
 import java.io.IOException
 
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
 
 import org.apache.spark.internal.io.FileCommitProtocol
 import org.apache.spark.sql._
@@ -32,19 +32,32 @@ import org.apache.spark.sql.execution.command.RunnableCommand
 /**
  * A command for writing data to a [[HadoopFsRelation]].  Supports both overwriting and appending.
  * Writing to dynamic partitions is also supported.
+ *
+ * @param staticPartitionKeys partial partitioning spec for write. This defines the scope of
+ *                            partition overwrites: when the spec is empty, all partitions are
+ *                            overwritten. When it covers a prefix of the partition keys, only
+ *                            partitions matching the prefix are overwritten.
+ * @param customPartitionLocations mapping of partition specs to their custom locations. The
+ *                                 caller should guarantee that exactly those table partitions
+ *                                 falling under the specified static partition keys are contained
+ *                                 in this map, and that no other partitions are.
  */
 case class InsertIntoHadoopFsRelationCommand(
     outputPath: Path,
+    staticPartitionKeys: TablePartitionSpec,
+    customPartitionLocations: Map[TablePartitionSpec, String],
     partitionColumns: Seq[Attribute],
     bucketSpec: Option[BucketSpec],
     fileFormat: FileFormat,
-    refreshFunction: (Seq[TablePartitionSpec]) => Unit,
+    refreshFunction: Seq[TablePartitionSpec] => Unit,
     options: Map[String, String],
     @transient query: LogicalPlan,
     mode: SaveMode,
     catalogTable: Option[CatalogTable])
   extends RunnableCommand {
 
+  import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
+
   override protected def innerChildren: Seq[LogicalPlan] = query :: Nil
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
@@ -66,10 +79,7 @@ case class InsertIntoHadoopFsRelationCommand(
       case (SaveMode.ErrorIfExists, true) =>
         throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
       case (SaveMode.Overwrite, true) =>
-        if (!fs.delete(qualifiedOutputPath, true /* recursively */)) {
-          throw new IOException(s"Unable to clear output " +
-            s"directory $qualifiedOutputPath prior to writing to it")
-        }
+        deleteMatchingPartitions(fs, qualifiedOutputPath)
         true
       case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
         true
@@ -93,7 +103,8 @@ case class InsertIntoHadoopFsRelationCommand(
         plan = query,
         fileFormat = fileFormat,
         committer = committer,
-        outputPath = qualifiedOutputPath.toString,
+        outputSpec = FileFormatWriter.OutputSpec(
+          qualifiedOutputPath.toString, customPartitionLocations),
         hadoopConf = hadoopConf,
         partitionColumns = partitionColumns,
         bucketSpec = bucketSpec,
@@ -105,4 +116,40 @@ case class InsertIntoHadoopFsRelationCommand(
 
     Seq.empty[Row]
   }
+
+  /**
+   * Deletes all partition files that match the specified static prefix. Partitions with custom
+   * locations are also cleared based on the custom locations map given to this class.
+   */
+  private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = {
+    val staticPartitionPrefix = if (staticPartitionKeys.nonEmpty) {
+      "/" + partitionColumns.flatMap { p =>
+        staticPartitionKeys.get(p.name) match {
+          case Some(value) =>
+            Some(escapePathName(p.name) + "=" + escapePathName(value))
+          case None =>
+            None
+        }
+      }.mkString("/")
+    } else {
+      ""
+    }
+    // first clear the path determined by the static partition keys (e.g. /table/foo=1)
+    val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix)
+    if (fs.exists(staticPrefixPath) && !fs.delete(staticPrefixPath, true /* recursively */)) {
+      throw new IOException(s"Unable to clear output " +
+        s"directory $staticPrefixPath prior to writing to it")
+    }
+    // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4)
+    for ((spec, customLoc) <- customPartitionLocations) {
+      assert(
+        (staticPartitionKeys.toSet -- spec).isEmpty,
+        "Custom partition location did not match static partitioning keys")
+      val path = new Path(customLoc)
+      if (fs.exists(path) && !fs.delete(path, true)) {
+        throw new IOException(s"Unable to clear partition " +
+          s"directory $path prior to writing to it")
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index a28b04c..bf9f318 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -62,6 +62,7 @@ object PartitioningUtils {
   }
 
   import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME
+  import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
   import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName
 
   /**
@@ -253,6 +254,15 @@ object PartitioningUtils {
   }
 
   /**
+   * This is the inverse of parsePathFragment().
+   */
+  def getPathFragment(spec: TablePartitionSpec, partitionSchema: StructType): String = {
+    partitionSchema.map { field =>
+      escapePathName(field.name) + "=" + escapePathName(spec(field.name))
+    }.mkString("/")
+  }
+
+  /**
    * Normalize the column names in partition specification, w.r.t. the real partition column names
    * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
    * partition column named `month`, and it's case insensitive, we will normalize `monTh` to

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index e849caf..f1c5f9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -80,7 +80,7 @@ class FileStreamSink(
         plan = data.logicalPlan,
         fileFormat = fileFormat,
         committer = committer,
-        outputPath = path,
+        outputSpec = FileFormatWriter.OutputSpec(path, Map.empty),
         hadoopConf = hadoopConf,
         partitionColumns = partitionColumns,
         bucketSpec = None,

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
index 1fe13fa..92191c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
@@ -96,6 +96,12 @@ class ManifestFileCommitProtocol(jobId: String, path: String)
     file
   }
 
+  override def newTaskTempFileAbsPath(
+      taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
+    throw new UnsupportedOperationException(
+      s"$this does not support adding files with an absolute path")
+  }
+
   override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
     if (addedFiles.nonEmpty) {
       val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration)

http://git-wip-us.apache.org/repos/asf/spark/blob/a3356343/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
index ac435bf..a1aa074 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest}
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
 
 class PartitionProviderCompatibilitySuite
   extends QueryTest with TestHiveSingleton with SQLTestUtils {
@@ -135,7 +136,7 @@ class PartitionProviderCompatibilitySuite
     }
   }
 
-  test("insert overwrite partition of legacy datasource table overwrites entire table") {
+  test("insert overwrite partition of legacy datasource table") {
     withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") {
       withTable("test") {
         withTempDir { dir =>
@@ -144,9 +145,9 @@ class PartitionProviderCompatibilitySuite
             """insert overwrite table test
               |partition (partCol=1)
               |select * from range(100)""".stripMargin)
-          assert(spark.sql("select * from test").count() == 100)
+          assert(spark.sql("select * from test").count() == 104)
 
-          // Dynamic partitions case
+          // Overwriting entire table
           spark.sql("insert overwrite table test select id, id from range(10)".stripMargin)
           assert(spark.sql("select * from test").count() == 10)
         }
@@ -186,4 +187,158 @@ class PartitionProviderCompatibilitySuite
       }
     }
   }
+
+  /**
+   * Runs a test against a multi-level partitioned table, then validates that the custom locations
+   * were respected by the output writer.
+   *
+   * The initial partitioning structure is:
+   *   /P1=0/P2=0  -- custom location a
+   *   /P1=0/P2=1  -- custom location b
+   *   /P1=1/P2=0  -- custom location c
+   *   /P1=1/P2=1  -- default location
+   */
+  private def testCustomLocations(testFn: => Unit): Unit = {
+    val base = Utils.createTempDir(namePrefix = "base")
+    val a = Utils.createTempDir(namePrefix = "a")
+    val b = Utils.createTempDir(namePrefix = "b")
+    val c = Utils.createTempDir(namePrefix = "c")
+    try {
+      spark.sql(s"""
+        |create table test (id long, P1 int, P2 int)
+        |using parquet
+        |options (path "${base.getAbsolutePath}")
+        |partitioned by (P1, P2)""".stripMargin)
+      spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.getAbsolutePath}'")
+      spark.sql(s"alter table test add partition (P1=0, P2=1) location '${b.getAbsolutePath}'")
+      spark.sql(s"alter table test add partition (P1=1, P2=0) location '${c.getAbsolutePath}'")
+      spark.sql(s"alter table test add partition (P1=1, P2=1)")
+
+      testFn
+
+      // Now validate the partition custom locations were respected
+      val initialCount = spark.sql("select * from test").count()
+      val numA = spark.sql("select * from test where P1=0 and P2=0").count()
+      val numB = spark.sql("select * from test where P1=0 and P2=1").count()
+      val numC = spark.sql("select * from test where P1=1 and P2=0").count()
+      Utils.deleteRecursively(a)
+      spark.sql("refresh table test")
+      assert(spark.sql("select * from test where P1=0 and P2=0").count() == 0)
+      assert(spark.sql("select * from test").count() == initialCount - numA)
+      Utils.deleteRecursively(b)
+      spark.sql("refresh table test")
+      assert(spark.sql("select * from test where P1=0 and P2=1").count() == 0)
+      assert(spark.sql("select * from test").count() == initialCount - numA - numB)
+      Utils.deleteRecursively(c)
+      spark.sql("refresh table test")
+      assert(spark.sql("select * from test where P1=1 and P2=0").count() == 0)
+      assert(spark.sql("select * from test").count() == initialCount - numA - numB - numC)
+    } finally {
+      Utils.deleteRecursively(base)
+      Utils.deleteRecursively(a)
+      Utils.deleteRecursively(b)
+      Utils.deleteRecursively(c)
+      spark.sql("drop table test")
+    }
+  }
+
+  test("sanity check table setup") {
+    testCustomLocations {
+      assert(spark.sql("select * from test").count() == 0)
+      assert(spark.sql("show partitions test").count() == 4)
+    }
+  }
+
+  test("insert into partial dynamic partitions") {
+    testCustomLocations {
+      spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 12)
+      spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 20)
+      assert(spark.sql("show partitions test").count() == 12)
+      spark.sql("insert into test partition (P1=1, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 30)
+      assert(spark.sql("show partitions test").count() == 20)
+      spark.sql("insert into test partition (P1=2, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 40)
+      assert(spark.sql("show partitions test").count() == 30)
+    }
+  }
+
+  test("insert into fully dynamic partitions") {
+    testCustomLocations {
+      spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 12)
+      spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 20)
+      assert(spark.sql("show partitions test").count() == 12)
+    }
+  }
+
+  test("insert into static partition") {
+    testCustomLocations {
+      spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)")
+      assert(spark.sql("select * from test").count() == 20)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert into test partition (P1=1, P2=1) select id from range(10)")
+      assert(spark.sql("select * from test").count() == 30)
+      assert(spark.sql("show partitions test").count() == 4)
+    }
+  }
+
+  test("overwrite partial dynamic partitions") {
+    testCustomLocations {
+      spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 12)
+      spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(5)")
+      assert(spark.sql("select * from test").count() == 5)
+      assert(spark.sql("show partitions test").count() == 7)
+      spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(1)")
+      assert(spark.sql("select * from test").count() == 1)
+      assert(spark.sql("show partitions test").count() == 3)
+      spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 11)
+      assert(spark.sql("show partitions test").count() == 11)
+      spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(1)")
+      assert(spark.sql("select * from test").count() == 2)
+      assert(spark.sql("show partitions test").count() == 2)
+      spark.sql("insert overwrite table test partition (P1=3, P2) select id, id from range(100)")
+      assert(spark.sql("select * from test").count() == 102)
+      assert(spark.sql("show partitions test").count() == 102)
+    }
+  }
+
+  test("overwrite fully dynamic partitions") {
+    testCustomLocations {
+      spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 10)
+      spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(5)")
+      assert(spark.sql("select * from test").count() == 5)
+      assert(spark.sql("show partitions test").count() == 5)
+    }
+  }
+
+  test("overwrite static partition") {
+    testCustomLocations {
+      spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(5)")
+      assert(spark.sql("select * from test").count() == 5)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert overwrite table test partition (P1=1, P2=1) select id from range(5)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert overwrite table test partition (P1=1, P2=2) select id from range(5)")
+      assert(spark.sql("select * from test").count() == 15)
+      assert(spark.sql("show partitions test").count() == 5)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org