This is an automated email from the ASF dual-hosted git repository.

anandinguva pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c004cc7fa14 Remove CoGBK in MLTransform's TFTProcessHandler (#30146)
c004cc7fa14 is described below

commit c004cc7fa1425d14ef2b0c134784f80a42e555a7
Author: Anand Inguva <[email protected]>
AuthorDate: Tue Feb 13 19:19:34 2024 +0000

    Remove CoGBK in MLTransform's TFTProcessHandler (#30146)
    
    * Add _Encode and _DecodeDict
    
    * Replace the CoGBK and utils with Encode and Decode utils
    
    * Use PicklerCoder for encoding and decoding elements
    
    * Remove comments
    
    * update coder
    
    * Address comments
    
    * Remove return comment
    
    * Update sdks/python/apache_beam/ml/transforms/handlers.py
    
    Co-authored-by: tvalentyn <[email protected]>
    
    * Make _DataCoder internal
    
    ---------
    
    Co-authored-by: tvalentyn <[email protected]>
---
 sdks/python/apache_beam/ml/transforms/handlers.py | 163 ++++++++--------------
 1 file changed, 57 insertions(+), 106 deletions(-)

diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py 
b/sdks/python/apache_beam/ml/transforms/handlers.py
index 3c37ddef1ed..5bcd0d16576 100644
--- a/sdks/python/apache_beam/ml/transforms/handlers.py
+++ b/sdks/python/apache_beam/ml/transforms/handlers.py
@@ -17,9 +17,10 @@
 # pytype: skip-file
 
 import collections
+import copy
 import os
 import typing
-import uuid
+from typing import Any
 from typing import Dict
 from typing import List
 from typing import Optional
@@ -31,6 +32,7 @@ import numpy as np
 import apache_beam as beam
 import tensorflow as tf
 import tensorflow_transform.beam as tft_beam
+from apache_beam import coders
 from apache_beam.io.filesystems import FileSystems
 from apache_beam.ml.transforms.base import ArtifactMode
 from apache_beam.ml.transforms.base import ProcessHandler
@@ -50,7 +52,7 @@ __all__ = [
     'TFTProcessHandler',
 ]
 
-_ID_COLUMN = 'tmp_uuid'  # Name for a temporary column.
+_TEMP_KEY = 'CODED_SAMPLE'  # key for the encoded sample
 
 RAW_DATA_METADATA_DIR = 'raw_data_metadata'
 SCHEMA_FILE = 'schema.pbtxt'
@@ -83,12 +85,41 @@ tft_process_handler_input_type = 
typing.Union[typing.NamedTuple,
 tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]]
 
 
+class _DataCoder:
+  def __init__(
+      self,
+      exclude_columns,
+      coder=coders.registry.get_coder(Any),
+  ):
+    """
+    Encodes/decodes items of a dictionary into a single element.
+    Args:
+      exclude_columns: list of columns to exclude from the encoding.
+    """
+    self.coder = coder
+    self.exclude_columns = exclude_columns
+
+  def encode(self, element):
+    data_to_encode = element.copy()
+    element_to_return = element.copy()
+    for key in self.exclude_columns:
+      if key in data_to_encode:
+        del data_to_encode[key]
+    element_to_return[_TEMP_KEY] = self.coder.encode(data_to_encode)
+    return element_to_return
+
+  def decode(self, element):
+    clone = copy.copy(element)
+    clone.update(self.coder.decode(clone[_TEMP_KEY].item()))
+    del clone[_TEMP_KEY]
+    return clone
+
+
 class _ConvertScalarValuesToListValues(beam.DoFn):
   def process(
       self,
       element,
   ):
