This is an automated email from the ASF dual-hosted git repository.
tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 6a573e4 BEAM-13189 Python TextIO: add escapechar feature. (#15901)
6a573e4 is described below
commit 6a573e431a2b4e69fdd6a861c6f54517bbfa3175
Author: Eugene Nikolaiev <[email protected]>
AuthorDate: Thu Nov 11 10:32:33 2021 +0200
BEAM-13189 Python TextIO: add escapechar feature. (#15901)
---
CHANGES.md | 1 +
sdks/python/apache_beam/io/textio.py | 70 ++++++++++++---
sdks/python/apache_beam/io/textio_test.py | 139 +++++++++++++++++++++++++++++-
3 files changed, 198 insertions(+), 12 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index eab4aec..a25c1e8 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -68,6 +68,7 @@
* X feature added (Java/Python)
([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
* Add custom delimiters to Python TextIO reads
([BEAM-12730](https://issues.apache.org/jira/browse/BEAM-12730)).
+* Add escapechar parameter to Python TextIO reads
([BEAM-13189](https://issues.apache.org/jira/browse/BEAM-13189)).
* Splittable reading is enabled by default while reading data with ParquetIO
([BEAM-12070](https://issues.apache.org/jira/browse/BEAM-12070)).
* DoFn Execution Time metrics added to Go
([BEAM-13001](https://issues.apache.org/jira/browse/BEAM-13001)).
* Cross-bundle side input caching is now available in the Go SDK for runners
that support the feature by setting the EnableSideInputCache hook
([BEAM-11097](https://issues.apache.org/jira/browse/BEAM-11097)).
diff --git a/sdks/python/apache_beam/io/textio.py
b/sdks/python/apache_beam/io/textio.py
index 7f9ea6e..f53f9b3 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -100,7 +100,8 @@ class _TextSource(filebasedsource.FileBasedSource):
validate=True,
skip_header_lines=0,
header_processor_fns=(None, None),
- delimiter=None):
+ delimiter=None,
+ escapechar=None):
"""Initialize a _TextSource
Args:
@@ -116,6 +117,8 @@ class _TextSource(filebasedsource.FileBasedSource):
delimiter (bytes) Optional: delimiter to split records.
Must not self-overlap, because self-overlapping delimiters cause
ambiguous parsing.
+ escapechar (bytes) Optional: a single byte to escape the records
+ delimiter, can also escape itself.
Raises:
ValueError: if skip_lines is negative.
@@ -147,6 +150,11 @@ class _TextSource(filebasedsource.FileBasedSource):
if self._is_self_overlapping(delimiter):
raise ValueError('Delimiter must not self-overlap.')
self._delimiter = delimiter
+ if escapechar is not None:
+ if not (isinstance(escapechar, bytes) and len(escapechar) == 1):
+ raise ValueError(
+ "escapechar must be bytes of size 1: '%s'" % escapechar)
+ self._escapechar = escapechar
def display_data(self):
parent_dd = super().display_data()
@@ -176,7 +184,7 @@ class _TextSource(filebasedsource.FileBasedSource):
start_offset = max(start_offset, position_after_processing_header_lines)
if start_offset > position_after_processing_header_lines:
# Seeking to one delimiter length before the start index and ignoring
- # the current line. If start_position is at beginning if the line, that
+ # the current line. If start_position is at beginning of the line, that
# line belongs to the current bundle, hence ignoring that is incorrect.
# Seeking to one delimiter before prevents that.
@@ -185,6 +193,16 @@ class _TextSource(filebasedsource.FileBasedSource):
else:
required_position = start_offset - 1
+ if self._escapechar is not None:
+ # Need more bytes to check if the delimiter is escaped.
+ # Seek until the first escapechar if any.
+ while required_position > 0:
+ file_to_read.seek(required_position - 1)
+ if file_to_read.read(1) == self._escapechar:
+ required_position -= 1
+ else:
+ break
+
file_to_read.seek(required_position)
read_buffer.reset()
sep_bounds = self._find_separator_bounds(file_to_read, read_buffer)
@@ -277,11 +295,22 @@ class _TextSource(filebasedsource.FileBasedSource):
if next_delim >= 0:
if (self._delimiter is None and
read_buffer.data[next_delim - 1:next_delim] == b'\r'):
- # Accept both '\r\n' and '\n' as a default delimiter.
- return (next_delim - 1, next_delim + 1)
+ if self._escapechar is not None and self._is_escaped(read_buffer,
+ next_delim - 1):
+ # Accept '\n' as a default delimiter, because '\r' is escaped.
+ return (next_delim, next_delim + 1)
+ else:
+ # Accept both '\r\n' and '\n' as a default delimiter.
+ return (next_delim - 1, next_delim + 1)
else:
- # Found a delimiter. Accepting that as the next delimiter.
- return (next_delim, next_delim + delimiter_len)
+ if self._escapechar is not None and self._is_escaped(read_buffer,
+ next_delim):
+ # Skip an escaped delimiter.
+ current_pos = next_delim + delimiter_len + 1
+ continue
+ else:
+ # Found a delimiter. Accepting that as the next delimiter.
+ return (next_delim, next_delim + delimiter_len)
elif self._delimiter is not None:
# Corner case: custom delimiter is truncated at the end of the buffer.
@@ -362,6 +391,17 @@ class _TextSource(filebasedsource.FileBasedSource):
return True
return False
+ def _is_escaped(self, read_buffer, position):
+ # Returns True if byte at position is preceded with an odd number
+ # of escapechar bytes or False if preceded by 0 or even escapes
+ # (the even number means that all the escapes are escaped themselves).
+ escape_count = 0
+ for current_pos in reversed(range(0, position)):
+ if read_buffer.data[current_pos:current_pos + 1] != self._escapechar:
+ break
+ escape_count += 1
+ return escape_count % 2 == 1
+
class _TextSourceWithFilename(_TextSource):
def read_records(self, file_name, range_tracker):
@@ -467,7 +507,8 @@ def _create_text_source(
strip_trailing_newlines=None,
coder=None,
skip_header_lines=None,
- delimiter=None):
+ delimiter=None,
+ escapechar=None):
return _TextSource(
file_pattern=file_pattern,
min_bundle_size=min_bundle_size,
@@ -476,7 +517,8 @@ def _create_text_source(
coder=coder,
validate=False,
skip_header_lines=skip_header_lines,
- delimiter=delimiter)
+ delimiter=delimiter,
+ escapechar=escapechar)
class ReadAllFromText(PTransform):
@@ -508,6 +550,7 @@ class ReadAllFromText(PTransform):
skip_header_lines=0,
with_filename=False,
delimiter=None,
+ escapechar=None,
**kwargs):
"""Initialize the ``ReadAllFromText`` transform.
@@ -535,6 +578,8 @@ class ReadAllFromText(PTransform):
delimiter (bytes) Optional: delimiter to split records.
Must not self-overlap, because self-overlapping delimiters cause
ambiguous parsing.
+ escapechar (bytes) Optional: a single byte to escape the records
+ delimiter, can also escape itself.
"""
super().__init__(**kwargs)
source_from_file = partial(
@@ -544,7 +589,8 @@ class ReadAllFromText(PTransform):
strip_trailing_newlines=strip_trailing_newlines,
coder=coder,
skip_header_lines=skip_header_lines,
- delimiter=delimiter)
+ delimiter=delimiter,
+ escapechar=escapechar)
self._desired_bundle_size = desired_bundle_size
self._min_bundle_size = min_bundle_size
self._compression_type = compression_type
@@ -585,6 +631,7 @@ class ReadFromText(PTransform):
validate=True,
skip_header_lines=0,
delimiter=None,
+ escapechar=None,
**kwargs):
"""Initialize the :class:`ReadFromText` transform.
@@ -611,6 +658,8 @@ class ReadFromText(PTransform):
delimiter (bytes) Optional: delimiter to split records.
Must not self-overlap, because self-overlapping delimiters cause
ambiguous parsing.
+ escapechar (bytes) Optional: a single byte to escape the records
+ delimiter, can also escape itself.
"""
super().__init__(**kwargs)
@@ -622,7 +671,8 @@ class ReadFromText(PTransform):
coder,
validate=validate,
skip_header_lines=skip_header_lines,
- delimiter=delimiter)
+ delimiter=delimiter,
+ escapechar=escapechar)
def expand(self, pvalue):
return pvalue.pipeline | Read(self._source)
diff --git a/sdks/python/apache_beam/io/textio_test.py
b/sdks/python/apache_beam/io/textio_test.py
index ae53234..f6e0dfb 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -163,14 +163,16 @@ class TextSourceTest(unittest.TestCase):
expected_data,
buffer_size=DEFAULT_NUM_RECORDS,
compression=CompressionTypes.UNCOMPRESSED,
- delimiter=None):
+ delimiter=None,
+ escapechar=None):
# Since each record usually takes more than 1 byte, default buffer size is
# smaller than the total size of the file. This is done to
# increase test coverage for cases that hit the buffer boundary.
-
kwargs = {}
if delimiter:
kwargs['delimiter'] = delimiter
+ if escapechar:
+ kwargs['escapechar'] = escapechar
source = TextSource(
file_or_pattern,
0,
@@ -1228,6 +1230,139 @@ class TextSourceTest(unittest.TestCase):
assert len(expected_data) == 3
self._run_read_test(file_name, expected_data, buffer_size=6)
+ def test_read_escaped_lf(self):
+ file_name, expected_data = write_data(
+ self.DEFAULT_NUM_RECORDS, eol=EOL.LF, line_value=b'li\\\nne')
+ assert len(expected_data) == self.DEFAULT_NUM_RECORDS
+ self._run_read_test(file_name, expected_data, escapechar=b'\\')
+
+ def test_read_escaped_crlf(self):
+ file_name, expected_data = write_data(
+ TextSource.DEFAULT_READ_BUFFER_SIZE,
+ eol=EOL.CRLF,
+ line_value=b'li\\\r\\\nne')
+ assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE
+ self._run_read_test(file_name, expected_data, escapechar=b'\\')
+
+ def test_read_escaped_cr_before_not_escaped_lf(self):
+ file_name, expected_data_temp = write_data(
+ self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne')
+ expected_data = []
+ for line in expected_data_temp:
+ expected_data += line.split("\n")
+ assert len(expected_data) == self.DEFAULT_NUM_RECORDS * 2
+ self._run_read_test(file_name, expected_data, escapechar=b'\\')
+
+ def test_read_escaped_custom_delimiter_crlf(self):
+ file_name, expected_data = write_data(
+ self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne')
+ assert len(expected_data) == self.DEFAULT_NUM_RECORDS
+ self._run_read_test(
+ file_name, expected_data, delimiter=b'\r\n', escapechar=b'\\')
+
+ def test_read_escaped_custom_delimiter(self):
+ file_name, expected_data = write_data(
+ TextSource.DEFAULT_READ_BUFFER_SIZE,
+ eol=EOL.CUSTOM_DELIMITER,
+ custom_delimiter=b'*|',
+ line_value=b'li\\*|ne')
+ assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE
+ self._run_read_test(
+ file_name, expected_data, delimiter=b'*|', escapechar=b'\\')
+
+ def test_read_escaped_lf_at_buffer_edge(self):
+ file_name, expected_data = write_data(3, eol=EOL.LF,
line_value=b'line\\\n')
+ assert len(expected_data) == 3
+ self._run_read_test(
+ file_name, expected_data, buffer_size=5, escapechar=b'\\')
+
+ def test_read_escaped_crlf_split_by_buffer(self):
+ file_name, expected_data = write_data(
+ 3, eol=EOL.CRLF, line_value=b'line\\\r\n')
+ assert len(expected_data) == 3
+ self._run_read_test(
+ file_name,
+ expected_data,
+ buffer_size=6,
+ delimiter=b'\r\n',
+ escapechar=b'\\')
+
+ def test_read_escaped_lf_after_splitting(self):
+ file_name, expected_data = write_data(3, line_value=b'line\\\n')
+ assert len(expected_data) == 3
+ source = TextSource(
+ file_name,
+ 0,
+ CompressionTypes.UNCOMPRESSED,
+ True,
+ coders.StrUtf8Coder(),
+ escapechar=b'\\')
+ splits = list(source.split(desired_bundle_size=6))
+
+ reference_source_info = (source, None, None)
+ sources_info = ([(split.source, split.start_position, split.stop_position)
+ for split in splits])
+ source_test_utils.assert_sources_equal_reference_source(
+ reference_source_info, sources_info)
+
+ def test_read_escaped_lf_after_splitting_many(self):
+ file_name, expected_data = write_data(
+ 3, line_value=b'\\\\\\\\\\\n') # 5 escapes
+ assert len(expected_data) == 3
+ source = TextSource(
+ file_name,
+ 0,
+ CompressionTypes.UNCOMPRESSED,
+ True,
+ coders.StrUtf8Coder(),
+ escapechar=b'\\')
+ splits = list(source.split(desired_bundle_size=6))
+
+ reference_source_info = (source, None, None)
+ sources_info = ([(split.source, split.start_position, split.stop_position)
+ for split in splits])
+ source_test_utils.assert_sources_equal_reference_source(
+ reference_source_info, sources_info)
+
+ def test_read_escaped_escapechar_after_splitting(self):
+ file_name, expected_data = write_data(3, line_value=b'line\\\\*|')
+ assert len(expected_data) == 3
+ source = TextSource(
+ file_name,
+ 0,
+ CompressionTypes.UNCOMPRESSED,
+ True,
+ coders.StrUtf8Coder(),
+ delimiter=b'*|',
+ escapechar=b'\\')
+ splits = list(source.split(desired_bundle_size=8))
+
+ reference_source_info = (source, None, None)
+ sources_info = ([(split.source, split.start_position, split.stop_position)
+ for split in splits])
+ source_test_utils.assert_sources_equal_reference_source(
+ reference_source_info, sources_info)
+
+ def test_read_escaped_escapechar_after_splitting_many(self):
+ file_name, expected_data = write_data(
+ 3, line_value=b'\\\\\\\\\\\\*|') # 6 escapes
+ assert len(expected_data) == 3
+ source = TextSource(
+ file_name,
+ 0,
+ CompressionTypes.UNCOMPRESSED,
+ True,
+ coders.StrUtf8Coder(),
+ delimiter=b'*|',
+ escapechar=b'\\')
+ splits = list(source.split(desired_bundle_size=8))
+
+ reference_source_info = (source, None, None)
+ sources_info = ([(split.source, split.start_position, split.stop_position)
+ for split in splits])
+ source_test_utils.assert_sources_equal_reference_source(
+ reference_source_info, sources_info)
+
class TextSinkTest(unittest.TestCase):
def setUp(self):