You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/12/21 22:50:40 UTC

spark git commit: [SPARK-18775][SQL] Limit the max number of records written per file

Repository: spark
Updated Branches:
  refs/heads/master 078c71c2d -> 354e93618


[SPARK-18775][SQL] Limit the max number of records written per file

## What changes were proposed in this pull request?
Currently, Spark writes a single file out per task, sometimes leading to very large files. It would be great to have an option to limit the max number of records written per file in a task, to avoid humongous files.

This patch introduces a new write config option `maxRecordsPerFile` (default to a session-wide setting `spark.sql.files.maxRecordsPerFile`) that limits the max number of records written to a single file. A non-positive value indicates there is no limit (same behavior as not having this flag).

## How was this patch tested?
Added test cases in PartitionedWriteSuite for both dynamic partition insert and non-dynamic partition insert.

Author: Reynold Xin <rx...@databricks.com>

Closes #16204 from rxin/SPARK-18775.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/354e9361
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/354e9361
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/354e9361

Branch: refs/heads/master
Commit: 354e936187708a404c0349e3d8815a47953123ec
Parents: 078c71c
Author: Reynold Xin <rx...@databricks.com>
Authored: Wed Dec 21 23:50:35 2016 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Wed Dec 21 23:50:35 2016 +0100

----------------------------------------------------------------------
 .../datasources/FileFormatWriter.scala          | 109 ++++++++++++++-----
 .../org/apache/spark/sql/internal/SQLConf.scala |  26 +++--
 .../datasources/BucketingUtilsSuite.scala       |  46 ++++++++
 .../sql/sources/PartitionedWriteSuite.scala     |  37 +++++++
 4 files changed, 179 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/354e9361/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index d560ad5..1eb4541 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -31,13 +31,12 @@ import org.apache.spark._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
 import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
-import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
 import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -47,6 +46,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 /** A helper object for writing FileFormat data out to a location. */
 object FileFormatWriter extends Logging {
 
+  /**
+   * Max number of files a single task writes out due to file size. In most cases the number of
+   * files written should be very small. This is just a safe guard to protect some really bad
+   * settings, e.g. maxRecordsPerFile = 1.
+   */
+  private val MAX_FILE_COUNTER = 1000 * 1000
+
   /** Describes how output files should be placed in the filesystem. */
   case class OutputSpec(
     outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String])
@@ -61,7 +67,8 @@ object FileFormatWriter extends Logging {
       val nonPartitionColumns: Seq[Attribute],
       val bucketSpec: Option[BucketSpec],
       val path: String,
-      val customPartitionLocations: Map[TablePartitionSpec, String])
+      val customPartitionLocations: Map[TablePartitionSpec, String],
+      val maxRecordsPerFile: Long)
     extends Serializable {
 
     assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
@@ -116,7 +123,10 @@ object FileFormatWriter extends Logging {
       nonPartitionColumns = dataColumns,
       bucketSpec = bucketSpec,
       path = outputSpec.outputPath,
-      customPartitionLocations = outputSpec.customPartitionLocations)
+      customPartitionLocations = outputSpec.customPartitionLocations,
+      maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
+        .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
+    )
 
     SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
       // This call shouldn't be put into the `try` block below because it only initializes and
