This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 77b813a6e6e [BEAM-12554] Create new instances of FileSink in sink_fn 
(#17708)
77b813a6e6e is described below

commit 77b813a6e6e37d5557d57fb1d5e71353bfce6ddc
Author: Yi Hu <[email protected]>
AuthorDate: Tue Jun 7 12:10:39 2022 -0400

    [BEAM-12554] Create new instances of FileSink in sink_fn (#17708)
    
    * [BEAM-12554] Create new instances of FileSink in sink_fn
    
    * add unit test for WriteToFiles dynamic destination
    
    * add test to both type signature and type instance as sink param
---
 sdks/python/apache_beam/io/fileio.py      | 12 ++++++-----
 sdks/python/apache_beam/io/fileio_test.py | 36 +++++++++++++++++++++++++++++++
 2 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/sdks/python/apache_beam/io/fileio.py 
b/sdks/python/apache_beam/io/fileio.py
index 6b88600a97b..d1839e9de0f 100644
--- a/sdks/python/apache_beam/io/fileio.py
+++ b/sdks/python/apache_beam/io/fileio.py
@@ -504,7 +504,8 @@ class WriteToFiles(beam.PTransform):
         given their final names. By default, the temporary directory will be
          within the temp_location of your pipeline.
       sink (callable, FileSink): The sink to use to write into a file. It 
should
-        implement the methods of a ``FileSink``. If none is provided, a
+        implement the methods of a ``FileSink``. Pass a class signature or an
+        instance of FileSink to this parameter. If none is provided, a
         ``TextSink`` is used.
       shards (int): The number of shards per destination and trigger firing.
       max_writers_per_bundle (int): The number of writers that can be open
@@ -525,8 +526,11 @@ class WriteToFiles(beam.PTransform):
   @staticmethod
   def _get_sink_fn(input_sink):
     # type: (...) -> Callable[[Any], FileSink]
-    if isinstance(input_sink, FileSink):
-      return lambda x: input_sink
+    if isinstance(input_sink, type) and issubclass(input_sink, FileSink):
+      return lambda x: input_sink()
+    elif isinstance(input_sink, FileSink):
+      kls = input_sink.__class__
+      return lambda x: kls()
     elif callable(input_sink):
       return input_sink
     else:
@@ -791,7 +795,6 @@ class _WriteUnshardedRecordsFn(beam.DoFn):
   def _get_or_create_writer_and_sink(self, destination, window):
     """Returns a tuple of writer, sink."""
     writer_key = (destination, window)
-
     if writer_key in self._writers_and_sinks:
       return self._writers_and_sinks.get(writer_key)
     elif len(self._writers_and_sinks) >= self.max_num_writers_per_bundle:
@@ -807,7 +810,6 @@ class _WriteUnshardedRecordsFn(beam.DoFn):
           create_metadata_fn=sink.create_metadata)
 
       sink.open(writer)
-
       self._writers_and_sinks[writer_key] = (writer, sink)
       self._file_names[writer_key] = full_file_name
       return self._writers_and_sinks[writer_key]
diff --git a/sdks/python/apache_beam/io/fileio_test.py 
b/sdks/python/apache_beam/io/fileio_test.py
index f21fb8d1796..ab4dba2366c 100644
--- a/sdks/python/apache_beam/io/fileio_test.py
+++ b/sdks/python/apache_beam/io/fileio_test.py
@@ -459,6 +459,42 @@ class WriteFilesTest(_TestCaseWithTempDirCleanUp):
 
       assert_that(result, equal_to([row for row in self.SIMPLE_COLLECTION]))
 
+  def test_write_to_dynamic_destination(self):
+
+    sink_params = [
+        fileio.TextSink, # pass a type signature
+        fileio.TextSink() # pass a FileSink object
+    ]
+
+    for sink in sink_params:
+      dir = self._new_tempdir()
+
+      with TestPipeline() as p:
+        _ = (
+            p
+            | "Create" >> beam.Create(range(100))
+            | beam.Map(lambda x: str(x))
+            | fileio.WriteToFiles(
+                path=dir,
+                destination=lambda n: "odd" if int(n) % 2 else "even",
+                sink=sink,
+                file_naming=fileio.destination_prefix_naming("test")))
+
+      with TestPipeline() as p:
+        result = (
+            p
+            | fileio.MatchFiles(FileSystems.join(dir, '*'))
+            | fileio.ReadMatches()
+            | beam.Map(
+                lambda f: (
+                    os.path.basename(f.metadata.path).split('-')[0],
+                    sorted(map(int, f.read_utf8().strip().split('\n'))))))
+
+        assert_that(
+            result,
+            equal_to([('odd', list(range(1, 100, 2))),
+                      ('even', list(range(0, 100, 2)))]))
+
   def test_write_to_different_file_types_some_spilling(self):
 
     dir = self._new_tempdir()

Reply via email to