-    id, element = element
     new_dict = {}
     for key, value in element.items():
       if isinstance(value,
@@ -96,7 +127,7 @@ class _ConvertScalarValuesToListValues(beam.DoFn):
         new_dict[key] = [value]
       else:
         new_dict[key] = value
-    yield (id, new_dict)
+    yield new_dict
 
 
 class _ConvertNamedTupleToDict(
@@ -124,79 +155,6 @@ class _ConvertNamedTupleToDict(
       return pcoll | beam.Map(lambda x: x._asdict())
 
 
-class _ComputeAndAttachUniqueID(beam.DoFn):
-  """
-  Computes and attaches a unique id to each element in the PCollection.
-  """
-  def process(self, element):
-    # UUID1 includes machine-specific bits and has a counter. As long as not 
too
-    # many are generated at the same time, they should be unique.
-    # UUID4 generation should be unique in practice as long as underlying 
random
-    # number generation is not compromised.
-    # A combintation of both should avoid the anecdotal pitfalls where
-    # replacing one with the other has helped some users.
-    # UUID collision will result in data loss, but we can detect that and fail.
-
-    # TODO(https://github.com/apache/beam/issues/29593): Evaluate MLTransform
-    # implementation without CoGBK.
-    unique_key = uuid.uuid1().bytes + uuid.uuid4().bytes
-    yield (unique_key, element)
-
-
-class _GetMissingColumns(beam.DoFn):
-  """
-  Returns data containing only the columns that are not
-  present in the schema. This is needed since TFT only outputs
-  columns that are transformed by any of the data processing transforms.
-  """
-  def __init__(self, existing_columns):
-    self.existing_columns = existing_columns
-
-  def process(self, element):
-    id, row_dict = element
-    new_dict = {
-        k: v
-        for k, v in row_dict.items() if k not in self.existing_columns
-    }
-    yield (id, new_dict)
-
-
-class _MakeIdAsColumn(beam.DoFn):
-  """
-  Extracts the id from the element and adds it as a column instead.
-  """
-  def process(self, element):
-    id, element = element
-    element[_ID_COLUMN] = id
-    yield element
-
-
-class _ExtractIdAndKeyPColl(beam.DoFn):
-  """
-  Extracts the id and return id and element as a tuple.
-  """
-  def process(self, element):
-    id = element[_ID_COLUMN][0]
-    del element[_ID_COLUMN]
-    yield (id, element)
-
-
-class _MergeDicts(beam.DoFn):
-  """
-  Merges processed and unprocessed columns from CoGBK result into a single row.
-  """
-  def process(self, element):
-    unused_row_id, row_dicts_tuple = element
-    new_dict = {}
-    for d in row_dicts_tuple:
-      # After CoGBK, dicts with processed and unprocessed portions of each row
-      # are wrapped in 1-element lists, since all rows have a unique id.
-      # Assertion could fail due to UUID collision.
-      assert len(d) == 1, f"Expected 1 element, got: {len(d)}."
-      new_dict.update(d[0])
-    yield new_dict
-
-
 class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
                                        tft_process_handler_output_type]):
   def __init__(
@@ -325,7 +283,7 @@ class 
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
   def get_raw_data_metadata(
       self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
     raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
-    raw_data_feature_spec[_ID_COLUMN] = tf.io.VarLenFeature(dtype=tf.string)
+    raw_data_feature_spec[_TEMP_KEY] = tf.io.VarLenFeature(dtype=tf.string)
     return self.convert_raw_data_feature_spec_to_dataset_metadata(
         raw_data_feature_spec)
 
@@ -403,7 +361,6 @@ class 
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
     artifact_location, which was previously used to store the produced
     artifacts.
     """
-
     if self.artifact_mode == ArtifactMode.PRODUCE:
       # If we are computing artifacts, we should fail for windows other than
       # default windowing since for example, for a fixed window, each window 
can
@@ -447,24 +404,29 @@ class 
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
       raw_data_metadata = metadata_io.read_metadata(
           os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
 
-    keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID()))
-
     feature_set = [feature.name for feature in 
raw_data_metadata.schema.feature]
-    keyed_columns_not_in_schema = (
-        keyed_raw_data
-        | beam.ParDo(_GetMissingColumns(feature_set)))
+
+    # TFT ignores columns in the input data that aren't explicitly defined
+    # in the schema. This is because TFT operations
+    # are designed to work with a predetermined schema.
+    # To preserve these extra columns without disrupting TFT processing,
+    # they are temporarily encoded as bytes and added to the PCollection with
+    # a unique identifier
+    data_coder = _DataCoder(exclude_columns=feature_set)
+    data_with_encoded_columns = (
+        raw_data
+        | "EncodeUnmodifiedColumns" >>
+        beam.Map(lambda elem: data_coder.encode(elem)))
 
     # To maintain consistency by outputting numpy array all the time,
     # whether a scalar value or list or np array is passed as input,
-    #  we will convert scalar values to list values and TFT will ouput
+    # we will convert scalar values to list values and TFT will ouput
     # numpy array all the time.
-    keyed_raw_data = keyed_raw_data | beam.ParDo(
+    data_list = data_with_encoded_columns | beam.ParDo(
         _ConvertScalarValuesToListValues())
 
-    raw_data_list = (keyed_raw_data | beam.ParDo(_MakeIdAsColumn()))
-
     with tft_beam.Context(temp_dir=self.artifact_location):
-      data = (raw_data_list, raw_data_metadata)
+      data = (data_list, raw_data_metadata)
       if self.artifact_mode == ArtifactMode.PRODUCE:
         transform_fn = (
             data
@@ -474,7 +436,7 @@ class 
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
         self.write_transform_artifacts(transform_fn, self.artifact_location)
       else:
         transform_fn = (
-            raw_data_list.pipeline
+            data_list.pipeline
             | "ReadTransformFn" >> tft_beam.ReadTransformFn(
                 self.artifact_location))
       (transformed_dataset, transformed_metadata) = (
@@ -492,26 +454,15 @@ class 
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
       # So we will use a RowTypeConstraint to create a schema'd PCollection.
       # this is needed since new columns are included in the
       # transformed_dataset.
-      del self.transformed_schema[_ID_COLUMN]
+      del self.transformed_schema[_TEMP_KEY]
       row_type = RowTypeConstraint.from_fields(
           list(self.transformed_schema.items()))
 
-      # If a non schema PCollection is passed, and one of the input columns
-      # is not transformed by any of the transforms, then the output will
-      # not have that column. So we will join the missing columns from the
-      # raw_data to the transformed_dataset.
-      keyed_transformed_dataset = (
-          transformed_dataset | beam.ParDo(_ExtractIdAndKeyPColl()))
-
-      # The grouping is needed here since tensorflow transform only outputs
-      # columns that are transformed by any of the transforms. So we will
-      # join the missing columns from the raw_data to the transformed_dataset
-      # using the id.
+      # Decode the extra columns that were encoded as bytes.
       transformed_dataset = (
-          (keyed_transformed_dataset, keyed_columns_not_in_schema)
-          | beam.CoGroupByKey()
-          | beam.ParDo(_MergeDicts()))
-
+          transformed_dataset
+          |
+          "DecodeUnmodifiedColumns" >> beam.Map(lambda x: 
data_coder.decode(x)))
       # The schema only contains the columns that are transformed.
       transformed_dataset = (
           transformed_dataset

Reply via email to