@@ -225,32 +235,49 @@ object FileFormatWriter extends Logging {
       taskAttemptContext: TaskAttemptContext,
       committer: FileCommitProtocol) extends ExecuteWriteTask {
 
-    private[this] var outputWriter: OutputWriter = {
+    private[this] var currentWriter: OutputWriter = _
+
+    private def newOutputWriter(fileCounter: Int): Unit = {
+      val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
       val tmpFilePath = committer.newTaskTempFile(
         taskAttemptContext,
         None,
-        description.outputWriterFactory.getFileExtension(taskAttemptContext))
+        f"-c$fileCounter%03d" + ext)
 
-      val outputWriter = description.outputWriterFactory.newInstance(
+      currentWriter = description.outputWriterFactory.newInstance(
         path = tmpFilePath,
         dataSchema = description.nonPartitionColumns.toStructType,
         context = taskAttemptContext)
-      outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
-      outputWriter
+      currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
     }
 
     override def execute(iter: Iterator[InternalRow]): Set[String] = {
+      var fileCounter = 0
+      var recordsInFile: Long = 0L
+      newOutputWriter(fileCounter)
       while (iter.hasNext) {
+        if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
+          fileCounter += 1
+          assert(fileCounter < MAX_FILE_COUNTER,
+            s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+
+          recordsInFile = 0
+          releaseResources()
+          newOutputWriter(fileCounter)
+        }
+
         val internalRow = iter.next()
-        outputWriter.writeInternal(internalRow)
+        currentWriter.writeInternal(internalRow)
+        recordsInFile += 1
       }
+      releaseResources()
       Set.empty
     }
 
     override def releaseResources(): Unit = {
-      if (outputWriter != null) {
-        outputWriter.close()
-        outputWriter = null
+      if (currentWriter != null) {
+        currentWriter.close()
+        currentWriter = null
       }
     }
   }
@@ -300,8 +327,15 @@ object FileFormatWriter extends Logging {
      * Open and returns a new OutputWriter given a partition key and optional bucket id.
      * If bucket id is specified, we will append it to the end of the file name, but before the
      * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
+     *
+     * @param key vaues for fields consisting of partition keys for the current row
+     * @param partString a function that projects the partition values into a string
+     * @param fileCounter the number of files that have been written in the past for this specific
+     *                    partition. This is used to limit the max number of records written for a
+     *                    single file. The value should start from 0.
      */
-    private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = {
+    private def newOutputWriter(
+        key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = {
       val partDir =
         if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0))
 
@@ -311,7 +345,10 @@ object FileFormatWriter extends Logging {
       } else {
         ""
       }
-      val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext)
+
+      // This must be in a form that matches our bucketing format. See BucketingUtils.
+      val ext = f"$bucketId.c$fileCounter%03d" +
+        description.outputWriterFactory.getFileExtension(taskAttemptContext)
 
       val customPath = partDir match {
         case Some(dir) =>
@@ -324,12 +361,12 @@ object FileFormatWriter extends Logging {
       } else {
         committer.newTaskTempFile(taskAttemptContext, partDir, ext)
       }
-      val newWriter = description.outputWriterFactory.newInstance(
+
+      currentWriter = description.outputWriterFactory.newInstance(
         path = path,
         dataSchema = description.nonPartitionColumns.toStructType,
         context = taskAttemptContext)
-      newWriter.initConverter(description.nonPartitionColumns.toStructType)
-      newWriter
+      currentWriter.initConverter(description.nonPartitionColumns.toStructType)
     }
 
     override def execute(iter: Iterator[InternalRow]): Set[String] = {
@@ -349,7 +386,7 @@ object FileFormatWriter extends Logging {
         description.nonPartitionColumns, description.allColumns)
 
       // Returns the partition path given a partition key.
-      val getPartitionString = UnsafeProjection.create(
+      val getPartitionStringFunc = UnsafeProjection.create(
         Seq(Concat(partitionStringExpression)), description.partitionColumns)
 
       // Sorts the data before write, so that we only need one writer at the same time.
@@ -366,7 +403,6 @@ object FileFormatWriter extends Logging {
         val currentRow = iter.next()
         sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
       }
-      logInfo(s"Sorting complete. Writing out partition files one at a time.")
 
       val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
         identity
@@ -379,30 +415,43 @@ object FileFormatWriter extends Logging {
       val sortedIterator = sorter.sortedIterator()
 
       // If anything below fails, we should abort the task.
+      var recordsInFile: Long = 0L
+      var fileCounter = 0
       var currentKey: UnsafeRow = null
       val updatedPartitions = mutable.Set[String]()
       while (sortedIterator.next()) {
         val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
         if (currentKey != nextKey) {
-          if (currentWriter != null) {
-            currentWriter.close()
-            currentWriter = null
-          }
+          // See a new key - write to a new partition (new file).
           currentKey = nextKey.copy()
           logDebug(s"Writing partition: $currentKey")
 
-          currentWriter = newOutputWriter(currentKey, getPartitionString)
-          val partitionPath = getPartitionString(currentKey).getString(0)
+          recordsInFile = 0
+          fileCounter = 0
+
+          releaseResources()
+          newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
+          val partitionPath = getPartitionStringFunc(currentKey).getString(0)
           if (partitionPath.nonEmpty) {
             updatedPartitions.add(partitionPath)
           }
+        } else if (description.maxRecordsPerFile > 0 &&
+            recordsInFile >= description.maxRecordsPerFile) {
+          // Exceeded the threshold in terms of the number of records per file.
+          // Create a new file by increasing the file counter.
+          recordsInFile = 0
+          fileCounter += 1
+          assert(fileCounter < MAX_FILE_COUNTER,
+            s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+
+          releaseResources()
+          newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
         }
+
         currentWriter.writeInternal(sortedIterator.getValue)
+        recordsInFile += 1
       }
-      if (currentWriter != null) {
-        currentWriter.close()
-        currentWriter = null
-      }
+      releaseResources()
       updatedPartitions.toSet
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/354e9361/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 4d25f54..cce1626 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -466,6 +466,19 @@ object SQLConf {
     .longConf
     .createWithDefault(4 * 1024 * 1024)
 
+  val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
+    .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
+      "encountering corrupted or non-existing and contents that have been read will still be " +
+      "returned.")
+    .booleanConf
+    .createWithDefault(false)
+
+  val MAX_RECORDS_PER_FILE = SQLConfigBuilder("spark.sql.files.maxRecordsPerFile")
+    .doc("Maximum number of records to write out to a single file. " +
+      "If this value is zero or negative, there is no limit.")
+    .longConf
+    .createWithDefault(0)
+
   val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse")
     .internal()
     .doc("When true, the planner will try to find out duplicated exchanges and re-use them.")
@@ -629,13 +642,6 @@ object SQLConf {
       .doubleConf
       .createWithDefault(0.05)
 
-  val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles")
-    .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
-      "encountering corrupted or non-existing and contents that have been read will still be " +
-      "returned.")
-    .booleanConf
-    .createWithDefault(false)
-
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
   }
@@ -700,6 +706,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
 
   def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES)
 
+  def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)
+
+  def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE)
+
   def useCompression: Boolean = getConf(COMPRESS_CACHED)
 
   def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
@@ -821,8 +831,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
 
   def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString
 
-  def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)
-
   override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
 
   override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)

http://git-wip-us.apache.org/repos/asf/spark/blob/354e9361/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala
new file mode 100644
index 0000000..9d892bb
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.SparkFunSuite
+
+class BucketingUtilsSuite extends SparkFunSuite {
+
+  test("generate bucket id") {
+    assert(BucketingUtils.bucketIdToString(0) == "_00000")
+    assert(BucketingUtils.bucketIdToString(10) == "_00010")
+    assert(BucketingUtils.bucketIdToString(999999) == "_999999")
+  }
+
+  test("match bucket ids") {
+    def testCase(filename: String, expected: Option[Int]): Unit = withClue(s"name: $filename") {
+      assert(BucketingUtils.getBucketId(filename) == expected)
+    }
+
+    testCase("a_1", Some(1))
+    testCase("a_1.txt", Some(1))
+    testCase("a_9999999", Some(9999999))
+    testCase("a_9999999.txt", Some(9999999))
+    testCase("a_1.c2.txt", Some(1))
+    testCase("a_1.", Some(1))
+
+    testCase("a_1:txt", None)
+    testCase("a_1-c2.txt", None)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/354e9361/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
index a2decad..953604e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.sources
 
+import java.io.File
+
 import org.apache.spark.sql.{QueryTest, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -61,4 +63,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
       assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
     }
   }
+
+  test("maxRecordsPerFile setting in non-partitioned write path") {
+    withTempDir { f =>
+      spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
+        .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath)
+      assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4)
+
+      spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
+        .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath)
+      assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2)
+
+      spark.range(start = 0, end = 4, step = 1, numPartitions = 1)
+        .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath)
+      assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1)
+    }
+  }
+
+  test("maxRecordsPerFile setting in dynamic partition writes") {
+    withTempDir { f =>
+      spark.range(start = 0, end = 4, step = 1, numPartitions = 1).selectExpr("id", "id id1")
+        .write
+        .partitionBy("id")
+        .option("maxRecordsPerFile", 1)
+        .mode("overwrite")
+        .parquet(f.getAbsolutePath)
+      assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4)
+    }
+  }
+
+  /** Lists files recursively. */
+  private def recursiveList(f: File): Array[File] = {
+    require(f.isDirectory)
+    val current = f.listFiles
+    current ++ current.filter(_.isDirectory).flatMap(recursiveList)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org