You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/01/18 07:38:09 UTC
[2/2] spark git commit: [SPARK-18243][SQL] Port Hive writing to use
FileFormat interface
[SPARK-18243][SQL] Port Hive writing to use FileFormat interface
## What changes were proposed in this pull request?
Inserting data into Hive tables has its own implementation that is distinct from data sources: `InsertIntoHiveTable`, `SparkHiveWriterContainer` and `SparkHiveDynamicPartitionWriterContainer`.
Note that one other major difference is that data source tables write directly to the final destination without using some staging directory, and then Spark itself adds the partitions/tables to the catalog. Hive tables actually write to some staging directory, and then call Hive metastore's loadPartition/loadTable function to load those data in. So we still need to keep `InsertIntoHiveTable` to put this special logic. In the future, we should think of writing to the hive table location directly, so that we don't need to call `loadTable`/`loadPartition` at the end and remove `InsertIntoHiveTable`.
This PR removes `SparkHiveWriterContainer` and `SparkHiveDynamicPartitionWriterContainer`, and create a `HiveFileFormat` to implement the write logic. In the future, we should also implement the read logic in `HiveFileFormat`.
## How was this patch tested?
existing tests
Author: Wenchen Fan <we...@databricks.com>
Closes #16517 from cloud-fan/insert-hive.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4494cd97
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4494cd97
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4494cd97
Branch: refs/heads/master
Commit: 4494cd9716d64a6c7cfa548abadb5dd0c4c143a6
Parents: e7f982b
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Jan 17 23:37:59 2017 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Tue Jan 17 23:37:59 2017 -0800
----------------------------------------------------------------------
.../io/HadoopMapReduceCommitProtocol.scala | 2 +-
.../spark/sql/execution/QueryExecution.scala | 33 +-
.../spark/sql/hive/HiveSessionState.scala | 2 +-
.../apache/spark/sql/hive/HiveStrategies.scala | 77 ++--
.../org/apache/spark/sql/hive/TableReader.scala | 4 +-
.../apache/spark/sql/hive/client/package.scala | 2 +-
.../sql/hive/execution/HiveFileFormat.scala | 149 ++++++
.../hive/execution/InsertIntoHiveTable.scala | 187 ++++----
.../hive/execution/ScriptTransformation.scala | 448 -------------------
.../execution/ScriptTransformationExec.scala | 448 +++++++++++++++++++
.../spark/sql/hive/hiveWriterContainers.scala | 356 ---------------
.../spark/sql/hive/client/VersionsSuite.scala | 4 +-
.../sql/hive/execution/HiveComparisonTest.scala | 10 +-
.../spark/sql/hive/execution/HiveDDLSuite.scala | 5 +-
.../sql/hive/execution/HiveQuerySuite.scala | 8 +-
.../execution/ScriptTransformationSuite.scala | 10 +-
16 files changed, 765 insertions(+), 980 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
index b2d9b8d..2f33f2e 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -99,7 +99,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
}
private def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
- // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
+ // The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet
// Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
// the file name is fine and won't overflow.
val split = taskContext.getTaskAttemptID.getTaskID.getId
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index b3ef29f..dcd9003 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -108,21 +108,18 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
/**
- * Returns the result as a hive compatible sequence of strings. For native commands, the
- * execution is simply passed back to Hive.
+ * Returns the result as a hive compatible sequence of strings. This is for testing only.
*/
def hiveResultString(): Seq[String] = executedPlan match {
case ExecutedCommandExec(desc: DescribeTableCommand) =>
- SQLExecution.withNewExecutionId(sparkSession, this) {
- // If it is a describe command for a Hive table, we want to have the output format
- // be similar with Hive.
- desc.run(sparkSession).map {
- case Row(name: String, dataType: String, comment) =>
- Seq(name, dataType,
- Option(comment.asInstanceOf[String]).getOrElse(""))
- .map(s => String.format(s"%-20s", s))
- .mkString("\t")
- }
+ // If it is a describe command for a Hive table, we want to have the output format
+ // be similar with Hive.
+ desc.run(sparkSession).map {
+ case Row(name: String, dataType: String, comment) =>
+ Seq(name, dataType,
+ Option(comment.asInstanceOf[String]).getOrElse(""))
+ .map(s => String.format(s"%-20s", s))
+ .mkString("\t")
}
// SHOW TABLES in Hive only output table names, while ours outputs database, table name, isTemp.
case command: ExecutedCommandExec if command.cmd.isInstanceOf[ShowTablesCommand] =>
@@ -130,13 +127,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
case command: ExecutedCommandExec =>
command.executeCollect().map(_.getString(0))
case other =>
- SQLExecution.withNewExecutionId(sparkSession, this) {
- val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
- // We need the types so we can output struct field names
- val types = analyzed.output.map(_.dataType)
- // Reformat to match hive tab delimited output.
- result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq
- }
+ val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
+ // We need the types so we can output struct field names
+ val types = analyzed.output.map(_.dataType)
+ // Reformat to match hive tab delimited output.
+ result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t"))
}
/** Formats a datum (based on the given data type) and returns the string representation. */
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index 9b4b8b6..4e30d03 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -66,6 +66,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
PreprocessTableInsertion(conf) ::
DataSourceAnalysis(conf) ::
new DetermineHiveSerde(conf) ::
+ new HiveAnalysis(sparkSession) ::
new ResolveDataSource(sparkSession) :: Nil
override val extendedCheckRules = Seq(PreWriteCheck(conf, catalog))
@@ -88,7 +89,6 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
SpecialLimits,
InMemoryScans,
HiveTableScans,
- DataSinks,
Scripts,
Aggregation,
JoinSelection,
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index d1f11e7..7987a0a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -21,14 +21,14 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, ScriptTransformation}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.command.{DDLUtils, ExecutedCommandExec}
+import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
+import org.apache.spark.sql.types.StructType
/**
@@ -86,6 +86,47 @@ class DetermineHiveSerde(conf: SQLConf) extends Rule[LogicalPlan] {
}
}
+class HiveAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case InsertIntoTable(table: MetastoreRelation, partSpec, query, overwrite, ifNotExists)
+ if hasBeenPreprocessed(table.output, table.partitionKeys.toStructType, partSpec, query) =>
+ InsertIntoHiveTable(table, partSpec, query, overwrite, ifNotExists)
+
+ case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) =>
+ // Currently `DataFrameWriter.saveAsTable` doesn't support the Append mode of hive serde
+ // tables yet.
+ if (mode == SaveMode.Append) {
+ throw new AnalysisException(
+ "CTAS for hive serde tables does not support append semantics.")
+ }
+
+ val dbName = tableDesc.identifier.database.getOrElse(session.catalog.currentDatabase)
+ CreateHiveTableAsSelectCommand(
+ tableDesc.copy(identifier = tableDesc.identifier.copy(database = Some(dbName))),
+ query,
+ mode == SaveMode.Ignore)
+ }
+
+ /**
+ * Returns true if the [[InsertIntoTable]] plan has already been preprocessed by analyzer rule
+ * [[PreprocessTableInsertion]]. It is important that this rule([[HiveAnalysis]]) has to
+ * be run after [[PreprocessTableInsertion]], to normalize the column names in partition spec and
+ * fix the schema mismatch by adding Cast.
+ */
+ private def hasBeenPreprocessed(
+ tableOutput: Seq[Attribute],
+ partSchema: StructType,
+ partSpec: Map[String, Option[String]],
+ query: LogicalPlan): Boolean = {
+ val partColNames = partSchema.map(_.name).toSet
+ query.resolved && partSpec.keys.forall(partColNames.contains) && {
+ val staticPartCols = partSpec.filter(_._2.isDefined).keySet
+ val expectedColumns = tableOutput.filterNot(a => staticPartCols.contains(a.name))
+ expectedColumns.toStructType.sameType(query.schema)
+ }
+ }
+}
+
private[hive] trait HiveStrategies {
// Possibly being too clever with types here... or not clever enough.
self: SparkPlanner =>
@@ -94,35 +135,9 @@ private[hive] trait HiveStrategies {
object Scripts extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.ScriptTransformation(input, script, output, child, ioschema) =>
+ case ScriptTransformation(input, script, output, child, ioschema) =>
val hiveIoSchema = HiveScriptIOSchema(ioschema)
- ScriptTransformation(input, script, output, planLater(child), hiveIoSchema) :: Nil
- case _ => Nil
- }
- }
-
- object DataSinks extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.InsertIntoTable(
- table: MetastoreRelation, partition, child, overwrite, ifNotExists) =>
- InsertIntoHiveTable(
- table, partition, planLater(child), overwrite, ifNotExists) :: Nil
-
- case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) =>
- // Currently `DataFrameWriter.saveAsTable` doesn't support
- // the Append mode of hive serde tables yet.
- if (mode == SaveMode.Append) {
- throw new AnalysisException(
- "CTAS for hive serde tables does not support append semantics.")
- }
-
- val dbName = tableDesc.identifier.database.getOrElse(sparkSession.catalog.currentDatabase)
- val cmd = CreateHiveTableAsSelectCommand(
- tableDesc.copy(identifier = tableDesc.identifier.copy(database = Some(dbName))),
- query,
- mode == SaveMode.Ignore)
- ExecutedCommandExec(cmd) :: Nil
-
+ ScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil
case _ => Nil
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index aaf30f4..b4b6303 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -311,10 +311,10 @@ private[hive] object HiveTableUtil {
// that calls Hive.get() which tries to access metastore, but it's not valid in runtime
// it would be fixed in next version of hive but till then, we should use this instead
def configureJobPropertiesForStorageHandler(
- tableDesc: TableDesc, jobConf: JobConf, input: Boolean) {
+ tableDesc: TableDesc, conf: Configuration, input: Boolean) {
val property = tableDesc.getProperties.getProperty(META_TABLE_STORAGE)
val storageHandler =
- org.apache.hadoop.hive.ql.metadata.HiveUtils.getStorageHandler(jobConf, property)
+ org.apache.hadoop.hive.ql.metadata.HiveUtils.getStorageHandler(conf, property)
if (storageHandler != null) {
val jobProperties = new java.util.LinkedHashMap[String, String]
if (input) {
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
index b1b8439..4e2193b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.hive
/** Support for interacting with different versions of the HiveMetastoreClient */
package object client {
- private[client] abstract class HiveVersion(
+ private[hive] abstract class HiveVersion(
val fullVersion: String,
val extraDeps: Seq[String] = Nil,
val exclusions: Seq[String] = Nil)
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
new file mode 100644
index 0000000..cc2b60b
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.hive.ql.exec.Utilities
+import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
+import org.apache.hadoop.hive.serde2.Serializer
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorUtils, StructObjectInspector}
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapred.{JobConf, Reporter}
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, OutputWriterFactory}
+import org.apache.spark.sql.hive.{HiveInspectors, HiveTableUtil}
+import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.SerializableJobConf
+
+/**
+ * `FileFormat` for writing Hive tables.
+ *
+ * TODO: implement the read logic.
+ */
+class HiveFileFormat(fileSinkConf: FileSinkDesc) extends FileFormat with Logging {
+ override def inferSchema(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ throw new UnsupportedOperationException(s"inferSchema is not supported for hive data source.")
+ }
+
+ override def prepareWrite(
+ sparkSession: SparkSession,
+ job: Job,
+ options: Map[String, String],
+ dataSchema: StructType): OutputWriterFactory = {
+ val conf = job.getConfiguration
+ val tableDesc = fileSinkConf.getTableInfo
+ conf.set("mapred.output.format.class", tableDesc.getOutputFileFormatClassName)
+
+ // When speculation is on and output committer class name contains "Direct", we should warn
+ // users that they may loss data if they are using a direct output committer.
+ val speculationEnabled = sparkSession.sparkContext.conf.getBoolean("spark.speculation", false)
+ val outputCommitterClass = conf.get("mapred.output.committer.class", "")
+ if (speculationEnabled && outputCommitterClass.contains("Direct")) {
+ val warningMessage =
+ s"$outputCommitterClass may be an output committer that writes data directly to " +
+ "the final location. Because speculation is enabled, this output committer may " +
+ "cause data loss (see the case in SPARK-10063). If possible, please use an output " +
+ "committer that does not have this behavior (e.g. FileOutputCommitter)."
+ logWarning(warningMessage)
+ }
+
+ // Add table properties from storage handler to hadoopConf, so any custom storage
+ // handler settings can be set to hadoopConf
+ HiveTableUtil.configureJobPropertiesForStorageHandler(tableDesc, conf, false)
+ Utilities.copyTableJobPropertiesToConf(tableDesc, conf)
+
+ // Avoid referencing the outer object.
+ val fileSinkConfSer = fileSinkConf
+ new OutputWriterFactory {
+ private val jobConf = new SerializableJobConf(new JobConf(conf))
+ @transient private lazy val outputFormat =
+ jobConf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]]
+
+ override def getFileExtension(context: TaskAttemptContext): String = {
+ Utilities.getFileExtension(jobConf.value, fileSinkConfSer.getCompressed, outputFormat)
+ }
+
+ override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new HiveOutputWriter(path, fileSinkConfSer, jobConf.value, dataSchema)
+ }
+ }
+ }
+}
+
+class HiveOutputWriter(
+ path: String,
+ fileSinkConf: FileSinkDesc,
+ jobConf: JobConf,
+ dataSchema: StructType) extends OutputWriter with HiveInspectors {
+
+ private def tableDesc = fileSinkConf.getTableInfo
+
+ private val serializer = {
+ val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer]
+ serializer.initialize(null, tableDesc.getProperties)
+ serializer
+ }
+
+ private val hiveWriter = HiveFileFormatUtils.getHiveRecordWriter(
+ jobConf,
+ tableDesc,
+ serializer.getSerializedClass,
+ fileSinkConf,
+ new Path(path),
+ Reporter.NULL)
+
+ private val standardOI = ObjectInspectorUtils
+ .getStandardObjectInspector(
+ tableDesc.getDeserializer.getObjectInspector,
+ ObjectInspectorCopyOption.JAVA)
+ .asInstanceOf[StructObjectInspector]
+
+ private val fieldOIs =
+ standardOI.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray
+ private val dataTypes = dataSchema.map(_.dataType).toArray
+ private val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt) }
+ private val outputData = new Array[Any](fieldOIs.length)
+
+ override def write(row: InternalRow): Unit = {
+ var i = 0
+ while (i < fieldOIs.length) {
+ outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
+ i += 1
+ }
+ hiveWriter.write(serializer.serialize(outputData, standardOI))
+ }
+
+ override def close(): Unit = {
+ // Seems the boolean value passed into close does not matter.
+ hiveWriter.close(false)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index aa858e8..ce418ae 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -24,22 +24,22 @@ import java.util.{Date, Locale, Random}
import scala.util.control.NonFatal
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.common.FileUtils
import org.apache.hadoop.hive.ql.exec.TaskRunner
import org.apache.hadoop.hive.ql.ErrorMsg
-import org.apache.hadoop.mapred.{FileOutputFormat, JobConf}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.internal.io.FileCommitProtocol
+import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.command.RunnableCommand
+import org.apache.spark.sql.execution.datasources.FileFormatWriter
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
+import org.apache.spark.sql.hive.client.HiveVersion
import org.apache.spark.SparkException
-import org.apache.spark.util.SerializableJobConf
/**
@@ -69,26 +69,20 @@ import org.apache.spark.util.SerializableJobConf
* {{{
* Map('a' -> Some('1'), 'b' -> None)
* }}}.
- * @param child the logical plan representing data to write to.
+ * @param query the logical plan representing data to write to.
* @param overwrite overwrite existing table or partitions.
* @param ifNotExists If true, only write if the table or partition does not exist.
*/
case class InsertIntoHiveTable(
table: MetastoreRelation,
partition: Map[String, Option[String]],
- child: SparkPlan,
+ query: LogicalPlan,
overwrite: Boolean,
- ifNotExists: Boolean) extends UnaryExecNode {
+ ifNotExists: Boolean) extends RunnableCommand {
- @transient private val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState]
- @transient private val externalCatalog = sqlContext.sharedState.externalCatalog
+ override protected def innerChildren: Seq[LogicalPlan] = query :: Nil
- def output: Seq[Attribute] = Seq.empty
-
- val hadoopConf = sessionState.newHadoopConf()
var createdTempDir: Option[Path] = None
- val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging")
- val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive")
private def executionId: String = {
val rand: Random = new Random
@@ -96,7 +90,10 @@ case class InsertIntoHiveTable(
"hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong)
}
- private def getStagingDir(inputPath: Path): Path = {
+ private def getStagingDir(
+ inputPath: Path,
+ hadoopConf: Configuration,
+ stagingDir: String): Path = {
val inputPathUri: URI = inputPath.toUri
val inputPathName: String = inputPathUri.getPath
val fs: FileSystem = inputPath.getFileSystem(hadoopConf)
@@ -121,17 +118,27 @@ case class InsertIntoHiveTable(
throw new RuntimeException(
"Cannot create staging directory '" + dir.toString + "': " + e.getMessage, e)
}
- return dir
+ dir
}
- private def getExternalScratchDir(extURI: URI): Path = {
- getStagingDir(new Path(extURI.getScheme, extURI.getAuthority, extURI.getPath))
+ private def getExternalScratchDir(
+ extURI: URI,
+ hadoopConf: Configuration,
+ stagingDir: String): Path = {
+ getStagingDir(
+ new Path(extURI.getScheme, extURI.getAuthority, extURI.getPath),
+ hadoopConf,
+ stagingDir)
}
- def getExternalTmpPath(path: Path): Path = {
+ def getExternalTmpPath(
+ path: Path,
+ hiveVersion: HiveVersion,
+ hadoopConf: Configuration,
+ stagingDir: String,
+ scratchDir: String): Path = {
import org.apache.spark.sql.hive.client.hive._
- val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version
// Before Hive 1.1, when inserting into a table, Hive will create the staging directory under
// a common scratch directory. After the writing is finished, Hive will simply empty the table
// directory and move the staging directory to it.
@@ -142,16 +149,19 @@ case class InsertIntoHiveTable(
// staging directory under the table director for Hive prior to 1.1, the staging directory will
// be removed by Hive when Hive is trying to empty the table directory.
if (hiveVersion == v12 || hiveVersion == v13 || hiveVersion == v14 || hiveVersion == v1_0) {
- oldVersionExternalTempPath(path)
+ oldVersionExternalTempPath(path, hadoopConf, scratchDir)
} else if (hiveVersion == v1_1 || hiveVersion == v1_2) {
- newVersionExternalTempPath(path)
+ newVersionExternalTempPath(path, hadoopConf, stagingDir)
} else {
throw new IllegalStateException("Unsupported hive version: " + hiveVersion.fullVersion)
}
}
// Mostly copied from Context.java#getExternalTmpPath of Hive 0.13
- def oldVersionExternalTempPath(path: Path): Path = {
+ def oldVersionExternalTempPath(
+ path: Path,
+ hadoopConf: Configuration,
+ scratchDir: String): Path = {
val extURI: URI = path.toUri
val scratchPath = new Path(scratchDir, executionId)
var dirPath = new Path(
@@ -176,54 +186,44 @@ case class InsertIntoHiveTable(
}
// Mostly copied from Context.java#getExternalTmpPath of Hive 1.2
- def newVersionExternalTempPath(path: Path): Path = {
+ def newVersionExternalTempPath(
+ path: Path,
+ hadoopConf: Configuration,
+ stagingDir: String): Path = {
val extURI: URI = path.toUri
if (extURI.getScheme == "viewfs") {
- getExtTmpPathRelTo(path.getParent)
+ getExtTmpPathRelTo(path.getParent, hadoopConf, stagingDir)
} else {
- new Path(getExternalScratchDir(extURI), "-ext-10000")
+ new Path(getExternalScratchDir(extURI, hadoopConf, stagingDir), "-ext-10000")
}
}
- def getExtTmpPathRelTo(path: Path): Path = {
- new Path(getStagingDir(path), "-ext-10000") // Hive uses 10000
- }
-
- private def saveAsHiveFile(
- rdd: RDD[InternalRow],
- valueClass: Class[_],
- fileSinkConf: FileSinkDesc,
- conf: SerializableJobConf,
- writerContainer: SparkHiveWriterContainer): Unit = {
- assert(valueClass != null, "Output value class not set")
- conf.value.setOutputValueClass(valueClass)
-
- val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName
- assert(outputFileFormatClassName != null, "Output format class not set")
- conf.value.set("mapred.output.format.class", outputFileFormatClassName)
-
- FileOutputFormat.setOutputPath(
- conf.value,
- SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName(), conf.value))
- log.debug("Saving as hadoop file of type " + valueClass.getSimpleName)
- writerContainer.driverSideSetup()
- sqlContext.sparkContext.runJob(rdd, writerContainer.writeToFile _)
- writerContainer.commitJob()
+ def getExtTmpPathRelTo(
+ path: Path,
+ hadoopConf: Configuration,
+ stagingDir: String): Path = {
+ new Path(getStagingDir(path, hadoopConf, stagingDir), "-ext-10000") // Hive uses 10000
}
/**
* Inserts all the rows in the table into Hive. Row objects are properly serialized with the
* `org.apache.hadoop.hive.serde2.SerDe` and the
* `org.apache.hadoop.mapred.OutputFormat` provided by the table definition.
- *
- * Note: this is run once and then kept to avoid double insertions.
*/
- protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val sessionState = sparkSession.sessionState
+ val externalCatalog = sparkSession.sharedState.externalCatalog
+ val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version
+ val hadoopConf = sessionState.newHadoopConf()
+ val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging")
+ val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive")
+
// Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer
// instances within the closure, since Serializer is not serializable while TableDesc is.
val tableDesc = table.tableDesc
val tableLocation = table.hiveQlTable.getDataLocation
- val tmpLocation = getExternalTmpPath(tableLocation)
+ val tmpLocation =
+ getExternalTmpPath(tableLocation, hiveVersion, hadoopConf, stagingDir, scratchDir)
val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false)
val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean
@@ -276,40 +276,31 @@ case class InsertIntoHiveTable(
}
}
- val jobConf = new JobConf(hadoopConf)
- val jobConfSer = new SerializableJobConf(jobConf)
-
- // When speculation is on and output committer class name contains "Direct", we should warn
- // users that they may loss data if they are using a direct output committer.
- val speculationEnabled = sqlContext.sparkContext.conf.getBoolean("spark.speculation", false)
- val outputCommitterClass = jobConf.get("mapred.output.committer.class", "")
- if (speculationEnabled && outputCommitterClass.contains("Direct")) {
- val warningMessage =
- s"$outputCommitterClass may be an output committer that writes data directly to " +
- "the final location. Because speculation is enabled, this output committer may " +
- "cause data loss (see the case in SPARK-10063). If possible, please use an output " +
- "committer that does not have this behavior (e.g. FileOutputCommitter)."
- logWarning(warningMessage)
+ val committer = FileCommitProtocol.instantiate(
+ sparkSession.sessionState.conf.fileCommitProtocolClass,
+ jobId = java.util.UUID.randomUUID().toString,
+ outputPath = tmpLocation.toString,
+ isAppend = false)
+
+ val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name =>
+ query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse {
+ throw new AnalysisException(
+ s"Unable to resolve $name given [${query.output.map(_.name).mkString(", ")}]")
+ }.asInstanceOf[Attribute]
}
- val writerContainer = if (numDynamicPartitions > 0) {
- val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions)
- new SparkHiveDynamicPartitionWriterContainer(
- jobConf,
- fileSinkConf,
- dynamicPartColNames,
- child.output)
- } else {
- new SparkHiveWriterContainer(
- jobConf,
- fileSinkConf,
- child.output)
- }
-
- @transient val outputClass = writerContainer.newSerializer(table.tableDesc).getSerializedClass
- saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer)
+ FileFormatWriter.write(
+ sparkSession = sparkSession,
+ queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
+ fileFormat = new HiveFileFormat(fileSinkConf),
+ committer = committer,
+ outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty),
+ hadoopConf = hadoopConf,
+ partitionColumns = partitionAttributes,
+ bucketSpec = None,
+ refreshFunction = _ => (),
+ options = Map.empty)
- val outputPath = FileOutputFormat.getOutputPath(jobConf)
// TODO: Correctly set holdDDLTime.
// In most of the time, we should have holdDDLTime = false.
// holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint.
@@ -319,7 +310,7 @@ case class InsertIntoHiveTable(
externalCatalog.loadDynamicPartitions(
db = table.catalogTable.database,
table = table.catalogTable.identifier.table,
- outputPath.toString,
+ tmpLocation.toString,
partitionSpec,
overwrite,
numDynamicPartitions,
@@ -363,7 +354,7 @@ case class InsertIntoHiveTable(
externalCatalog.loadPartition(
table.catalogTable.database,
table.catalogTable.identifier.table,
- outputPath.toString,
+ tmpLocation.toString,
partitionSpec,
isOverwrite = doHiveOverwrite,
holdDDLTime = holdDDLTime,
@@ -375,7 +366,7 @@ case class InsertIntoHiveTable(
externalCatalog.loadTable(
table.catalogTable.database,
table.catalogTable.identifier.table,
- outputPath.toString, // TODO: URI
+ tmpLocation.toString, // TODO: URI
overwrite,
holdDDLTime,
isSrcLocal = false)
@@ -391,21 +382,13 @@ case class InsertIntoHiveTable(
}
// Invalidate the cache.
- sqlContext.sharedState.cacheManager.invalidateCache(table)
- sqlContext.sessionState.catalog.refreshTable(table.catalogTable.identifier)
+ sparkSession.sharedState.cacheManager.invalidateCache(table)
+ sparkSession.sessionState.catalog.refreshTable(table.catalogTable.identifier)
// It would be nice to just return the childRdd unchanged so insert operations could be chained,
// however for now we return an empty list to simplify compatibility checks with hive, which
// does not return anything for insert operations.
// TODO: implement hive compatibility as rules.
- Seq.empty[InternalRow]
- }
-
- override def outputPartitioning: Partitioning = child.outputPartitioning
-
- override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray
-
- protected override def doExecute(): RDD[InternalRow] = {
- sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1)
+ Seq.empty[Row]
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
deleted file mode 100644
index 50855e4..0000000
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ /dev/null
@@ -1,448 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.execution
-
-import java.io._
-import java.nio.charset.StandardCharsets
-import java.util.Properties
-import javax.annotation.Nullable
-
-import scala.collection.JavaConverters._
-import scala.util.control.NonFatal
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter}
-import org.apache.hadoop.hive.serde.serdeConstants
-import org.apache.hadoop.hive.serde2.AbstractSerDe
-import org.apache.hadoop.hive.serde2.objectinspector._
-import org.apache.hadoop.io.Writable
-
-import org.apache.spark.{SparkException, TaskContext}
-import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.hive.HiveInspectors
-import org.apache.spark.sql.hive.HiveShim._
-import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}
-
-/**
- * Transforms the input by forking and running the specified script.
- *
- * @param input the set of expression that should be passed to the script.
- * @param script the command that should be executed.
- * @param output the attributes that are produced by the script.
- */
-case class ScriptTransformation(
- input: Seq[Expression],
- script: String,
- output: Seq[Attribute],
- child: SparkPlan,
- ioschema: HiveScriptIOSchema)
- extends UnaryExecNode {
-
- override def producedAttributes: AttributeSet = outputSet -- inputSet
-
- override def outputPartitioning: Partitioning = child.outputPartitioning
-
- protected override def doExecute(): RDD[InternalRow] = {
- def processIterator(inputIterator: Iterator[InternalRow], hadoopConf: Configuration)
- : Iterator[InternalRow] = {
- val cmd = List("/bin/bash", "-c", script)
- val builder = new ProcessBuilder(cmd.asJava)
-
- val proc = builder.start()
- val inputStream = proc.getInputStream
- val outputStream = proc.getOutputStream
- val errorStream = proc.getErrorStream
-
- // In order to avoid deadlocks, we need to consume the error output of the child process.
- // To avoid issues caused by large error output, we use a circular buffer to limit the amount
- // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang
- // that motivates this.
- val stderrBuffer = new CircularBuffer(2048)
- new RedirectThread(
- errorStream,
- stderrBuffer,
- "Thread-ScriptTransformation-STDERR-Consumer").start()
-
- val outputProjection = new InterpretedProjection(input, child.output)
-
- // This nullability is a performance optimization in order to avoid an Option.foreach() call
- // inside of a loop
- @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null))
-
- // This new thread will consume the ScriptTransformation's input rows and write them to the
- // external process. That process's output will be read by this current thread.
- val writerThread = new ScriptTransformationWriterThread(
- inputIterator,
- input.map(_.dataType),
- outputProjection,
- inputSerde,
- inputSoi,
- ioschema,
- outputStream,
- proc,
- stderrBuffer,
- TaskContext.get(),
- hadoopConf
- )
-
- // This nullability is a performance optimization in order to avoid an Option.foreach() call
- // inside of a loop
- @Nullable val (outputSerde, outputSoi) = {
- ioschema.initOutputSerDe(output).getOrElse((null, null))
- }
-
- val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))
- val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
- var curLine: String = null
- val scriptOutputStream = new DataInputStream(inputStream)
-
- @Nullable val scriptOutputReader =
- ioschema.recordReader(scriptOutputStream, hadoopConf).orNull
-
- var scriptOutputWritable: Writable = null
- val reusedWritableObject: Writable = if (null != outputSerde) {
- outputSerde.getSerializedClass().newInstance
- } else {
- null
- }
- val mutableRow = new SpecificInternalRow(output.map(_.dataType))
-
- @transient
- lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor)
-
- private def checkFailureAndPropagate(cause: Throwable = null): Unit = {
- if (writerThread.exception.isDefined) {
- throw writerThread.exception.get
- }
-
- // Checks if the proc is still alive (incase the command ran was bad)
- // The ideal way to do this is to use Java 8's Process#isAlive()
- // but it cannot be used because Spark still supports Java 7.
- // Following is a workaround used to check if a process is alive in Java 7
- // TODO: Once builds are switched to Java 8, this can be changed
- try {
- val exitCode = proc.exitValue()
- if (exitCode != 0) {
- logError(stderrBuffer.toString) // log the stderr circular buffer
- throw new SparkException(s"Subprocess exited with status $exitCode. " +
- s"Error: ${stderrBuffer.toString}", cause)
- }
- } catch {
- case _: IllegalThreadStateException =>
- // This means that the process is still alive. Move ahead
- }
- }
-
- override def hasNext: Boolean = {
- try {
- if (outputSerde == null) {
- if (curLine == null) {
- curLine = reader.readLine()
- if (curLine == null) {
- checkFailureAndPropagate()
- return false
- }
- }
- } else if (scriptOutputWritable == null) {
- scriptOutputWritable = reusedWritableObject
-
- if (scriptOutputReader != null) {
- if (scriptOutputReader.next(scriptOutputWritable) <= 0) {
- checkFailureAndPropagate()
- return false
- }
- } else {
- try {
- scriptOutputWritable.readFields(scriptOutputStream)
- } catch {
- case _: EOFException =>
- // This means that the stdout of `proc` (ie. TRANSFORM process) has exhausted.
- // Ideally the proc should *not* be alive at this point but
- // there can be a lag between EOF being written out and the process
- // being terminated. So explicitly waiting for the process to be done.
- proc.waitFor()
- checkFailureAndPropagate()
- return false
- }
- }
- }
-
- true
- } catch {
- case NonFatal(e) =>
- // If this exception is due to abrupt / unclean termination of `proc`,
- // then detect it and propagate a better exception message for end users
- checkFailureAndPropagate(e)
-
- throw e
- }
- }
-
- override def next(): InternalRow = {
- if (!hasNext) {
- throw new NoSuchElementException
- }
- if (outputSerde == null) {
- val prevLine = curLine
- curLine = reader.readLine()
- if (!ioschema.schemaLess) {
- new GenericInternalRow(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
- .map(CatalystTypeConverters.convertToCatalyst))
- } else {
- new GenericInternalRow(
- prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
- .map(CatalystTypeConverters.convertToCatalyst))
- }
- } else {
- val raw = outputSerde.deserialize(scriptOutputWritable)
- scriptOutputWritable = null
- val dataList = outputSoi.getStructFieldsDataAsList(raw)
- var i = 0
- while (i < dataList.size()) {
- if (dataList.get(i) == null) {
- mutableRow.setNullAt(i)
- } else {
- unwrappers(i)(dataList.get(i), mutableRow, i)
- }
- i += 1
- }
- mutableRow
- }
- }
- }
-
- writerThread.start()
-
- outputIterator
- }
-
- val broadcastedHadoopConf =
- new SerializableConfiguration(sqlContext.sessionState.newHadoopConf())
-
- child.execute().mapPartitions { iter =>
- if (iter.hasNext) {
- val proj = UnsafeProjection.create(schema)
- processIterator(iter, broadcastedHadoopConf.value).map(proj)
- } else {
- // If the input iterator has no rows then do not launch the external script.
- Iterator.empty
- }
- }
- }
-}
-
-private class ScriptTransformationWriterThread(
- iter: Iterator[InternalRow],
- inputSchema: Seq[DataType],
- outputProjection: Projection,
- @Nullable inputSerde: AbstractSerDe,
- @Nullable inputSoi: ObjectInspector,
- ioschema: HiveScriptIOSchema,
- outputStream: OutputStream,
- proc: Process,
- stderrBuffer: CircularBuffer,
- taskContext: TaskContext,
- conf: Configuration
- ) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
-
- setDaemon(true)
-
- @volatile private var _exception: Throwable = null
-
- /** Contains the exception thrown while writing the parent iterator to the external process. */
- def exception: Option[Throwable] = Option(_exception)
-
- override def run(): Unit = Utils.logUncaughtExceptions {
- TaskContext.setTaskContext(taskContext)
-
- val dataOutputStream = new DataOutputStream(outputStream)
- @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull
-
- // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
- // let's use a variable to record whether the `finally` block was hit due to an exception
- var threwException: Boolean = true
- val len = inputSchema.length
- try {
- iter.map(outputProjection).foreach { row =>
- if (inputSerde == null) {
- val data = if (len == 0) {
- ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")
- } else {
- val sb = new StringBuilder
- sb.append(row.get(0, inputSchema(0)))
- var i = 1
- while (i < len) {
- sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
- sb.append(row.get(i, inputSchema(i)))
- i += 1
- }
- sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES"))
- sb.toString()
- }
- outputStream.write(data.getBytes(StandardCharsets.UTF_8))
- } else {
- val writable = inputSerde.serialize(
- row.asInstanceOf[GenericInternalRow].values, inputSoi)
-
- if (scriptInputWriter != null) {
- scriptInputWriter.write(writable)
- } else {
- prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream)
- }
- }
- }
- threwException = false
- } catch {
- case t: Throwable =>
- // An error occurred while writing input, so kill the child process. According to the
- // Javadoc this call will not throw an exception:
- _exception = t
- proc.destroy()
- throw t
- } finally {
- try {
- Utils.tryLogNonFatalError(outputStream.close())
- if (proc.waitFor() != 0) {
- logError(stderrBuffer.toString) // log the stderr circular buffer
- }
- } catch {
- case NonFatal(exceptionFromFinallyBlock) =>
- if (!threwException) {
- throw exceptionFromFinallyBlock
- } else {
- log.error("Exception in finally block", exceptionFromFinallyBlock)
- }
- }
- }
- }
-}
-
-object HiveScriptIOSchema {
- def apply(input: ScriptInputOutputSchema): HiveScriptIOSchema = {
- HiveScriptIOSchema(
- input.inputRowFormat,
- input.outputRowFormat,
- input.inputSerdeClass,
- input.outputSerdeClass,
- input.inputSerdeProps,
- input.outputSerdeProps,
- input.recordReaderClass,
- input.recordWriterClass,
- input.schemaLess)
- }
-}
-
-/**
- * The wrapper class of Hive input and output schema properties
- */
-case class HiveScriptIOSchema (
- inputRowFormat: Seq[(String, String)],
- outputRowFormat: Seq[(String, String)],
- inputSerdeClass: Option[String],
- outputSerdeClass: Option[String],
- inputSerdeProps: Seq[(String, String)],
- outputSerdeProps: Seq[(String, String)],
- recordReaderClass: Option[String],
- recordWriterClass: Option[String],
- schemaLess: Boolean)
- extends HiveInspectors {
-
- private val defaultFormat = Map(
- ("TOK_TABLEROWFORMATFIELD", "\t"),
- ("TOK_TABLEROWFORMATLINES", "\n")
- )
-
- val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
- val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
-
-
- def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = {
- inputSerdeClass.map { serdeClass =>
- val (columns, columnTypes) = parseAttrs(input)
- val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps)
- val fieldObjectInspectors = columnTypes.map(toInspector)
- val objectInspector = ObjectInspectorFactory
- .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava)
- .asInstanceOf[ObjectInspector]
- (serde, objectInspector)
- }
- }
-
- def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = {
- outputSerdeClass.map { serdeClass =>
- val (columns, columnTypes) = parseAttrs(output)
- val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps)
- val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector]
- (serde, structObjectInspector)
- }
- }
-
- private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
- val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}")
- val columnTypes = attrs.map(_.dataType)
- (columns, columnTypes)
- }
-
- private def initSerDe(
- serdeClassName: String,
- columns: Seq[String],
- columnTypes: Seq[DataType],
- serdeProps: Seq[(String, String)]): AbstractSerDe = {
-
- val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe]
-
- val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
-
- var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
- propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
-
- val properties = new Properties()
- properties.putAll(propsMap.asJava)
- serde.initialize(null, properties)
-
- serde
- }
-
- def recordReader(
- inputStream: InputStream,
- conf: Configuration): Option[RecordReader] = {
- recordReaderClass.map { klass =>
- val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader]
- val props = new Properties()
- props.putAll(outputSerdeProps.toMap.asJava)
- instance.initialize(inputStream, conf, props)
- instance
- }
- }
-
- def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = {
- recordWriterClass.map { klass =>
- val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter]
- instance.initialize(outputStream, conf)
- instance
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala
new file mode 100644
index 0000000..e7c165c
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala
@@ -0,0 +1,448 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import java.io._
+import java.nio.charset.StandardCharsets
+import java.util.Properties
+import javax.annotation.Nullable
+
+import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter}
+import org.apache.hadoop.hive.serde.serdeConstants
+import org.apache.hadoop.hive.serde2.AbstractSerDe
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.io.Writable
+
+import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.hive.HiveInspectors
+import org.apache.spark.sql.hive.HiveShim._
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}
+
+/**
+ * Transforms the input by forking and running the specified script.
+ *
+ * @param input the set of expression that should be passed to the script.
+ * @param script the command that should be executed.
+ * @param output the attributes that are produced by the script.
+ */
+case class ScriptTransformationExec(
+ input: Seq[Expression],
+ script: String,
+ output: Seq[Attribute],
+ child: SparkPlan,
+ ioschema: HiveScriptIOSchema)
+ extends UnaryExecNode {
+
+ override def producedAttributes: AttributeSet = outputSet -- inputSet
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ def processIterator(inputIterator: Iterator[InternalRow], hadoopConf: Configuration)
+ : Iterator[InternalRow] = {
+ val cmd = List("/bin/bash", "-c", script)
+ val builder = new ProcessBuilder(cmd.asJava)
+
+ val proc = builder.start()
+ val inputStream = proc.getInputStream
+ val outputStream = proc.getOutputStream
+ val errorStream = proc.getErrorStream
+
+ // In order to avoid deadlocks, we need to consume the error output of the child process.
+ // To avoid issues caused by large error output, we use a circular buffer to limit the amount
+ // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang
+ // that motivates this.
+ val stderrBuffer = new CircularBuffer(2048)
+ new RedirectThread(
+ errorStream,
+ stderrBuffer,
+ "Thread-ScriptTransformation-STDERR-Consumer").start()
+
+ val outputProjection = new InterpretedProjection(input, child.output)
+
+ // This nullability is a performance optimization in order to avoid an Option.foreach() call
+ // inside of a loop
+ @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null))
+
+ // This new thread will consume the ScriptTransformation's input rows and write them to the
+ // external process. That process's output will be read by this current thread.
+ val writerThread = new ScriptTransformationWriterThread(
+ inputIterator,
+ input.map(_.dataType),
+ outputProjection,
+ inputSerde,
+ inputSoi,
+ ioschema,
+ outputStream,
+ proc,
+ stderrBuffer,
+ TaskContext.get(),
+ hadoopConf
+ )
+
+ // This nullability is a performance optimization in order to avoid an Option.foreach() call
+ // inside of a loop
+ @Nullable val (outputSerde, outputSoi) = {
+ ioschema.initOutputSerDe(output).getOrElse((null, null))
+ }
+
+ val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))
+ val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
+ var curLine: String = null
+ val scriptOutputStream = new DataInputStream(inputStream)
+
+ @Nullable val scriptOutputReader =
+ ioschema.recordReader(scriptOutputStream, hadoopConf).orNull
+
+ var scriptOutputWritable: Writable = null
+ val reusedWritableObject: Writable = if (null != outputSerde) {
+ outputSerde.getSerializedClass().newInstance
+ } else {
+ null
+ }
+ val mutableRow = new SpecificInternalRow(output.map(_.dataType))
+
+ @transient
+ lazy val unwrappers = outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor)
+
+ private def checkFailureAndPropagate(cause: Throwable = null): Unit = {
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+
+ // Checks if the proc is still alive (incase the command ran was bad)
+ // The ideal way to do this is to use Java 8's Process#isAlive()
+ // but it cannot be used because Spark still supports Java 7.
+ // Following is a workaround used to check if a process is alive in Java 7
+ // TODO: Once builds are switched to Java 8, this can be changed
+ try {
+ val exitCode = proc.exitValue()
+ if (exitCode != 0) {
+ logError(stderrBuffer.toString) // log the stderr circular buffer
+ throw new SparkException(s"Subprocess exited with status $exitCode. " +
+ s"Error: ${stderrBuffer.toString}", cause)
+ }
+ } catch {
+ case _: IllegalThreadStateException =>
+ // This means that the process is still alive. Move ahead
+ }
+ }
+
+ override def hasNext: Boolean = {
+ try {
+ if (outputSerde == null) {
+ if (curLine == null) {
+ curLine = reader.readLine()
+ if (curLine == null) {
+ checkFailureAndPropagate()
+ return false
+ }
+ }
+ } else if (scriptOutputWritable == null) {
+ scriptOutputWritable = reusedWritableObject
+
+ if (scriptOutputReader != null) {
+ if (scriptOutputReader.next(scriptOutputWritable) <= 0) {
+ checkFailureAndPropagate()
+ return false
+ }
+ } else {
+ try {
+ scriptOutputWritable.readFields(scriptOutputStream)
+ } catch {
+ case _: EOFException =>
+ // This means that the stdout of `proc` (ie. TRANSFORM process) has exhausted.
+ // Ideally the proc should *not* be alive at this point but
+ // there can be a lag between EOF being written out and the process
+ // being terminated. So explicitly waiting for the process to be done.
+ proc.waitFor()
+ checkFailureAndPropagate()
+ return false
+ }
+ }
+ }
+
+ true
+ } catch {
+ case NonFatal(e) =>
+ // If this exception is due to abrupt / unclean termination of `proc`,
+ // then detect it and propagate a better exception message for end users
+ checkFailureAndPropagate(e)
+
+ throw e
+ }
+ }
+
+ override def next(): InternalRow = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ if (outputSerde == null) {
+ val prevLine = curLine
+ curLine = reader.readLine()
+ if (!ioschema.schemaLess) {
+ new GenericInternalRow(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
+ .map(CatalystTypeConverters.convertToCatalyst))
+ } else {
+ new GenericInternalRow(
+ prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
+ .map(CatalystTypeConverters.convertToCatalyst))
+ }
+ } else {
+ val raw = outputSerde.deserialize(scriptOutputWritable)
+ scriptOutputWritable = null
+ val dataList = outputSoi.getStructFieldsDataAsList(raw)
+ var i = 0
+ while (i < dataList.size()) {
+ if (dataList.get(i) == null) {
+ mutableRow.setNullAt(i)
+ } else {
+ unwrappers(i)(dataList.get(i), mutableRow, i)
+ }
+ i += 1
+ }
+ mutableRow
+ }
+ }
+ }
+
+ writerThread.start()
+
+ outputIterator
+ }
+
+ val broadcastedHadoopConf =
+ new SerializableConfiguration(sqlContext.sessionState.newHadoopConf())
+
+ child.execute().mapPartitions { iter =>
+ if (iter.hasNext) {
+ val proj = UnsafeProjection.create(schema)
+ processIterator(iter, broadcastedHadoopConf.value).map(proj)
+ } else {
+ // If the input iterator has no rows then do not launch the external script.
+ Iterator.empty
+ }
+ }
+ }
+}
+
+private class ScriptTransformationWriterThread(
+ iter: Iterator[InternalRow],
+ inputSchema: Seq[DataType],
+ outputProjection: Projection,
+ @Nullable inputSerde: AbstractSerDe,
+ @Nullable inputSoi: ObjectInspector,
+ ioschema: HiveScriptIOSchema,
+ outputStream: OutputStream,
+ proc: Process,
+ stderrBuffer: CircularBuffer,
+ taskContext: TaskContext,
+ conf: Configuration
+ ) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
+
+ setDaemon(true)
+
+ @volatile private var _exception: Throwable = null
+
+ /** Contains the exception thrown while writing the parent iterator to the external process. */
+ def exception: Option[Throwable] = Option(_exception)
+
+ override def run(): Unit = Utils.logUncaughtExceptions {
+ TaskContext.setTaskContext(taskContext)
+
+ val dataOutputStream = new DataOutputStream(outputStream)
+ @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull
+
+ // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
+ // let's use a variable to record whether the `finally` block was hit due to an exception
+ var threwException: Boolean = true
+ val len = inputSchema.length
+ try {
+ iter.map(outputProjection).foreach { row =>
+ if (inputSerde == null) {
+ val data = if (len == 0) {
+ ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")
+ } else {
+ val sb = new StringBuilder
+ sb.append(row.get(0, inputSchema(0)))
+ var i = 1
+ while (i < len) {
+ sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
+ sb.append(row.get(i, inputSchema(i)))
+ i += 1
+ }
+ sb.append(ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES"))
+ sb.toString()
+ }
+ outputStream.write(data.getBytes(StandardCharsets.UTF_8))
+ } else {
+ val writable = inputSerde.serialize(
+ row.asInstanceOf[GenericInternalRow].values, inputSoi)
+
+ if (scriptInputWriter != null) {
+ scriptInputWriter.write(writable)
+ } else {
+ prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream)
+ }
+ }
+ }
+ threwException = false
+ } catch {
+ case t: Throwable =>
+ // An error occurred while writing input, so kill the child process. According to the
+ // Javadoc this call will not throw an exception:
+ _exception = t
+ proc.destroy()
+ throw t
+ } finally {
+ try {
+ Utils.tryLogNonFatalError(outputStream.close())
+ if (proc.waitFor() != 0) {
+ logError(stderrBuffer.toString) // log the stderr circular buffer
+ }
+ } catch {
+ case NonFatal(exceptionFromFinallyBlock) =>
+ if (!threwException) {
+ throw exceptionFromFinallyBlock
+ } else {
+ log.error("Exception in finally block", exceptionFromFinallyBlock)
+ }
+ }
+ }
+ }
+}
+
+object HiveScriptIOSchema {
+ def apply(input: ScriptInputOutputSchema): HiveScriptIOSchema = {
+ HiveScriptIOSchema(
+ input.inputRowFormat,
+ input.outputRowFormat,
+ input.inputSerdeClass,
+ input.outputSerdeClass,
+ input.inputSerdeProps,
+ input.outputSerdeProps,
+ input.recordReaderClass,
+ input.recordWriterClass,
+ input.schemaLess)
+ }
+}
+
+/**
+ * The wrapper class of Hive input and output schema properties
+ */
+case class HiveScriptIOSchema (
+ inputRowFormat: Seq[(String, String)],
+ outputRowFormat: Seq[(String, String)],
+ inputSerdeClass: Option[String],
+ outputSerdeClass: Option[String],
+ inputSerdeProps: Seq[(String, String)],
+ outputSerdeProps: Seq[(String, String)],
+ recordReaderClass: Option[String],
+ recordWriterClass: Option[String],
+ schemaLess: Boolean)
+ extends HiveInspectors {
+
+ private val defaultFormat = Map(
+ ("TOK_TABLEROWFORMATFIELD", "\t"),
+ ("TOK_TABLEROWFORMATLINES", "\n")
+ )
+
+ val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
+ val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
+
+
+ def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = {
+ inputSerdeClass.map { serdeClass =>
+ val (columns, columnTypes) = parseAttrs(input)
+ val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps)
+ val fieldObjectInspectors = columnTypes.map(toInspector)
+ val objectInspector = ObjectInspectorFactory
+ .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava)
+ .asInstanceOf[ObjectInspector]
+ (serde, objectInspector)
+ }
+ }
+
+ def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = {
+ outputSerdeClass.map { serdeClass =>
+ val (columns, columnTypes) = parseAttrs(output)
+ val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps)
+ val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector]
+ (serde, structObjectInspector)
+ }
+ }
+
+ private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
+ val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}")
+ val columnTypes = attrs.map(_.dataType)
+ (columns, columnTypes)
+ }
+
+ private def initSerDe(
+ serdeClassName: String,
+ columns: Seq[String],
+ columnTypes: Seq[DataType],
+ serdeProps: Seq[(String, String)]): AbstractSerDe = {
+
+ val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe]
+
+ val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
+
+ var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
+ propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
+
+ val properties = new Properties()
+ properties.putAll(propsMap.asJava)
+ serde.initialize(null, properties)
+
+ serde
+ }
+
+ def recordReader(
+ inputStream: InputStream,
+ conf: Configuration): Option[RecordReader] = {
+ recordReaderClass.map { klass =>
+ val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader]
+ val props = new Properties()
+ props.putAll(outputSerdeProps.toMap.asJava)
+ instance.initialize(inputStream, conf, props)
+ instance
+ }
+ }
+
+ def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = {
+ recordWriterClass.map { klass =>
+ val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter]
+ instance.initialize(outputStream, conf)
+ instance
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
deleted file mode 100644
index 0c93210..0000000
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ /dev/null
@@ -1,356 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.text.NumberFormat
-import java.util.{Date, Locale}
-
-import scala.collection.JavaConverters._
-
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.hive.common.FileUtils
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities}
-import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
-import org.apache.hadoop.hive.ql.plan.TableDesc
-import org.apache.hadoop.hive.serde2.Serializer
-import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorUtils, StructObjectInspector}
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.mapred._
-import org.apache.hadoop.mapreduce.TaskType
-
-import org.apache.spark._
-import org.apache.spark.internal.Logging
-import org.apache.spark.internal.io.SparkHadoopWriterUtils
-import org.apache.spark.mapred.SparkHadoopMapRedUtil
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.UnsafeKVExternalSorter
-import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
-import org.apache.spark.sql.types._
-import org.apache.spark.util.SerializableJobConf
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
-
-/**
- * Internal helper class that saves an RDD using a Hive OutputFormat.
- * It is based on `SparkHadoopWriter`.
- */
-private[hive] class SparkHiveWriterContainer(
- @transient private val jobConf: JobConf,
- fileSinkConf: FileSinkDesc,
- inputSchema: Seq[Attribute])
- extends Logging
- with HiveInspectors
- with Serializable {
-
- private val now = new Date()
- private val tableDesc: TableDesc = fileSinkConf.getTableInfo
- // Add table properties from storage handler to jobConf, so any custom storage
- // handler settings can be set to jobConf
- if (tableDesc != null) {
- HiveTableUtil.configureJobPropertiesForStorageHandler(tableDesc, jobConf, false)
- Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf)
- }
- protected val conf = new SerializableJobConf(jobConf)
-
- private var jobID = 0
- private var splitID = 0
- private var attemptID = 0
- private var jID: SerializableWritable[JobID] = null
- private var taID: SerializableWritable[TaskAttemptID] = null
-
- @transient private var writer: FileSinkOperator.RecordWriter = null
- @transient protected lazy val committer = conf.value.getOutputCommitter
- @transient protected lazy val jobContext = new JobContextImpl(conf.value, jID.value)
- @transient private lazy val taskContext = new TaskAttemptContextImpl(conf.value, taID.value)
- @transient private lazy val outputFormat =
- conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]]
-
- def driverSideSetup() {
- setIDs(0, 0, 0)
- setConfParams()
- committer.setupJob(jobContext)
- }
-
- def executorSideSetup(jobId: Int, splitId: Int, attemptId: Int) {
- setIDs(jobId, splitId, attemptId)
- setConfParams()
- committer.setupTask(taskContext)
- initWriters()
- }
-
- protected def getOutputName: String = {
- val numberFormat = NumberFormat.getInstance(Locale.US)
- numberFormat.setMinimumIntegerDigits(5)
- numberFormat.setGroupingUsed(false)
- val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat)
- "part-" + numberFormat.format(splitID) + extension
- }
-
- def close() {
- // Seems the boolean value passed into close does not matter.
- if (writer != null) {
- writer.close(false)
- commit()
- }
- }
-
- def commitJob() {
- committer.commitJob(jobContext)
- }
-
- protected def initWriters() {
- // NOTE this method is executed at the executor side.
- // For Hive tables without partitions or with only static partitions, only 1 writer is needed.
- writer = HiveFileFormatUtils.getHiveRecordWriter(
- conf.value,
- fileSinkConf.getTableInfo,
- conf.value.getOutputValueClass.asInstanceOf[Class[Writable]],
- fileSinkConf,
- FileOutputFormat.getTaskOutputPath(conf.value, getOutputName),
- Reporter.NULL)
- }
-
- protected def commit() {
- SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID)
- }
-
- def abortTask(): Unit = {
- if (committer != null) {
- committer.abortTask(taskContext)
- }
- logError(s"Task attempt $taskContext aborted.")
- }
-
- private def setIDs(jobId: Int, splitId: Int, attemptId: Int) {
- jobID = jobId
- splitID = splitId
- attemptID = attemptId
-
- jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobId))
- taID = new SerializableWritable[TaskAttemptID](
- new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID))
- }
-
- private def setConfParams() {
- conf.value.set("mapred.job.id", jID.value.toString)
- conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
- conf.value.set("mapred.task.id", taID.value.toString)
- conf.value.setBoolean("mapred.task.is.map", true)
- conf.value.setInt("mapred.task.partition", splitID)
- }
-
- def newSerializer(tableDesc: TableDesc): Serializer = {
- val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer]
- serializer.initialize(null, tableDesc.getProperties)
- serializer
- }
-
- protected def prepareForWrite() = {
- val serializer = newSerializer(fileSinkConf.getTableInfo)
- val standardOI = ObjectInspectorUtils
- .getStandardObjectInspector(
- fileSinkConf.getTableInfo.getDeserializer.getObjectInspector,
- ObjectInspectorCopyOption.JAVA)
- .asInstanceOf[StructObjectInspector]
-
- val fieldOIs = standardOI.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray
- val dataTypes = inputSchema.map(_.dataType)
- val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt) }
- val outputData = new Array[Any](fieldOIs.length)
- (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData)
- }
-
- // this function is executed on executor side
- def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
- executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
-
- iterator.foreach { row =>
- var i = 0
- while (i < fieldOIs.length) {
- outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
- i += 1
- }
- writer.write(serializer.serialize(outputData, standardOI))
- }
-
- close()
- }
-}
-
-private[hive] object SparkHiveWriterContainer {
- def createPathFromString(path: String, conf: JobConf): Path = {
- if (path == null) {
- throw new IllegalArgumentException("Output path is null")
- }
- val outputPath = new Path(path)
- val fs = outputPath.getFileSystem(conf)
- if (outputPath == null || fs == null) {
- throw new IllegalArgumentException("Incorrectly formatted output path")
- }
- outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- }
-}
-
-private[spark] object SparkHiveDynamicPartitionWriterContainer {
- val SUCCESSFUL_JOB_OUTPUT_DIR_MARKER = "mapreduce.fileoutputcommitter.marksuccessfuljobs"
-}
-
-private[spark] class SparkHiveDynamicPartitionWriterContainer(
- jobConf: JobConf,
- fileSinkConf: FileSinkDesc,
- dynamicPartColNames: Array[String],
- inputSchema: Seq[Attribute])
- extends SparkHiveWriterContainer(jobConf, fileSinkConf, inputSchema) {
-
- import SparkHiveDynamicPartitionWriterContainer._
-
- private val defaultPartName = jobConf.get(
- ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultStrVal)
-
- override protected def initWriters(): Unit = {
- // do nothing
- }
-
- override def close(): Unit = {
- // do nothing
- }
-
- override def commitJob(): Unit = {
- // This is a hack to avoid writing _SUCCESS mark file. In lower versions of Hadoop (e.g. 1.0.4),
- // semantics of FileSystem.globStatus() is different from higher versions (e.g. 2.4.1) and will
- // include _SUCCESS file when glob'ing for dynamic partition data files.
- //
- // Better solution is to add a step similar to what Hive FileSinkOperator.jobCloseOp does:
- // calling something like Utilities.mvFileToFinalPath to cleanup the output directory and then
- // load it with loadDynamicPartitions/loadPartition/loadTable.
- val oldMarker = conf.value.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true)
- conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false)
- super.commitJob()
- conf.value.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker)
- }
-
- // this function is executed on executor side
- override def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
- executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
-
- val partitionOutput = inputSchema.takeRight(dynamicPartColNames.length)
- val dataOutput = inputSchema.take(fieldOIs.length)
- // Returns the partition key given an input row
- val getPartitionKey = UnsafeProjection.create(partitionOutput, inputSchema)
- // Returns the data columns to be written given an input row
- val getOutputRow = UnsafeProjection.create(dataOutput, inputSchema)
-
- val fun: AnyRef = (pathString: String) => FileUtils.escapePathName(pathString, defaultPartName)
- // Expressions that given a partition key build a string like: col1=val/col2=val/...
- val partitionStringExpression = partitionOutput.zipWithIndex.flatMap { case (c, i) =>
- val escaped =
- ScalaUDF(fun, StringType, Seq(Cast(c, StringType)), Seq(StringType))
- val str = If(IsNull(c), Literal(defaultPartName), escaped)
- val partitionName = Literal(dynamicPartColNames(i) + "=") :: str :: Nil
- if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName
- }
-
- // Returns the partition path given a partition key.
- val getPartitionString =
- UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionOutput)
-
- // If anything below fails, we should abort the task.
- try {
- val sorter: UnsafeKVExternalSorter = new UnsafeKVExternalSorter(
- StructType.fromAttributes(partitionOutput),
- StructType.fromAttributes(dataOutput),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes,
- SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
- UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
-
- while (iterator.hasNext) {
- val inputRow = iterator.next()
- val currentKey = getPartitionKey(inputRow)
- sorter.insertKV(currentKey, getOutputRow(inputRow))
- }
-
- logInfo(s"Sorting complete. Writing out partition files one at a time.")
- val sortedIterator = sorter.sortedIterator()
- var currentKey: InternalRow = null
- var currentWriter: FileSinkOperator.RecordWriter = null
- try {
- while (sortedIterator.next()) {
- if (currentKey != sortedIterator.getKey) {
- if (currentWriter != null) {
- currentWriter.close(false)
- }
- currentKey = sortedIterator.getKey.copy()
- logDebug(s"Writing partition: $currentKey")
- currentWriter = newOutputWriter(currentKey)
- }
-
- var i = 0
- while (i < fieldOIs.length) {
- outputData(i) = if (sortedIterator.getValue.isNullAt(i)) {
- null
- } else {
- wrappers(i)(sortedIterator.getValue.get(i, dataTypes(i)))
- }
- i += 1
- }
- currentWriter.write(serializer.serialize(outputData, standardOI))
- }
- } finally {
- if (currentWriter != null) {
- currentWriter.close(false)
- }
- }
- commit()
- } catch {
- case cause: Throwable =>
- logError("Aborting task.", cause)
- abortTask()
- throw new SparkException("Task failed while writing rows.", cause)
- }
- /** Open and returns a new OutputWriter given a partition key. */
- def newOutputWriter(key: InternalRow): FileSinkOperator.RecordWriter = {
- val partitionPath = getPartitionString(key).getString(0)
- val newFileSinkDesc = new FileSinkDesc(
- fileSinkConf.getDirName + partitionPath,
- fileSinkConf.getTableInfo,
- fileSinkConf.getCompressed)
- newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec)
- newFileSinkDesc.setCompressType(fileSinkConf.getCompressType)
-
- // use the path like ${hive_tmp}/_temporary/${attemptId}/
- // to avoid write to the same file when `spark.speculation=true`
- val path = FileOutputFormat.getTaskOutputPath(
- conf.value,
- partitionPath.stripPrefix("/") + "/" + getOutputName)
-
- HiveFileFormatUtils.getHiveRecordWriter(
- conf.value,
- fileSinkConf.getTableInfo,
- conf.value.getOutputValueClass.asInstanceOf[Class[Writable]],
- newFileSinkDesc,
- path,
- Reporter.NULL)
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/4494cd97/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index 5cb8519..28b5bfd 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -565,8 +565,8 @@ class VersionsSuite extends SparkFunSuite with SQLTestUtils with TestHiveSinglet
val filePaths = dir.map(_.getName).toList
folders.flatMap(listFiles) ++: filePaths
}
- val expectedFiles = ".part-00000.crc" :: "part-00000" :: Nil
- assert(listFiles(tmpDir).sorted == expectedFiles)
+ // expect 2 files left: `.part-00000-random-uuid.crc` and `part-00000-random-uuid`
+ assert(listFiles(tmpDir).length == 2)
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org