Repository: spark Updated Branches: refs/heads/master d3c90b74e -> c82f16c15
[SPARK-18658][SQL] Write text records directly to a FileOutputStream ## What changes were proposed in this pull request? This replaces uses of `TextOutputFormat` with an `OutputStream`, which will either write directly to the filesystem or indirectly via a compressor (if so configured). This avoids intermediate buffering. The inverse of this (reading directly from a stream) is necessary for streaming large JSON records (when `wholeFile` is enabled) so I wanted to keep the read and write paths symmetric. ## How was this patch tested? Existing unit tests. Author: Nathan Howell <nhow...@godaddy.com> Closes #16089 from NathanHowell/SPARK-18658. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c82f16c1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c82f16c1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c82f16c1 Branch: refs/heads/master Commit: c82f16c15e0d4bfc54fb890a667d9164a088b5c6 Parents: d3c90b7 Author: Nathan Howell <nhow...@godaddy.com> Authored: Thu Dec 1 21:40:49 2016 -0800 Committer: Reynold Xin <r...@databricks.com> Committed: Thu Dec 1 21:40:49 2016 -0800 ---------------------------------------------------------------------- .../apache/spark/unsafe/types/UTF8String.java | 19 ++++ .../spark/unsafe/types/UTF8StringSuite.java | 109 +++++++++++++++++++ .../spark/ml/source/libsvm/LibSVMRelation.scala | 28 ++--- .../sql/catalyst/json/JacksonGenerator.scala | 4 + .../execution/datasources/CodecStreams.scala | 74 +++++++++++++ .../execution/datasources/csv/CSVParser.scala | 19 ++-- .../execution/datasources/csv/CSVRelation.scala | 43 ++------ .../datasources/json/JsonFileFormat.scala | 31 ++---- .../datasources/text/TextFileFormat.scala | 42 ++----- .../spark/sql/sources/SimpleTextRelation.scala | 27 +---- 10 files changed, 252 insertions(+), 144 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java ---------------------------------------------------------------------- diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e09a6b7..0255f53 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -147,6 +147,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, buffer.position(pos + numBytes); } + public void writeTo(OutputStream out) throws IOException { + if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = (byte[]) base; + + // the offset includes an object header... this is only needed for unsafe copies + final long arrayOffset = offset - BYTE_ARRAY_OFFSET; + + // verify that the offset and length points somewhere inside the byte array + // and that the offset can safely be truncated to a 32-bit integer + if ((long) bytes.length < arrayOffset + numBytes) { + throw new ArrayIndexOutOfBoundsException(); + } + + out.write(bytes, (int) arrayOffset, numBytes); + } else { + out.write(getBytes()); + } + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java ---------------------------------------------------------------------- diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7f03686..04f6845 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -17,15 +17,22 @@ package org.apache.spark.unsafe.types; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import com.google.common.collect.ImmutableMap; +import org.apache.spark.unsafe.Platform; import org.junit.Test; import static org.junit.Assert.*; +import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; import static org.apache.spark.unsafe.types.UTF8String.*; public class UTF8StringSuite { @@ -499,4 +506,106 @@ public class UTF8StringSuite { assertEquals(fromString("123").soundex(), fromString("123")); assertEquals(fromString("ä¸çåä¸").soundex(), fromString("ä¸çåä¸")); } + + @Test + public void writeToOutputStreamUnderflow() throws IOException { + // offset underflow is apparently supported? + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); + + for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { + UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) + .writeTo(outputStream); + final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); + assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); + outputStream.reset(); + } + } + + @Test + public void writeToOutputStreamSlice() throws IOException { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); + + for (int i = 0; i < test.length; ++i) { + for (int j = 0; j < test.length - i; ++j) { + UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) + .writeTo(outputStream); + + assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); + outputStream.reset(); + } + } + } + + @Test + public void writeToOutputStreamOverflow() throws IOException { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); + + final HashSet<Long> offsets = new HashSet<>(); + for (int i = 0; i < 16; ++i) { + // touch more points around MAX_VALUE + offsets.add((long) Integer.MAX_VALUE - i); + // subtract off BYTE_ARRAY_OFFSET to avoid wrapping around to a negative value, + // which will hit the slower copy path instead of the optimized one + offsets.add(Long.MAX_VALUE - BYTE_ARRAY_OFFSET - i); + } + + for (long i = 1; i > 0L; i <<= 1) { + for (long j = 0; j < 32L; ++j) { + offsets.add(i + j); + } + } + + for (final long offset : offsets) { + try { + fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) + .writeTo(outputStream); + + throw new IllegalStateException(Long.toString(offset)); + } catch (ArrayIndexOutOfBoundsException e) { + // ignore + } finally { + outputStream.reset(); + } + } + } + + @Test + public void writeToOutputStream() throws IOException { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + EMPTY_UTF8.writeTo(outputStream); + assertEquals("", outputStream.toString("UTF-8")); + outputStream.reset(); + + fromString("æ°æ®ç å¾é").writeTo(outputStream); + assertEquals( + "æ°æ®ç å¾é", + outputStream.toString("UTF-8")); + outputStream.reset(); + } + + @Test + public void writeToOutputStreamIntArray() throws IOException { + // verify that writes work on objects that are not byte arrays + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大åä¸ç"); + buffer.position(0); + buffer.order(ByteOrder.LITTLE_ENDIAN); + + final int length = buffer.limit(); + assertEquals(12, length); + + final int ints = length / 4; + final int[] array = new int[ints]; + + for (int i = 0; i < ints; ++i) { + array[i] = buffer.getInt(); + } + + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + fromAddress(array, Platform.INT_ARRAY_OFFSET, length) + .writeTo(outputStream); + assertEquals("大åä¸ç", outputStream.toString("UTF-8")); + } } http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index cb3ca1b..b5aa7ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -21,9 +21,7 @@ import java.io.IOException import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext import org.apache.spark.ml.feature.LabeledPoint @@ -35,7 +33,6 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextOutputWriter import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -46,30 +43,21 @@ private[libsvm] class LibSVMOutputWriter( context: TaskAttemptContext) extends OutputWriter { - private[this] val buffer = new Text() - - private val recordWriter: RecordWriter[NullWritable, Text] = { - new TextOutputFormat[NullWritable, Text]() { - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(path) - } - }.getRecordWriter(context) - } + private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) override def write(row: Row): Unit = { val label = row.get(0) val vector = row.get(1).asInstanceOf[Vector] - val sb = new StringBuilder(label.toString) + writer.write(label.toString) vector.foreachActive { case (i, v) => - sb += ' ' - sb ++= s"${i + 1}:$v" + writer.write(s" ${i + 1}:$v") } - buffer.set(sb.mkString) - recordWriter.write(NullWritable.get(), buffer) + + writer.write('\n') } override def close(): Unit = { - recordWriter.close(context) + writer.close() } } @@ -136,7 +124,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour } override def getFileExtension(context: TaskAttemptContext): String = { - ".libsvm" + TextOutputWriter.getCompressionExtension(context) + ".libsvm" + CodecStreams.getCompressionExtension(context) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 4b548e0..bf8e3c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -194,4 +194,8 @@ private[sql] class JacksonGenerator( writeFields(row, schema, rootFieldWriters) } } + + def writeLineEnding(): Unit = { + gen.writeRaw('\n') + } } http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala new file mode 100644 index 0000000..900263a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.{OutputStream, OutputStreamWriter} +import java.nio.charset.{Charset, StandardCharsets} + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress._ +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.util.ReflectionUtils + +object CodecStreams { + private def getCompressionCodec( + context: JobContext, + file: Option[Path] = None): Option[CompressionCodec] = { + if (FileOutputFormat.getCompressOutput(context)) { + val compressorClass = FileOutputFormat.getOutputCompressorClass( + context, + classOf[GzipCodec]) + + Some(ReflectionUtils.newInstance(compressorClass, context.getConfiguration)) + } else { + file.flatMap { path => + val compressionCodecs = new CompressionCodecFactory(context.getConfiguration) + Option(compressionCodecs.getCodec(path)) + } + } + } + + /** + * Create a new file and open it for writing. + * If compression is enabled in the [[JobContext]] the stream will write compressed data to disk. + * An exception will be thrown if the file already exists. + */ + def createOutputStream(context: JobContext, file: Path): OutputStream = { + val fs = file.getFileSystem(context.getConfiguration) + val outputStream: OutputStream = fs.create(file, false) + + getCompressionCodec(context, Some(file)) + .map(codec => codec.createOutputStream(outputStream)) + .getOrElse(outputStream) + } + + def createOutputStreamWriter( + context: JobContext, + file: Path, + charset: Charset = StandardCharsets.UTF_8): OutputStreamWriter = { + new OutputStreamWriter(createOutputStream(context, file), charset) + } + + /** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */ + def getCompressionExtension(context: JobContext): String = { + getCompressionCodec(context) + .map(_.getDefaultExtension) + .getOrElse("") + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 332f5c8..6239508 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.datasources.csv -import java.io.{CharArrayWriter, StringReader} +import java.io.{CharArrayWriter, OutputStream, StringReader} +import java.nio.charset.StandardCharsets import com.univocity.parsers.csv._ @@ -64,7 +65,10 @@ private[csv] class CsvReader(params: CSVOptions) { * @param params Parameters object for configuration * @param headers headers for columns */ -private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { +private[csv] class LineCsvWriter( + params: CSVOptions, + headers: Seq[String], + output: OutputStream) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat @@ -80,21 +84,14 @@ private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten writerSettings.setHeaders(headers: _*) writerSettings.setQuoteEscapingEnabled(params.escapeQuotes) - private val buffer = new CharArrayWriter() - private val writer = new CsvWriter(buffer, writerSettings) + private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings) def writeRow(row: Seq[String], includeHeader: Boolean): Unit = { if (includeHeader) { writer.writeHeaders() } - writer.writeRow(row.toArray: _*) - } - def flush(): String = { - writer.flush() - val lines = buffer.toString.stripLineEnd - buffer.reset() - lines + writer.writeRow(row: _*) } def close(): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index a47b414..52de11d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -20,10 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.RecordWriter import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -31,8 +28,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile} -import org.apache.spark.sql.execution.datasources.text.TextOutputWriter +import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.types._ object CSVRelation extends Logging { @@ -179,7 +175,7 @@ private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWrit } override def getFileExtension(context: TaskAttemptContext): String = { - ".csv" + TextOutputWriter.getCompressionExtension(context) + ".csv" + CodecStreams.getCompressionExtension(context) } } @@ -189,9 +185,6 @@ private[csv] class CsvOutputWriter( context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { - // create the Generator without separator inserted between 2 records - private[this] val text = new Text() - // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. // When the value is null, this converter should not be called. private type ValueConverter = (InternalRow, Int) => String @@ -200,17 +193,9 @@ private[csv] class CsvOutputWriter( private val valueConverters: Array[ValueConverter] = dataSchema.map(_.dataType).map(makeConverter).toArray - private val recordWriter: RecordWriter[NullWritable, Text] = { - new TextOutputFormat[NullWritable, Text]() { - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(path) - } - }.getRecordWriter(context) - } - - private val FLUSH_BATCH_SIZE = 1024L - private var records: Long = 0L - private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) + private var printHeader: Boolean = params.headerFlag + private val writer = CodecStreams.createOutputStream(context, new Path(path)) + private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq, writer) private def rowToString(row: InternalRow): Seq[String] = { var i = 0 @@ -245,24 +230,12 @@ private[csv] class CsvOutputWriter( override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { - csvWriter.writeRow(rowToString(row), records == 0L && params.headerFlag) - records += 1 - if (records % FLUSH_BATCH_SIZE == 0) { - flush() - } - } - - private def flush(): Unit = { - val lines = csvWriter.flush() - if (lines.nonEmpty) { - text.set(lines) - recordWriter.write(NullWritable.get(), text) - } + csvWriter.writeRow(rowToString(row), printHeader) + printHeader = false } override def close(): Unit = { - flush() csvWriter.close() - recordWriter.close(context) + writer.close() } } http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 0e38aef..c957914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.CharArrayWriter - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.{JobConf, TextInputFormat} -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -35,7 +32,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextOutputWriter import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -90,7 +86,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } override def getFileExtension(context: TaskAttemptContext): String = { - ".json" + TextOutputWriter.getCompressionExtension(context) + ".json" + CodecStreams.getCompressionExtension(context) } } } @@ -163,33 +159,20 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { - private[this] val writer = new CharArrayWriter() + private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) - private[this] val result = new Text() - - private val recordWriter: RecordWriter[NullWritable, Text] = { - new TextOutputFormat[NullWritable, Text]() { - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(path) - } - }.getRecordWriter(context) - } override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { gen.write(row) - gen.flush() - - result.set(writer.toString) - writer.reset() - - recordWriter.write(NullWritable.get(), result) + gen.writeLineEnding() } override def close(): Unit = { gen.close() - recordWriter.close(context) + writer.close() } } http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 8e04396..178160c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -19,11 +19,7 @@ package org.apache.spark.sql.execution.datasources.text import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.io.compress.GzipCodec -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} -import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -82,7 +78,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } override def getFileExtension(context: TaskAttemptContext): String = { - ".txt" + TextOutputWriter.getCompressionExtension(context) + ".txt" + CodecStreams.getCompressionExtension(context) } } } @@ -132,39 +128,19 @@ class TextOutputWriter( context: TaskAttemptContext) extends OutputWriter { - private[this] val buffer = new Text() - - private val recordWriter: RecordWriter[NullWritable, Text] = { - new TextOutputFormat[NullWritable, Text]() { - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(path) - } - }.getRecordWriter(context) - } + private val writer = CodecStreams.createOutputStream(context, new Path(path)) override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { - val utf8string = row.getUTF8String(0) - buffer.set(utf8string.getBytes) - recordWriter.write(NullWritable.get(), buffer) + if (!row.isNullAt(0)) { + val utf8string = row.getUTF8String(0) + utf8string.writeTo(writer) + } + writer.write('\n') } override def close(): Unit = { - recordWriter.close(context) - } -} - - -object TextOutputWriter { - /** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */ - def getCompressionExtension(context: TaskAttemptContext): String = { - // Set the compression extension, similar to code in TextOutputFormat.getDefaultWorkFile - if (FileOutputFormat.getCompressOutput(context)) { - val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec]) - ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension - } else { - "" - } + writer.close() } } http://git-wip-us.apache.org/repos/asf/spark/blob/c82f16c1/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index cecfd99..5fdf615 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -17,14 +17,9 @@ package org.apache.spark.sql.sources -import java.text.NumberFormat -import java.util.Locale - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.sql.{sources, Row, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} @@ -125,29 +120,19 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { - private val recordWriter: RecordWriter[NullWritable, Text] = - new AppendingTextOutputFormat(path).getRecordWriter(context) + private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) override def write(row: Row): Unit = { val serialized = row.toSeq.map { v => if (v == null) "" else v.toString }.mkString(",") - recordWriter.write(null, new Text(serialized)) - } - override def close(): Unit = { - recordWriter.close(context) + writer.write(serialized) + writer.write('\n') } -} -class AppendingTextOutputFormat(path: String) extends TextOutputFormat[NullWritable, Text] { - - val numberFormat = NumberFormat.getInstance(Locale.US) - numberFormat.setMinimumIntegerDigits(5) - numberFormat.setGroupingUsed(false) - - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(path) + override def close(): Unit = { + writer.close() } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org