peridotml commented on code in PR #23030:
URL: https://github.com/apache/beam/pull/23030#discussion_r967867516


##########
sdks/python/apache_beam/io/parquetio.py:
##########
@@ -83,6 +86,67 @@ def process(self, table, with_filename=False):
         yield row
 
 
+class _RowDictionariesToArrowTable(DoFn):
+  """ A DoFn that consumes python dictionarys and yields a pyarrow table."""
+  def __init__(
+      self,
+      schema,
+      row_group_buffer_size=64 * 1024 * 1024,
+      record_batch_size=1000):
+    self._schema = schema
+    self._row_group_buffer_size = row_group_buffer_size
+    self._buffer = [[] for _ in range(len(schema.names))]
+    self._buffer_size = record_batch_size
+    self._record_batches = []
+    self._record_batches_byte_size = 0
+
+  def process(self, row):
+    if len(self._buffer[0]) >= self._buffer_size:
+      self._flush_buffer()
+
+    if self._record_batches_byte_size >= self._row_group_buffer_size:
+      table = self._create_table()
+      yield table
+
+    # reorder the data in columnar format.
+    for i, n in enumerate(self._schema.names):
+      self._buffer[i].append(row[n])
+
+  def finish_bundle(self):
+    if len(self._buffer[0]) > 0:
+      self._flush_buffer()
+    if self._record_batches_byte_size > 0:
+      table = self._create_table()
+      yield window.GlobalWindows.windowed_value_at_end_of_window(table)
+
+  def display_data(self):
+    res = super().display_data()
+    res['row_group_buffer_size'] = str(self._row_group_buffer_size)
+    res['buffer_size'] = str(self._buffer_size)
+
+    return res
+
+  def _create_table(self):
+    table = pa.Table.from_batches(self._record_batches, schema=self._schema)
+    self._record_batches = []
+    self._record_batches_byte_size = 0
+    return table
+
+  def _flush_buffer(self):
+    arrays = [[] for _ in range(len(self._schema.names))]
+    for x, y in enumerate(self._buffer):
+      arrays[x] = pa.array(y, type=self._schema.types[x])
+      self._buffer[x] = []
+    rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema)
+    self._record_batches.append(rb)
+    size = 0
+    for x in arrays:

Review Comment:
   Both `pa.RecordBatch` and `pa.Table` have the attribute `nbytes`. It could 
be simpler and more performant to just use the attribute. I couldn't get 
documentation to see if it was available before `v1` of pyarrow.  



##########
sdks/python/apache_beam/io/parquetio.py:
##########
@@ -83,6 +86,67 @@ def process(self, table, with_filename=False):
         yield row
 
 
+class _RowDictionariesToArrowTable(DoFn):
+  """ A DoFn that consumes python dictionarys and yields a pyarrow table."""
+  def __init__(
+      self,
+      schema,
+      row_group_buffer_size=64 * 1024 * 1024,
+      record_batch_size=1000):
+    self._schema = schema
+    self._row_group_buffer_size = row_group_buffer_size
+    self._buffer = [[] for _ in range(len(schema.names))]
+    self._buffer_size = record_batch_size
+    self._record_batches = []
+    self._record_batches_byte_size = 0
+
+  def process(self, row):
+    if len(self._buffer[0]) >= self._buffer_size:
+      self._flush_buffer()
+
+    if self._record_batches_byte_size >= self._row_group_buffer_size:
+      table = self._create_table()
+      yield table
+
+    # reorder the data in columnar format.
+    for i, n in enumerate(self._schema.names):
+      self._buffer[i].append(row[n])
+
+  def finish_bundle(self):
+    if len(self._buffer[0]) > 0:
+      self._flush_buffer()
+    if self._record_batches_byte_size > 0:
+      table = self._create_table()
+      yield window.GlobalWindows.windowed_value_at_end_of_window(table)

Review Comment:
   It seems like we need to be careful with windows when using buffers inside 
of DoFns. 
   
   I don't actually know if this works, but I looked at `BatchElements` and 
copied part of it. If there is a working strategy, then I assumed you would let 
me know. I am not experienced with streaming.



##########
sdks/python/apache_beam/io/parquetio.py:
##########
@@ -83,6 +86,67 @@ def process(self, table, with_filename=False):
         yield row
 
 
+class _RowDictionariesToArrowTable(DoFn):
+  """ A DoFn that consumes python dictionarys and yields a pyarrow table."""
+  def __init__(
+      self,
+      schema,
+      row_group_buffer_size=64 * 1024 * 1024,
+      record_batch_size=1000):
+    self._schema = schema
+    self._row_group_buffer_size = row_group_buffer_size
+    self._buffer = [[] for _ in range(len(schema.names))]
+    self._buffer_size = record_batch_size
+    self._record_batches = []
+    self._record_batches_byte_size = 0
+
+  def process(self, row):
+    if len(self._buffer[0]) >= self._buffer_size:
+      self._flush_buffer()
+
+    if self._record_batches_byte_size >= self._row_group_buffer_size:

Review Comment:
   If we switched to using `nbytes` then this could be moved into the sink if 
that makes more sense.



##########
sdks/python/apache_beam/io/parquetio.py:
##########
@@ -83,6 +86,67 @@ def process(self, table, with_filename=False):
         yield row
 
 
+class _RowDictionariesToArrowTable(DoFn):

Review Comment:
   I used `pa.Table` for consistency with reading batched parquets, but I 
wanted to bring up that `pa.RecordBatch` would also work.
   
   I don't have strong preferences either way. I would use it both ways. 



##########
sdks/python/apache_beam/io/parquetio.py:
##########
@@ -565,22 +665,9 @@ def open(self, temp_path):
         use_compliant_nested_type=self._use_compliant_nested_type)
 
   def write_record(self, writer, value):
-    if len(self._buffer[0]) >= self._buffer_size:
-      self._flush_buffer()
-
-    if self._record_batches_byte_size >= self._row_group_buffer_size:
-      self._write_batches(writer)
-
-    # reorder the data in columnar format.
-    for i, n in enumerate(self._schema.names):
-      self._buffer[i].append(value[n])
+    writer.write_table(value)

Review Comment:
   I have used this before where I infer the schema from the first batch that 
comes through. This simplified my model inference scripts and avoided loading 
tensorflow saved models to get output signature and converting them to pyarrow 
types.
   
   It worked locally and on Dataflow. Thought I would throw it out there.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to