You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/12/02 05:40:52 UTC
spark git commit: [SPARK-18658][SQL] Write text records directly to a
FileOutputStream
Repository: spark
Updated Branches:
refs/heads/master d3c90b74e -> c82f16c15
[SPARK-18658][SQL] Write text records directly to a FileOutputStream
## What changes were proposed in this pull request?
This replaces uses of `TextOutputFormat` with an `OutputStream`, which will either write directly to the filesystem or indirectly via a compressor (if so configured). This avoids intermediate buffering.
The inverse of this (reading directly from a stream) is necessary for streaming large JSON records (when `wholeFile` is enabled) so I wanted to keep the read and write paths symmetric.
## How was this patch tested?
Existing unit tests.
Author: Nathan Howell <nh...@godaddy.com>
Closes #16089 from NathanHowell/SPARK-18658.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c82f16c1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c82f16c1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c82f16c1
Branch: refs/heads/master
Commit: c82f16c15e0d4bfc54fb890a667d9164a088b5c6
Parents: d3c90b7
Author: Nathan Howell <nh...@godaddy.com>
Authored: Thu Dec 1 21:40:49 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Dec 1 21:40:49 2016 -0800
----------------------------------------------------------------------
.../apache/spark/unsafe/types/UTF8String.java | 19 ++++
.../spark/unsafe/types/UTF8StringSuite.java | 109 +++++++++++++++++++
.../spark/ml/source/libsvm/LibSVMRelation.scala | 28 ++---
.../sql/catalyst/json/JacksonGenerator.scala | 4 +
.../execution/datasources/CodecStreams.scala | 74 +++++++++++++
.../execution/datasources/csv/CSVParser.scala | 19 ++--
.../execution/datasources/csv/CSVRelation.scala | 43 ++------
.../datasources/json/JsonFileFormat.scala | 31 ++----
.../datasources/text/TextFileFormat.scala | 42 ++-----
.../spark/sql/sources/SimpleTextRelation.scala | 27 +----
10 files changed, 252 insertions(+), 144 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index e09a6b7..0255f53 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -147,6 +147,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
buffer.position(pos + numBytes);
}
+ public void writeTo(OutputStream out) throws IOException {
+ if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) {
+ final byte[] bytes = (byte[]) base;
+
+ // the offset includes an object header... this is only needed for unsafe copies
+ final long arrayOffset = offset - BYTE_ARRAY_OFFSET;
+
+ // verify that the offset and length points somewhere inside the byte array
+ // and that the offset can safely be truncated to a 32-bit integer
+ if ((long) bytes.length < arrayOffset + numBytes) {
+ throw new ArrayIndexOutOfBoundsException();
+ }
+
+ out.write(bytes, (int) arrayOffset, numBytes);
+ } else {
+ out.write(getBytes());
+ }
+ }
+
/**
* Returns the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 7f03686..04f6845 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -17,15 +17,22 @@
package org.apache.spark.unsafe.types;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.HashSet;
import com.google.common.collect.ImmutableMap;
+import org.apache.spark.unsafe.Platform;
import org.junit.Test;
import static org.junit.Assert.*;
+import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
import static org.apache.spark.unsafe.types.UTF8String.*;
public class UTF8StringSuite {
@@ -499,4 +506,106 @@ public class UTF8StringSuite {
assertEquals(fromString("123").soundex(), fromString("123"));
assertEquals(fromString("\u4e16\u754c\u5343\u4e16").soundex(), fromString("\u4e16\u754c\u5343\u4e16"));
}
+
+ @Test
+ public void writeToOutputStreamUnderflow() throws IOException {
+ // offset underflow is apparently supported?
+ final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8);
+
+ for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) {
+ UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)
+ .writeTo(outputStream);
+ final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length);
+ assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString());
+ outputStream.reset();
+ }
+ }
+
+ @Test
+ public void writeToOutputStreamSlice() throws IOException {
+ final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8);
+
+ for (int i = 0; i < test.length; ++i) {
+ for (int j = 0; j < test.length - i; ++j) {
+ UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j)
+ .writeTo(outputStream);
+
+ assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray());
+ outputStream.reset();
+ }
+ }
+ }
+
+ @Test
+ public void writeToOutputStreamOverflow() throws IOException {
+ final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8);
+
+ final HashSet<Long> offsets = new HashSet<>();
+ for (int i = 0; i < 16; ++i) {
+ // touch more points around MAX_VALUE
+ offsets.add((long) Integer.MAX_VALUE - i);
+ // subtract off BYTE_ARRAY_OFFSET to avoid wrapping around to a negative value,
+ // which will hit the slower copy path instead of the optimized one
+ offsets.add(Long.MAX_VALUE - BYTE_ARRAY_OFFSET - i);
+ }
+
+ for (long i = 1; i > 0L; i <<= 1) {
+ for (long j = 0; j < 32L; ++j) {
+ offsets.add(i + j);
+ }
+ }
+
+ for (final long offset : offsets) {
+ try {
+ fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length)
+ .writeTo(outputStream);
+
+ throw new IllegalStateException(Long.toString(offset));
+ } catch (ArrayIndexOutOfBoundsException e) {
+ // ignore
+ } finally {
+ outputStream.reset();
+ }
+ }
+ }
+
+ @Test
+ public void writeToOutputStream() throws IOException {
+ final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ EMPTY_UTF8.writeTo(outputStream);
+ assertEquals("", outputStream.toString("UTF-8"));
+ outputStream.reset();
+
+ fromString("\u6570\u636e\u7816\u5f88\u91cd").writeTo(outputStream);
+ assertEquals(
+ "\u6570\u636e\u7816\u5f88\u91cd",
+ outputStream.toString("UTF-8"));
+ outputStream.reset();
+ }
+
+ @Test
+ public void writeToOutputStreamIntArray() throws IOException {
+ // verify that writes work on objects that are not byte arrays
+ final ByteBuffer buffer = StandardCharsets.UTF_8.encode("\u5927\u5343\u4e16\u754c");
+ buffer.position(0);
+ buffer.order(ByteOrder.LITTLE_ENDIAN);
+
+ final int length = buffer.limit();
+ assertEquals(12, length);
+
+ final int ints = length / 4;
+ final int[] array = new int[ints];
+
+ for (int i = 0; i < ints; ++i) {
+ array[i] = buffer.getInt();
+ }
+
+ final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ fromAddress(array, Platform.INT_ARRAY_OFFSET, length)
+ .writeTo(outputStream);
+ assertEquals("\u5927\u5343\u4e16\u754c", outputStream.toString("UTF-8"));
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index cb3ca1b..b5aa7ce 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -21,9 +21,7 @@ import java.io.IOException
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.{NullWritable, Text}
-import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
-import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.TaskContext
import org.apache.spark.ml.feature.LabeledPoint
@@ -35,7 +33,6 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.text.TextOutputWriter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
@@ -46,30 +43,21 @@ private[libsvm] class LibSVMOutputWriter(
context: TaskAttemptContext)
extends OutputWriter {
- private[this] val buffer = new Text()
-
- private val recordWriter: RecordWriter[NullWritable, Text] = {
- new TextOutputFormat[NullWritable, Text]() {
- override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
- new Path(path)
- }
- }.getRecordWriter(context)
- }
+ private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
override def write(row: Row): Unit = {
val label = row.get(0)
val vector = row.get(1).asInstanceOf[Vector]
- val sb = new StringBuilder(label.toString)
+ writer.write(label.toString)
vector.foreachActive { case (i, v) =>
- sb += ' '
- sb ++= s"${i + 1}:$v"
+ writer.write(s" ${i + 1}:$v")
}
- buffer.set(sb.mkString)
- recordWriter.write(NullWritable.get(), buffer)
+
+ writer.write('\n')
}
override def close(): Unit = {
- recordWriter.close(context)
+ writer.close()
}
}
@@ -136,7 +124,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
}
override def getFileExtension(context: TaskAttemptContext): String = {
- ".libsvm" + TextOutputWriter.getCompressionExtension(context)
+ ".libsvm" + CodecStreams.getCompressionExtension(context)
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index 4b548e0..bf8e3c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -194,4 +194,8 @@ private[sql] class JacksonGenerator(
writeFields(row, schema, rootFieldWriters)
}
}
+
+ def writeLineEnding(): Unit = {
+ gen.writeRaw('\n')
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
new file mode 100644
index 0000000..900263a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.io.{OutputStream, OutputStreamWriter}
+import java.nio.charset.{Charset, StandardCharsets}
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.compress._
+import org.apache.hadoop.mapreduce.JobContext
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
+import org.apache.hadoop.util.ReflectionUtils
+
+object CodecStreams {
+ private def getCompressionCodec(
+ context: JobContext,
+ file: Option[Path] = None): Option[CompressionCodec] = {
+ if (FileOutputFormat.getCompressOutput(context)) {
+ val compressorClass = FileOutputFormat.getOutputCompressorClass(
+ context,
+ classOf[GzipCodec])
+
+ Some(ReflectionUtils.newInstance(compressorClass, context.getConfiguration))
+ } else {
+ file.flatMap { path =>
+ val compressionCodecs = new CompressionCodecFactory(context.getConfiguration)
+ Option(compressionCodecs.getCodec(path))
+ }
+ }
+ }
+
+ /**
+ * Create a new file and open it for writing.
+ * If compression is enabled in the [[JobContext]] the stream will write compressed data to disk.
+ * An exception will be thrown if the file already exists.
+ */
+ def createOutputStream(context: JobContext, file: Path): OutputStream = {
+ val fs = file.getFileSystem(context.getConfiguration)
+ val outputStream: OutputStream = fs.create(file, false)
+
+ getCompressionCodec(context, Some(file))
+ .map(codec => codec.createOutputStream(outputStream))
+ .getOrElse(outputStream)
+ }
+
+ def createOutputStreamWriter(
+ context: JobContext,
+ file: Path,
+ charset: Charset = StandardCharsets.UTF_8): OutputStreamWriter = {
+ new OutputStreamWriter(createOutputStream(context, file), charset)
+ }
+
+ /** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */
+ def getCompressionExtension(context: JobContext): String = {
+ getCompressionCodec(context)
+ .map(_.getDefaultExtension)
+ .getOrElse("")
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
index 332f5c8..6239508 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.execution.datasources.csv
-import java.io.{CharArrayWriter, StringReader}
+import java.io.{CharArrayWriter, OutputStream, StringReader}
+import java.nio.charset.StandardCharsets
import com.univocity.parsers.csv._
@@ -64,7 +65,10 @@ private[csv] class CsvReader(params: CSVOptions) {
* @param params Parameters object for configuration
* @param headers headers for columns
*/
-private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging {
+private[csv] class LineCsvWriter(
+ params: CSVOptions,
+ headers: Seq[String],
+ output: OutputStream) extends Logging {
private val writerSettings = new CsvWriterSettings
private val format = writerSettings.getFormat
@@ -80,21 +84,14 @@ private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten
writerSettings.setHeaders(headers: _*)
writerSettings.setQuoteEscapingEnabled(params.escapeQuotes)
- private val buffer = new CharArrayWriter()
- private val writer = new CsvWriter(buffer, writerSettings)
+ private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings)
def writeRow(row: Seq[String], includeHeader: Boolean): Unit = {
if (includeHeader) {
writer.writeHeaders()
}
- writer.writeRow(row.toArray: _*)
- }
- def flush(): String = {
- writer.flush()
- val lines = buffer.toString.stripLineEnd
- buffer.reset()
- lines
+ writer.writeRow(row: _*)
}
def close(): Unit = {
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index a47b414..52de11d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -20,10 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv
import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.io.{NullWritable, Text}
-import org.apache.hadoop.mapreduce.RecordWriter
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
@@ -31,8 +28,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile}
-import org.apache.spark.sql.execution.datasources.text.TextOutputWriter
+import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.types._
object CSVRelation extends Logging {
@@ -179,7 +175,7 @@ private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWrit
}
override def getFileExtension(context: TaskAttemptContext): String = {
- ".csv" + TextOutputWriter.getCompressionExtension(context)
+ ".csv" + CodecStreams.getCompressionExtension(context)
}
}
@@ -189,9 +185,6 @@ private[csv] class CsvOutputWriter(
context: TaskAttemptContext,
params: CSVOptions) extends OutputWriter with Logging {
- // create the Generator without separator inserted between 2 records
- private[this] val text = new Text()
-
// A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
// When the value is null, this converter should not be called.
private type ValueConverter = (InternalRow, Int) => String
@@ -200,17 +193,9 @@ private[csv] class CsvOutputWriter(
private val valueConverters: Array[ValueConverter] =
dataSchema.map(_.dataType).map(makeConverter).toArray
- private val recordWriter: RecordWriter[NullWritable, Text] = {
- new TextOutputFormat[NullWritable, Text]() {
- override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
- new Path(path)
- }
- }.getRecordWriter(context)
- }
-
- private val FLUSH_BATCH_SIZE = 1024L
- private var records: Long = 0L
- private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)
+ private var printHeader: Boolean = params.headerFlag
+ private val writer = CodecStreams.createOutputStream(context, new Path(path))
+ private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq, writer)
private def rowToString(row: InternalRow): Seq[String] = {
var i = 0
@@ -245,24 +230,12 @@ private[csv] class CsvOutputWriter(
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
override protected[sql] def writeInternal(row: InternalRow): Unit = {
- csvWriter.writeRow(rowToString(row), records == 0L && params.headerFlag)
- records += 1
- if (records % FLUSH_BATCH_SIZE == 0) {
- flush()
- }
- }
-
- private def flush(): Unit = {
- val lines = csvWriter.flush()
- if (lines.nonEmpty) {
- text.set(lines)
- recordWriter.write(NullWritable.get(), text)
- }
+ csvWriter.writeRow(rowToString(row), printHeader)
+ printHeader = false
}
override def close(): Unit = {
- flush()
csvWriter.close()
- recordWriter.close(context)
+ writer.close()
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 0e38aef..c957914 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -17,15 +17,12 @@
package org.apache.spark.sql.execution.datasources.json
-import java.io.CharArrayWriter
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
+import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
-import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
-import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
@@ -35,7 +32,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.text.TextOutputWriter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
@@ -90,7 +86,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
override def getFileExtension(context: TaskAttemptContext): String = {
- ".json" + TextOutputWriter.getCompressionExtension(context)
+ ".json" + CodecStreams.getCompressionExtension(context)
}
}
}
@@ -163,33 +159,20 @@ private[json] class JsonOutputWriter(
context: TaskAttemptContext)
extends OutputWriter with Logging {
- private[this] val writer = new CharArrayWriter()
+ private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
+
// create the Generator without separator inserted between 2 records
private[this] val gen = new JacksonGenerator(dataSchema, writer, options)
- private[this] val result = new Text()
-
- private val recordWriter: RecordWriter[NullWritable, Text] = {
- new TextOutputFormat[NullWritable, Text]() {
- override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
- new Path(path)
- }
- }.getRecordWriter(context)
- }
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
override protected[sql] def writeInternal(row: InternalRow): Unit = {
gen.write(row)
- gen.flush()
-
- result.set(writer.toString)
- writer.reset()
-
- recordWriter.write(NullWritable.get(), result)
+ gen.writeLineEnding()
}
override def close(): Unit = {
gen.close()
- recordWriter.close(context)
+ writer.close()
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index 8e04396..178160c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -19,11 +19,7 @@ package org.apache.spark.sql.execution.datasources.text
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.{NullWritable, Text}
-import org.apache.hadoop.io.compress.GzipCodec
-import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
-import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
-import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.TaskContext
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
@@ -82,7 +78,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
override def getFileExtension(context: TaskAttemptContext): String = {
- ".txt" + TextOutputWriter.getCompressionExtension(context)
+ ".txt" + CodecStreams.getCompressionExtension(context)
}
}
}
@@ -132,39 +128,19 @@ class TextOutputWriter(
context: TaskAttemptContext)
extends OutputWriter {
- private[this] val buffer = new Text()
-
- private val recordWriter: RecordWriter[NullWritable, Text] = {
- new TextOutputFormat[NullWritable, Text]() {
- override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
- new Path(path)
- }
- }.getRecordWriter(context)
- }
+ private val writer = CodecStreams.createOutputStream(context, new Path(path))
override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
override protected[sql] def writeInternal(row: InternalRow): Unit = {
- val utf8string = row.getUTF8String(0)
- buffer.set(utf8string.getBytes)
- recordWriter.write(NullWritable.get(), buffer)
+ if (!row.isNullAt(0)) {
+ val utf8string = row.getUTF8String(0)
+ utf8string.writeTo(writer)
+ }
+ writer.write('\n')
}
override def close(): Unit = {
- recordWriter.close(context)
- }
-}
-
-
-object TextOutputWriter {
- /** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */
- def getCompressionExtension(context: TaskAttemptContext): String = {
- // Set the compression extension, similar to code in TextOutputFormat.getDefaultWorkFile
- if (FileOutputFormat.getCompressOutput(context)) {
- val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec])
- ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension
- } else {
- ""
- }
+ writer.close()
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index cecfd99..5fdf615 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -17,14 +17,9 @@
package org.apache.spark.sql.sources
-import java.text.NumberFormat
-import java.util.Locale
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.{NullWritable, Text}
-import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
-import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.sql.{sources, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
@@ -125,29 +120,19 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
class SimpleTextOutputWriter(path: String, context: TaskAttemptContext)
extends OutputWriter {
- private val recordWriter: RecordWriter[NullWritable, Text] =
- new AppendingTextOutputFormat(path).getRecordWriter(context)
+ private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
override def write(row: Row): Unit = {
val serialized = row.toSeq.map { v =>
if (v == null) "" else v.toString
}.mkString(",")
- recordWriter.write(null, new Text(serialized))
- }
- override def close(): Unit = {
- recordWriter.close(context)
+ writer.write(serialized)
+ writer.write('\n')
}
-}
-class AppendingTextOutputFormat(path: String) extends TextOutputFormat[NullWritable, Text] {
-
- val numberFormat = NumberFormat.getInstance(Locale.US)
- numberFormat.setMinimumIntegerDigits(5)
- numberFormat.setGroupingUsed(false)
-
- override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
- new Path(path)
+ override def close(): Unit = {
+ writer.close()
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org