Repository: incubator-beam Updated Branches: refs/heads/python-sdk 7d988e3bb -> 7e744e445
Implement avro sink. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1090ca39 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1090ca39 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1090ca39 Branch: refs/heads/python-sdk Commit: 1090ca3911faef84bbd61c6b41668a695d2fe8dc Parents: 8bc965b Author: Robert Bradshaw <rober...@gmail.com> Authored: Sat Sep 24 02:25:05 2016 -0700 Committer: Robert Bradshaw <rober...@google.com> Committed: Mon Sep 26 12:17:35 2016 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/io/avroio.py | 100 ++++++++++++++++++++++++- sdks/python/apache_beam/io/avroio_test.py | 41 ++++++---- 2 files changed, 127 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1090ca39/sdks/python/apache_beam/io/avroio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index 82b30be..fdf8dae 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -20,15 +20,18 @@ import os import StringIO import zlib +import avro from avro import datafile from avro import io as avroio from avro import schema +import apache_beam as beam from apache_beam.io import filebasedsource +from apache_beam.io import fileio from apache_beam.io.iobase import Read from apache_beam.transforms import PTransform -__all__ = ['ReadFromAvro'] +__all__ = ['ReadFromAvro', 'WriteToAvro'] class ReadFromAvro(PTransform): @@ -242,3 +245,98 @@ class _AvroSource(filebasedsource.FileBasedSource): sync_marker) for record in block.records(): yield record + + +_avro_codecs = { + fileio.CompressionTypes.UNCOMPRESSED: 'null', + fileio.CompressionTypes.ZLIB: 'deflate', + # fileio.CompressionTypes.SNAPPY: 'snappy', +} + + +class WriteToAvro(beam.transforms.PTransform): + """A ``PTransform`` for writing avro files.""" + + def __init__(self, + file_path_prefix, + schema, + file_name_suffix='', + num_shards=0, + shard_name_template=None, + mime_type='application/x-avro', + compression_type=fileio.CompressionTypes.ZLIB): + """Initialize a WriteToAvro transform. + + Args: + file_path_prefix: The file path to write to. The files written will begin + with this prefix, followed by a shard identifier (see num_shards), and + end in a common extension, if given by file_name_suffix. In most cases, + only this argument is specified and num_shards, shard_name_template, and + file_name_suffix use default values. + schema: The schema to use, as returned by avro.schema.parse + file_name_suffix: Suffix for the files written. + append_trailing_newlines: indicate whether this sink should write an + additional newline char after writing each element. + num_shards: The number of files (shards) used for output. If not set, the + service will decide on the optimal number of shards. + Constraining the number of shards is likely to reduce + the performance of a pipeline. Setting this value is not recommended + unless you require a specific number of output files. + shard_name_template: A template string containing placeholders for + the shard number and shard count. Currently only '' and + '-SSSSS-of-NNNNN' are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters 'S' and 'N' are replaced with the 0-padded shard + number and shard count respectively. This argument can be '' in which + case it behaves as if num_shards was set to 1 and only one file will be + generated. The default pattern used is '-SSSSS-of-NNNNN'. + mime_type: The MIME type to use for the produced files, if the filesystem + supports specifying MIME types. + compression_type: Used to handle compressed output files. Defaults to + CompressionTypes.ZLIB + + Returns: + A WriteToAvro transform usable for writing. + """ + if compression_type not in _avro_codecs: + raise ValueError( + 'Compression type %s not supported by avro.' % compression_type) + self.args = (file_path_prefix, schema, file_name_suffix, num_shards, + shard_name_template, mime_type, compression_type) + + def apply(self, pcoll): + # pylint: disable=expression-not-assigned + pcoll | beam.io.iobase.Write(_AvroSink(*self.args)) + + +class _AvroSink(fileio.FileSink): + """A sink to avro files.""" + + def __init__(self, + file_path_prefix, + schema, + file_name_suffix, + num_shards, + shard_name_template, + mime_type, + compression_type): + super(_AvroSink, self).__init__( + file_path_prefix, + file_name_suffix=file_name_suffix, + num_shards=num_shards, + shard_name_template=shard_name_template, + coder=None, + mime_type=mime_type, + # Compression happens at the block level, not the file level. + compression_type=fileio.CompressionTypes.UNCOMPRESSED) + self.schema = schema + self.avro_compression_type = compression_type + + def open(self, temp_path): + file_handle = super(_AvroSink, self).open(temp_path) + return avro.datafile.DataFileWriter( + file_handle, avro.io.DatumWriter(), self.schema, + _avro_codecs[self.avro_compression_type]) + + def write_record(self, writer, value): + writer.append(value) http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1090ca39/sdks/python/apache_beam/io/avroio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index e0c211f..dbaf6f3 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -15,6 +15,7 @@ # limitations under the License. # +import json import logging import os import tempfile @@ -33,7 +34,7 @@ from apache_beam.io.avroio import _AvroSource as AvroSource import avro.datafile from avro.datafile import DataFileWriter from avro.io import DatumWriter -import avro.schema as avro_schema +import avro.schema class TestAvro(unittest.TestCase): @@ -67,25 +68,27 @@ class TestAvro(unittest.TestCase): 'favorite_number': 6, 'favorite_color': 'Green'}] + SCHEMA = avro.schema.parse(''' + {"namespace": "example.avro", + "type": "record", + "name": "User", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_number", "type": ["int", "null"]}, + {"name": "favorite_color", "type": ["string", "null"]} + ] + } + ''') + def _write_data(self, directory=None, prefix=tempfile.template, codec='null', count=len(RECORDS)): - schema = ('{\"namespace\": \"example.avro\",' - '\"type\": \"record\",' - '\"name\": \"User\",' - '\"fields\": [' - '{\"name\": \"name\", \"type\": \"string\"},' - '{\"name\": \"favorite_number\", \"type\": [\"int\", \"null\"]},' - '{\"name\": \"favorite_color\", \"type\": [\"string\", \"null\"]}' - ']}') - - schema = avro_schema.parse(schema) with tempfile.NamedTemporaryFile( delete=False, dir=directory, prefix=prefix) as f: - writer = DataFileWriter(f, DatumWriter(), schema, codec=codec) + writer = DataFileWriter(f, DatumWriter(), self.SCHEMA, codec=codec) len_records = len(self.RECORDS) for i in range(count): writer.append(self.RECORDS[i % len_records]) @@ -227,11 +230,23 @@ class TestAvro(unittest.TestCase): source_test_utils.readFromSource(source, None, None) self.assertEqual(0, exn.exception.message.find('Unexpected sync marker')) - def test_pipeline(self): + def test_source_transform(self): path = self._write_data() with beam.Pipeline('DirectPipelineRunner') as p: assert_that(p | avroio.ReadFromAvro(path), equal_to(self.RECORDS)) + def test_sink_transform(self): + with tempfile.NamedTemporaryFile() as dst: + path = dst.name + with beam.Pipeline('DirectPipelineRunner') as p: + # pylint: disable=expression-not-assigned + p | beam.Create(self.RECORDS) | avroio.WriteToAvro(path, self.SCHEMA) + with beam.Pipeline('DirectPipelineRunner') as p: + # json used for stable sortability + readback = p | avroio.ReadFromAvro(path + '*') | beam.Map(json.dumps) + assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()