You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/12/21 22:50:40 UTC
spark git commit: [SPARK-18775][SQL] Limit the max number of records
written per file
Repository: spark
Updated Branches:
refs/heads/master 078c71c2d -> 354e93618
[SPARK-18775][SQL] Limit the max number of records written per file
## What changes were proposed in this pull request?
Currently, Spark writes a single file out per task, sometimes leading to very large files. It would be great to have an option to limit the max number of records written per file in a task, to avoid humongous files.
This patch introduces a new write config option `maxRecordsPerFile` (default to a session-wide setting `spark.sql.files.maxRecordsPerFile`) that limits the max number of records written to a single file. A non-positive value indicates there is no limit (same behavior as not having this flag).
## How was this patch tested?
Added test cases in PartitionedWriteSuite for both dynamic partition insert and non-dynamic partition insert.
Author: Reynold Xin <rx...@databricks.com>
Closes #16204 from rxin/SPARK-18775.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/354e9361
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/354e9361
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/354e9361
Branch: refs/heads/master
Commit: 354e936187708a404c0349e3d8815a47953123ec
Parents: 078c71c
Author: Reynold Xin <rx...@databricks.com>
Authored: Wed Dec 21 23:50:35 2016 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Wed Dec 21 23:50:35 2016 +0100
----------------------------------------------------------------------
.../datasources/FileFormatWriter.scala | 109 ++++++++++++++-----
.../org/apache/spark/sql/internal/SQLConf.scala | 26 +++--
.../datasources/BucketingUtilsSuite.scala | 46 ++++++++
.../sql/sources/PartitionedWriteSuite.scala | 37 +++++++
4 files changed, 179 insertions(+), 39 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/354e9361/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index d560ad5..1eb4541 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -31,13 +31,12 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
-import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -47,6 +46,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/** A helper object for writing FileFormat data out to a location. */
object FileFormatWriter extends Logging {
+ /**
+ * Max number of files a single task writes out due to file size. In most cases the number of
+ * files written should be very small. This is just a safe guard to protect some really bad
+ * settings, e.g. maxRecordsPerFile = 1.
+ */
+ private val MAX_FILE_COUNTER = 1000 * 1000
+
/** Describes how output files should be placed in the filesystem. */
case class OutputSpec(
outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String])
@@ -61,7 +67,8 @@ object FileFormatWriter extends Logging {
val nonPartitionColumns: Seq[Attribute],
val bucketSpec: Option[BucketSpec],
val path: String,
- val customPartitionLocations: Map[TablePartitionSpec, String])
+ val customPartitionLocations: Map[TablePartitionSpec, String],
+ val maxRecordsPerFile: Long)
extends Serializable {
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
@@ -116,7 +123,10 @@ object FileFormatWriter extends Logging {
nonPartitionColumns = dataColumns,
bucketSpec = bucketSpec,
path = outputSpec.outputPath,
- customPartitionLocations = outputSpec.customPartitionLocations)
+ customPartitionLocations = outputSpec.customPartitionLocations,
+ maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
+ .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
+ )
SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
// This call shouldn't be put into the `try` block below because it only initializes and
@@ -225,32 +235,49 @@ object FileFormatWriter extends Logging {
taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) extends ExecuteWriteTask {
- private[this] var outputWriter: OutputWriter = {
+ private[this] var currentWriter: OutputWriter = _
+
+ private def newOutputWriter(fileCounter: Int): Unit = {
+ val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
val tmpFilePath = committer.newTaskTempFile(
taskAttemptContext,
None,
- description.outputWriterFactory.getFileExtension(taskAttemptContext))
+ f"-c$fileCounter%03d" + ext)
- val outputWriter = description.outputWriterFactory.newInstance(
+ currentWriter = description.outputWriterFactory.newInstance(
path = tmpFilePath,
dataSchema = description.nonPartitionColumns.toStructType,
context = taskAttemptContext)
- outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
- outputWriter
+ currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
}
override def execute(iter: Iterator[InternalRow]): Set[String] = {
+ var fileCounter = 0
+ var recordsInFile: Long = 0L
+ newOutputWriter(fileCounter)
while (iter.hasNext) {
+ if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
+ fileCounter += 1
+ assert(fileCounter < MAX_FILE_COUNTER,
+ s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+
+ recordsInFile = 0
+ releaseResources()
+ newOutputWriter(fileCounter)
+ }
+
val internalRow = iter.next()
- outputWriter.writeInternal(internalRow)
+ currentWriter.writeInternal(internalRow)
+ recordsInFile += 1
}
+ releaseResources()
Set.empty
}
override def releaseResources(): Unit = {
- if (outputWriter != null) {
- outputWriter.close()
- outputWriter = null
+ if (currentWriter != null) {
+ currentWriter.close()
+ currentWriter = null
}
}
}
@@ -300,8 +327,15 @@ object FileFormatWriter extends Logging {
* Open and returns a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
+ *
+ * @param key vaues for fields consisting of partition keys for the current row
+ * @param partString a function that projects the partition values into a string
+ * @param fileCounter the number of files that have been written in the past for this specific
+ * partition. This is used to limit the max number of records written for a
+ * single file. The value should start from 0.
*/
- private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = {
+ private def newOutputWriter(
+ key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = {
val partDir =
if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0))
@@ -311,7 +345,10 @@ object FileFormatWriter extends Logging {
} else {
""
}
- val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext)
+
+ // This must be in a form that matches our bucketing format. See BucketingUtils.
+ val ext = f"$bucketId.c$fileCounter%03d" +
+ description.outputWriterFactory.getFileExtension(taskAttemptContext)
val customPath = partDir match {
case Some(dir) =>
@@ -324,12 +361,12 @@ object FileFormatWriter extends Logging {
} else {
committer.newTaskTempFile(taskAttemptContext, partDir, ext)
}
- val newWriter = description.outputWriterFactory.newInstance(
+
+ currentWriter = description.outputWriterFactory.newInstance(
path = path,
dataSchema = description.nonPartitionColumns.toStructType,
context = taskAttemptContext)
- newWriter.initConverter(description.nonPartitionColumns.toStructType)
- newWriter
+ currentWriter.initConverter(description.nonPartitionColumns.toStructType)
}
override def execute(iter: Iterator[InternalRow]): Set[String] = {
@@ -349,7 +386,7 @@ object FileFormatWriter extends Logging {
description.nonPartitionColumns, description.allColumns)
// Returns the partition path given a partition key.
- val getPartitionString = UnsafeProjection.create(
+ val getPartitionStringFunc = UnsafeProjection.create(
Seq(Concat(partitionStringExpression)), description.partitionColumns)
// Sorts the data before write, so that we only need one writer at the same time.
@@ -366,7 +403,6 @@ object FileFormatWriter extends Logging {
val currentRow = iter.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
- logInfo(s"Sorting complete. Writing out partition files one at a time.")
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
@@ -379,30 +415,43 @@ object FileFormatWriter extends Logging {
val sortedIterator = sorter.sortedIterator()
// If anything below fails, we should abort the task.
+ var recordsInFile: Long = 0L
+ var fileCounter = 0
var currentKey: UnsafeRow = null
val updatedPartitions = mutable.Set[String]()
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
- if (currentWriter != null) {
- currentWriter.close()
- currentWriter = null
- }
+ // See a new key - write to a new partition (new file).
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")
- currentWriter = newOutputWriter(currentKey, getPartitionString)
- val partitionPath = getPartitionString(currentKey).getString(0)
+ recordsInFile = 0
+ fileCounter = 0
+
+ releaseResources()
+ newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
+ val partitionPath = getPartitionStringFunc(currentKey).getString(0)
if (partitionPath.nonEmpty) {
updatedPartitions.add(partitionPath)
}
+ } else if (description.maxRecordsPerFile > 0 &&
+ recordsInFile >= description.maxRecordsPerFile) {
+ // Exceeded the threshold in terms of the number of records per file.
+ // Create a new file by increasing the file counter.
+ recordsInFile = 0
+ fileCounter += 1
+ assert(fileCounter < MAX_FILE_COUNTER,
+ s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+
+ releaseResources()
+ newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
}
+
currentWriter.writeInternal(sortedIterator.getValue)
+ recordsInFile += 1
}
- if (currentWriter != null) {
- currentWriter.close()
- currentWriter = null
- }
+ releaseResources()
updatedPartitions.toSet
}
http://git-wip-us.apache.org/repos/asf/spark/blob/354e9361/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 4d25f54..cce1626 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -466,6 +466,19 @@ object SQLConf {
.longConf
.createWithDefault(4 * 1024 * 1024)
+ val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
+ .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
+ "encountering corrupted or non-existing and contents that have been read will still be " +
+ "returned.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val MAX_RECORDS_PER_FILE = SQLConfigBuilder("spark.sql.files.maxRecordsPerFile")
+ .doc("Maximum number of records to write out to a single file. " +
+ "If this value is zero or negative, there is no limit.")
+ .longConf
+ .createWithDefault(0)
+
val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse")
.internal()
.doc("When true, the planner will try to find out duplicated exchanges and re-use them.")
@@ -629,13 +642,6 @@ object SQLConf {
.doubleConf
.createWithDefault(0.05)
- val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
- .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
- "encountering corrupted or non-existing and contents that have been read will still be " +
- "returned.")
- .booleanConf
- .createWithDefault(false)
-
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -700,6 +706,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES)
+ def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)
+
+ def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE)
+
def useCompression: Boolean = getConf(COMPRESS_CACHED)
def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
@@ -821,8 +831,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString
- def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)
-
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
http://git-wip-us.apache.org/repos/asf/spark/blob/354e9361/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala
new file mode 100644
index 0000000..9d892bb
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.SparkFunSuite
+
+class BucketingUtilsSuite extends SparkFunSuite {
+
+ test("generate bucket id") {
+ assert(BucketingUtils.bucketIdToString(0) == "_00000")
+ assert(BucketingUtils.bucketIdToString(10) == "_00010")
+ assert(BucketingUtils.bucketIdToString(999999) == "_999999")
+ }
+
+ test("match bucket ids") {
+ def testCase(filename: String, expected: Option[Int]): Unit = withClue(s"name: $filename") {
+ assert(BucketingUtils.getBucketId(filename) == expected)
+ }
+
+ testCase("a_1", Some(1))
+ testCase("a_1.txt", Some(1))
+ testCase("a_9999999", Some(9999999))
+ testCase("a_9999999.txt", Some(9999999))
+ testCase("a_1.c2.txt", Some(1))
+ testCase("a_1.", Some(1))
+
+ testCase("a_1:txt", None)
+ testCase("a_1-c2.txt", None)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/354e9361/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
index a2decad..953604e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.sources
+import java.io.File
+
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -61,4 +63,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
}
}
+
+ test("maxRecordsPerFile setting in non-partitioned write path") {
+ withTempDir { f =>
+ spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
+ .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath)
+ assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4)
+
+ spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
+ .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath)
+ assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2)
+
+ spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
+ .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath)
+ assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1)
+ }
+ }
+
+ test("maxRecordsPerFile setting in dynamic partition writes") {
+ withTempDir { f =>
+ spark.range(start = 0, end = 4, step = 1, numPartitions = 1).selectExpr("id", "id id1")
+ .write
+ .partitionBy("id")
+ .option("maxRecordsPerFile", 1)
+ .mode("overwrite")
+ .parquet(f.getAbsolutePath)
+ assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4)
+ }
+ }
+
+ /** Lists files recursively. */
+ private def recursiveList(f: File): Array[File] = {
+ require(f.isDirectory)
+ val current = f.listFiles
+ current ++ current.filter(_.isDirectory).flatMap(recursiveList)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org