Repository: beam Updated Branches: refs/heads/python-sdk 69d8f2bf1 -> a25515171
Create TFRecordIO, which provides source/sink for TFRecords, the dedicated record format for Tensorflow. For more about TFRecords, refer to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/python_io.md Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/88833ba5 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/88833ba5 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/88833ba5 Branch: refs/heads/python-sdk Commit: 88833ba52bf0a3ac6668bcaa73ca383771d5e1b7 Parents: 69d8f2b Author: Younghee Kwon <[email protected]> Authored: Fri Jan 6 18:05:56 2017 -0800 Committer: Robert Bradshaw <[email protected]> Committed: Mon Jan 9 13:13:45 2017 -0800 ---------------------------------------------------------------------- sdks/python/apache_beam/io/__init__.py | 1 + sdks/python/apache_beam/io/tfrecordio.py | 271 +++++++++++++++ sdks/python/apache_beam/io/tfrecordio_test.py | 365 +++++++++++++++++++++ sdks/python/setup.py | 1 + 4 files changed, 638 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/apache_beam/io/__init__.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/__init__.py b/sdks/python/apache_beam/io/__init__.py index 4ce9872..13ce36f 100644 --- a/sdks/python/apache_beam/io/__init__.py +++ b/sdks/python/apache_beam/io/__init__.py @@ -27,4 +27,5 @@ from apache_beam.io.iobase import Write from apache_beam.io.iobase import Writer from apache_beam.io.pubsub import * from apache_beam.io.textio import * +from apache_beam.io.tfrecordio import * from apache_beam.io.range_trackers import * http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/apache_beam/io/tfrecordio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py new file mode 100644 index 0000000..be9f839 --- /dev/null +++ b/sdks/python/apache_beam/io/tfrecordio.py @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""TFRecord sources and sinks.""" + +from __future__ import absolute_import + +import logging +import struct + +from apache_beam import coders +from apache_beam.io import filebasedsource +from apache_beam.io import fileio +from apache_beam.io.iobase import Read +from apache_beam.io.iobase import Write +from apache_beam.transforms import PTransform +import crcmod + +__all__ = ['ReadFromTFRecord', 'WriteToTFRecord'] + + +def _default_crc32c_fn(value): + """Calculates crc32c by either snappy or crcmod based on installation.""" + + if not _default_crc32c_fn.fn: + try: + import snappy # pylint: disable=import-error + _default_crc32c_fn.fn = snappy._crc32c # pylint: disable=protected-access + except ImportError: + logging.warning('Couldn\'t find python-snappy so the implementation of ' + '_TFRecordUtil._masked_crc32c is not as fast as it could ' + 'be.') + _default_crc32c_fn.fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c') + return _default_crc32c_fn.fn(value) +_default_crc32c_fn.fn = None + + +class _TFRecordUtil(object): + """Provides basic TFRecord encoding/decoding with consistency checks. + + For detailed TFRecord format description see: + https://www.tensorflow.org/versions/master/api_docs/python/python_io.html#tfrecords-format-details + + Note that masks and length are represented in LittleEndian order. + """ + + @classmethod + def _masked_crc32c(cls, value, crc32c_fn=_default_crc32c_fn): + """Compute a masked crc32c checksum for a value. + + Args: + value: A string for which we compute the crc. + crc32c_fn: A function that can compute a crc32c. + This is a performance hook that also helps with testing. Callers are + not expected to make use of it directly. + Returns: + Masked crc32c checksum. + """ + + crc = crc32c_fn(value) + return (((crc >> 15) | (crc << 17)) + 0xa282ead8) & 0xffffffff + + @staticmethod + def encoded_num_bytes(record): + """Return the number of bytes consumed by a record in its encoded form.""" + # 16 = 8 (Length) + 4 (crc of length) + 4 (crc of data) + return len(record) + 16 + + @classmethod + def write_record(cls, file_handle, value): + """Encode a value as a TFRecord. + + Args: + file_handle: The file to write to. + value: A string content of the record. + """ + encoded_length = struct.pack('<Q', len(value)) + file_handle.write('{}{}{}{}'.format( + encoded_length, + struct.pack('<I', cls._masked_crc32c(encoded_length)), # + value, + struct.pack('<I', cls._masked_crc32c(value)))) + + @classmethod + def read_record(cls, file_handle): + """Read a record from a TFRecords file. + + Args: + file_handle: The file to read from. + Returns: + None if EOF is reached; the paylod of the record otherwise. + Raises: + ValueError: If file appears to not be a valid TFRecords file. + """ + buf_length_expected = 12 + buf = file_handle.read(buf_length_expected) + if not buf: + return None # EOF Reached. + + # Validate all length related payloads. + if len(buf) != buf_length_expected: + raise ValueError('Not a valid TFRecord. Fewer than %d bytes: %s' % + (buf_length_expected, buf.encode('hex'))) + length, length_mask_expected = struct.unpack('<QI', buf) + length_mask_actual = cls._masked_crc32c(buf[:8]) + if length_mask_actual != length_mask_expected: + raise ValueError('Not a valid TFRecord. Mismatch of length mask: %s' % + buf.encode('hex')) + + # Validate all data related payloads. + buf_length_expected = length + 4 + buf = file_handle.read(buf_length_expected) + if len(buf) != buf_length_expected: + raise ValueError('Not a valid TFRecord. Fewer than %d bytes: %s' % + (buf_length_expected, buf.encode('hex'))) + data, data_mask_expected = struct.unpack('<%dsI' % length, buf) + data_mask_actual = cls._masked_crc32c(data) + if data_mask_actual != data_mask_expected: + raise ValueError('Not a valid TFRecord. Mismatch of data mask: %s' % + buf.encode('hex')) + + # All validation checks passed. + return data + + +class _TFRecordSource(filebasedsource.FileBasedSource): + """A File source for reading files of TFRecords. + + For detailed TFRecords format description see: + https://www.tensorflow.org/versions/master/api_docs/python/python_io.html#tfrecords-format-details + """ + + def __init__(self, + file_pattern, + coder, + compression_type): + """Initialize a TFRecordSource. See ReadFromTFRecord for details.""" + super(_TFRecordSource, self).__init__( + file_pattern=file_pattern, + compression_type=compression_type, + splittable=False) + self._coder = coder + + def read_records(self, file_name, offset_range_tracker): + if offset_range_tracker.start_position(): + raise ValueError('Start position not 0:%s' % + offset_range_tracker.start_position()) + + current_offset = offset_range_tracker.start_position() + with self.open_file(file_name) as file_handle: + while True: + if not offset_range_tracker.try_claim(current_offset): + raise RuntimeError('Unable to claim position: %s' % current_offset) + record = _TFRecordUtil.read_record(file_handle) + if record is None: + return # Reached EOF + else: + current_offset += _TFRecordUtil.encoded_num_bytes(record) + yield self._coder.decode(record) + + +class ReadFromTFRecord(PTransform): + """Transform for reading TFRecord sources.""" + + def __init__(self, + file_pattern, + coder=coders.BytesCoder(), + compression_type=fileio.CompressionTypes.AUTO, + **kwargs): + """Initialize a ReadFromTFRecord transform. + + Args: + file_pattern: A file glob pattern to read TFRecords from. + coder: Coder used to decode each record. + compression_type: Used to handle compressed input files. Default value + is CompressionTypes.AUTO, in which case the file_path's extension will + be used to detect the compression. + **kwargs: optional args dictionary. These are passed through to parent + constructor. + + Returns: + A ReadFromTFRecord transform object. + """ + super(ReadFromTFRecord, self).__init__(**kwargs) + self._args = (file_pattern, coder, compression_type) + + def expand(self, pvalue): + return pvalue.pipeline | Read(_TFRecordSource(*self._args)) + + +class _TFRecordSink(fileio.FileSink): + """Sink for writing TFRecords files. + + For detailed TFRecord format description see: + https://www.tensorflow.org/versions/master/api_docs/python/python_io.html#tfrecords-format-details + """ + + def __init__(self, file_path_prefix, coder, file_name_suffix, num_shards, + shard_name_template, compression_type): + """Initialize a TFRecordSink. See WriteToTFRecord for details.""" + + super(_TFRecordSink, self).__init__( + file_path_prefix=file_path_prefix, + coder=coder, + file_name_suffix=file_name_suffix, + num_shards=num_shards, + shard_name_template=shard_name_template, + mime_type='application/octet-stream', + compression_type=compression_type) + + def write_encoded_record(self, file_handle, value): + _TFRecordUtil.write_record(file_handle, value) + + +class WriteToTFRecord(PTransform): + """Transform for writing to TFRecord sinks.""" + + def __init__(self, + file_path_prefix, + coder=coders.BytesCoder(), + file_name_suffix='', + num_shards=0, + shard_name_template=fileio.DEFAULT_SHARD_NAME_TEMPLATE, + compression_type=fileio.CompressionTypes.AUTO, + **kwargs): + """Initialize WriteToTFRecord transform. + + Args: + file_path_prefix: The file path to write to. The files written will begin + with this prefix, followed by a shard identifier (see num_shards), and + end in a common extension, if given by file_name_suffix. + coder: Coder used to encode each record. + file_name_suffix: Suffix for the files written. + num_shards: The number of files (shards) used for output. If not set, the + default value will be used. + shard_name_template: A template string containing placeholders for + the shard number and shard count. Currently only '' and + '-SSSSS-of-NNNNN' are patterns allowed. + When constructing a filename for a particular shard number, the + upper-case letters 'S' and 'N' are replaced with the 0-padded shard + number and shard count respectively. This argument can be '' in which + case it behaves as if num_shards was set to 1 and only one file will be + generated. The default pattern is '-SSSSS-of-NNNNN'. + compression_type: Used to handle compressed output files. Typical value + is CompressionTypes.AUTO, in which case the file_path's extension will + be used to detect the compression. + **kwargs: Optional args dictionary. These are passed through to parent + constructor. + + Returns: + A WriteToTFRecord transform object. + """ + super(WriteToTFRecord, self).__init__(**kwargs) + self._args = (file_path_prefix, coder, file_name_suffix, num_shards, + shard_name_template, compression_type) + + def expand(self, pcoll): + return pcoll | Write(_TFRecordSink(*self._args)) http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/apache_beam/io/tfrecordio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py new file mode 100644 index 0000000..ee287b3 --- /dev/null +++ b/sdks/python/apache_beam/io/tfrecordio_test.py @@ -0,0 +1,365 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import binascii +import cStringIO +import glob +import gzip +import logging +import pickle +import random +import tempfile +import unittest + +import apache_beam as beam +from apache_beam import coders +from apache_beam.io import fileio +from apache_beam.io.tfrecordio import _TFRecordSink +from apache_beam.io.tfrecordio import _TFRecordSource +from apache_beam.io.tfrecordio import _TFRecordUtil +from apache_beam.io.tfrecordio import ReadFromTFRecord +from apache_beam.io.tfrecordio import WriteToTFRecord +from apache_beam.runners import DirectRunner +import crcmod + + +try: + import tensorflow as tf # pylint: disable=import-error +except ImportError: + tf = None # pylint: disable=invalid-name + logging.warning('Tensorflow is not installed, so skipping some tests.') + +# Created by running following code in python: +# >>> import tensorflow as tf +# >>> import base64 +# >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord') +# >>> writer.write('foo') +# >>> writer.close() +# >>> with open('/tmp/python_foo.tfrecord', 'rb') as f: +# ... data = base64.b64encode(f.read()) +# ... print data +FOO_RECORD_BASE64 = 'AwAAAAAAAACwmUkOZm9vYYq+/g==' + +# Same as above but containing two records ['foo', 'bar'] +FOO_BAR_RECORD_BASE64 = 'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=' + + +class TestTFRecordUtil(unittest.TestCase): + + def setUp(self): + self.record = binascii.a2b_base64(FOO_RECORD_BASE64) + + def _as_file_handle(self, contents): + result = cStringIO.StringIO() + result.write(contents) + result.reset() + return result + + def _increment_value_at_index(self, value, index): + l = list(value) + l[index] = chr(ord(l[index]) + 1) + return ''.join(l) + + def _test_error(self, record, error_text): + with self.assertRaises(ValueError) as context: + _TFRecordUtil.read_record(self._as_file_handle(record)) + self.assertIn(error_text, context.exception.message) + + def test_masked_crc32c(self): + self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c('\x00' * 32)) + self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c('\xff' * 32)) + self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c('foo')) + self.assertEqual( + 0xe4999b0, + _TFRecordUtil._masked_crc32c('\x03\x00\x00\x00\x00\x00\x00\x00')) + + def test_masked_crc32c_crcmod(self): + crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c') + self.assertEqual( + 0xfd7fffa, + _TFRecordUtil._masked_crc32c( + '\x00' * 32, crc32c_fn=crc32c_fn)) + self.assertEqual( + 0xf909b029, + _TFRecordUtil._masked_crc32c( + '\xff' * 32, crc32c_fn=crc32c_fn)) + self.assertEqual( + 0xfebe8a61, _TFRecordUtil._masked_crc32c( + 'foo', crc32c_fn=crc32c_fn)) + self.assertEqual( + 0xe4999b0, + _TFRecordUtil._masked_crc32c( + '\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn)) + + def test_write_record(self): + file_handle = cStringIO.StringIO() + _TFRecordUtil.write_record(file_handle, 'foo') + self.assertEqual(self.record, file_handle.getvalue()) + + def test_read_record(self): + actual = _TFRecordUtil.read_record(self._as_file_handle(self.record)) + self.assertEqual('foo', actual) + + def test_read_record_invalid_record(self): + self._test_error('bar', 'Not a valid TFRecord. Fewer than 12 bytes') + + def test_read_record_invalid_length_mask(self): + record = self._increment_value_at_index(self.record, 9) + self._test_error(record, 'Mismatch of length mask') + + def test_read_record_invalid_data_mask(self): + record = self._increment_value_at_index(self.record, 16) + self._test_error(record, 'Mismatch of data mask') + + def test_compatibility_read_write(self): + for record in ['', 'blah', 'another blah']: + file_handle = cStringIO.StringIO() + _TFRecordUtil.write_record(file_handle, record) + file_handle.reset() + actual = _TFRecordUtil.read_record(file_handle) + self.assertEqual(record, actual) + + +class TestTFRecordSink(unittest.TestCase): + + def _write_lines(self, sink, path, lines): + f = sink.open(path) + for l in lines: + sink.write_record(f, l) + sink.close(f) + + def test_write_record_single(self): + path = tempfile.NamedTemporaryFile().name + record = binascii.a2b_base64(FOO_RECORD_BASE64) + sink = _TFRecordSink( + path, + coder=coders.BytesCoder(), + file_name_suffix='', + num_shards=0, + shard_name_template=None, + compression_type=fileio.CompressionTypes.UNCOMPRESSED) + self._write_lines(sink, path, ['foo']) + + with open(path, 'r') as f: + self.assertEqual(f.read(), record) + + def test_write_record_multiple(self): + path = tempfile.NamedTemporaryFile().name + record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64) + sink = _TFRecordSink( + path, + coder=coders.BytesCoder(), + file_name_suffix='', + num_shards=0, + shard_name_template=None, + compression_type=fileio.CompressionTypes.UNCOMPRESSED) + self._write_lines(sink, path, ['foo', 'bar']) + + with open(path, 'r') as f: + self.assertEqual(f.read(), record) + + [email protected](tf is None, 'tensorflow not installed.') +class TestWriteToTFRecord(TestTFRecordSink): + + def test_write_record_gzip(self): + with beam.Pipeline(DirectRunner()) as p: + file_path_prefix = tempfile.NamedTemporaryFile().name + input_data = ['foo', 'bar'] + _ = p | beam.Create(input_data) | WriteToTFRecord( + file_path_prefix, compression_type=fileio.CompressionTypes.GZIP) + + actual = [] + file_name = glob.glob(file_path_prefix + '-*')[0] + for r in tf.python_io.tf_record_iterator( + file_name, options=tf.python_io.TFRecordOptions( + tf.python_io.TFRecordCompressionType.GZIP)): + actual.append(r) + self.assertEqual(actual, input_data) + + def test_write_record_auto(self): + with beam.Pipeline(DirectRunner()) as p: + file_path_prefix = tempfile.NamedTemporaryFile().name + input_data = ['foo', 'bar'] + _ = p | beam.Create(input_data) | WriteToTFRecord( + file_path_prefix, file_name_suffix='.gz') + + actual = [] + file_name = glob.glob(file_path_prefix + '-*.gz')[0] + for r in tf.python_io.tf_record_iterator( + file_name, options=tf.python_io.TFRecordOptions( + tf.python_io.TFRecordCompressionType.GZIP)): + actual.append(r) + self.assertEqual(actual, input_data) + + +class TestTFRecordSource(unittest.TestCase): + + def _write_file(self, path, base64_records): + record = binascii.a2b_base64(base64_records) + with open(path, 'wb') as f: + f.write(record) + + def _write_file_gzip(self, path, base64_records): + record = binascii.a2b_base64(base64_records) + with gzip.GzipFile(path, 'wb') as f: + f.write(record) + + def test_process_single(self): + path = tempfile.NamedTemporaryFile().name + self._write_file(path, FOO_RECORD_BASE64) + with beam.Pipeline(DirectRunner()) as p: + result = (p + | beam.Read( + _TFRecordSource( + path, + coder=coders.BytesCoder(), + compression_type=fileio.CompressionTypes.AUTO))) + beam.assert_that(result, beam.equal_to(['foo'])) + + def test_process_multiple(self): + path = tempfile.NamedTemporaryFile().name + self._write_file(path, FOO_BAR_RECORD_BASE64) + with beam.Pipeline(DirectRunner()) as p: + result = (p + | beam.Read( + _TFRecordSource( + path, + coder=coders.BytesCoder(), + compression_type=fileio.CompressionTypes.AUTO))) + beam.assert_that(result, beam.equal_to(['foo', 'bar'])) + + def test_process_gzip(self): + path = tempfile.NamedTemporaryFile().name + self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with beam.Pipeline(DirectRunner()) as p: + result = (p + | beam.Read( + _TFRecordSource( + path, + coder=coders.BytesCoder(), + compression_type=fileio.CompressionTypes.GZIP))) + beam.assert_that(result, beam.equal_to(['foo', 'bar'])) + + def test_process_auto(self): + path = tempfile.mkstemp(suffix='.gz')[1] + self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with beam.Pipeline(DirectRunner()) as p: + result = (p + | beam.Read( + _TFRecordSource( + path, + coder=coders.BytesCoder(), + compression_type=fileio.CompressionTypes.AUTO))) + beam.assert_that(result, beam.equal_to(['foo', 'bar'])) + + +class TestReadFromTFRecordSource(TestTFRecordSource): + + def test_process_gzip(self): + path = tempfile.NamedTemporaryFile().name + self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with beam.Pipeline(DirectRunner()) as p: + result = (p + | ReadFromTFRecord( + path, compression_type=fileio.CompressionTypes.GZIP)) + beam.assert_that(result, beam.equal_to(['foo', 'bar'])) + + def test_process_gzip_auto(self): + path = tempfile.mkstemp(suffix='.gz')[1] + self._write_file_gzip(path, FOO_BAR_RECORD_BASE64) + with beam.Pipeline(DirectRunner()) as p: + result = (p + | ReadFromTFRecord( + path, compression_type=fileio.CompressionTypes.AUTO)) + beam.assert_that(result, beam.equal_to(['foo', 'bar'])) + + +class TestEnd2EndWriteAndRead(unittest.TestCase): + + def create_inputs(self): + input_array = [[random.random() - 0.5 for _ in xrange(15)] + for _ in xrange(12)] + memfile = cStringIO.StringIO() + pickle.dump(input_array, memfile) + return memfile.getvalue() + + def test_end2end(self): + file_path_prefix = tempfile.NamedTemporaryFile().name + + # Generate a TFRecord file. + with beam.Pipeline(DirectRunner()) as p: + expected_data = [self.create_inputs() for _ in range(0, 10)] + _ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix) + + # Read the file back and compare. + with beam.Pipeline(DirectRunner()) as p: + actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') + beam.assert_that(actual_data, beam.equal_to(expected_data)) + + def test_end2end_auto_compression(self): + file_path_prefix = tempfile.NamedTemporaryFile().name + + # Generate a TFRecord file. + with beam.Pipeline(DirectRunner()) as p: + expected_data = [self.create_inputs() for _ in range(0, 10)] + _ = p | beam.Create(expected_data) | WriteToTFRecord( + file_path_prefix, file_name_suffix='.gz') + + # Read the file back and compare. + with beam.Pipeline(DirectRunner()) as p: + actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') + beam.assert_that(actual_data, beam.equal_to(expected_data)) + + def test_end2end_auto_compression_unsharded(self): + file_path_prefix = tempfile.NamedTemporaryFile().name + + # Generate a TFRecord file. + with beam.Pipeline(DirectRunner()) as p: + expected_data = [self.create_inputs() for _ in range(0, 10)] + _ = p | beam.Create(expected_data) | WriteToTFRecord( + file_path_prefix + '.gz', shard_name_template='') + + # Read the file back and compare. + with beam.Pipeline(DirectRunner()) as p: + actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz') + beam.assert_that(actual_data, beam.equal_to(expected_data)) + + @unittest.skipIf(tf is None, 'tensorflow not installed.') + def test_end2end_example_proto(self): + file_path_prefix = tempfile.NamedTemporaryFile().name + + example = tf.train.Example() + example.features.feature['int'].int64_list.value.extend(range(3)) + example.features.feature['bytes'].bytes_list.value.extend( + [b'foo', b'bar']) + + with beam.Pipeline(DirectRunner()) as p: + _ = p | beam.Create([example]) | WriteToTFRecord( + file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__)) + + # Read the file back and compare. + with beam.Pipeline(DirectRunner()) as p: + actual_data = (p | ReadFromTFRecord( + file_path_prefix + '-*', + coder=beam.coders.ProtoCoder(example.__class__))) + beam.assert_that(actual_data, beam.equal_to([example])) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() http://git-wip-us.apache.org/repos/asf/beam/blob/88833ba5/sdks/python/setup.py ---------------------------------------------------------------------- diff --git a/sdks/python/setup.py b/sdks/python/setup.py index f6357b6..1fd622f 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -85,6 +85,7 @@ else: REQUIRED_PACKAGES = [ 'avro>=1.7.7,<2.0.0', + 'crcmod>=1.7,<2.0', 'dill>=0.2.5,<0.3', 'google-apitools>=0.5.6,<1.0.0', 'googledatastore==6.4.1',
