tvalentyn commented on code in PR #30146:
URL: https://github.com/apache/beam/pull/30146#discussion_r1476889525


##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -447,22 +409,22 @@ def expand(
       raw_data_metadata = metadata_io.read_metadata(
           os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
 
-    keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID()))
+    keyed_raw_data = (raw_data)  #  | beam.ParDo(_ComputeAndAttachUniqueID()))
 
     feature_set = [feature.name for feature in 
raw_data_metadata.schema.feature]
-    keyed_columns_not_in_schema = (
-        keyed_raw_data
-        | beam.ParDo(_GetMissingColumns(feature_set)))
+    self.data_coder.set_unused_columns(exclude_columns=feature_set)
 
     # To maintain consistency by outputting numpy array all the time,
     # whether a scalar value or list or np array is passed as input,
     #  we will convert scalar values to list values and TFT will ouput
     # numpy array all the time.

Review Comment:
   misplaced comment. also worth explaining in another comment why we do the 
encoding manipulation.



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -215,6 +176,7 @@ def __init__(
     self.artifact_mode = artifact_mode
     if artifact_mode not in ['produce', 'consume']:
       raise ValueError('artifact_mode must be either `produce` or `consume`.')
+    self.data_coder = DataCoder()

Review Comment:
   is there a reason we create it here vs in expand?



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -447,22 +409,22 @@ def expand(
       raw_data_metadata = metadata_io.read_metadata(
           os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
 
-    keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID()))
+    keyed_raw_data = (raw_data)  #  | beam.ParDo(_ComputeAndAttachUniqueID()))
 
     feature_set = [feature.name for feature in 
raw_data_metadata.schema.feature]
-    keyed_columns_not_in_schema = (
-        keyed_raw_data
-        | beam.ParDo(_GetMissingColumns(feature_set)))
+    self.data_coder.set_unused_columns(exclude_columns=feature_set)
 
     # To maintain consistency by outputting numpy array all the time,
     # whether a scalar value or list or np array is passed as input,
     #  we will convert scalar values to list values and TFT will ouput
     # numpy array all the time.
+    raw_data_list = (
+        keyed_raw_data
+        | beam.Map(lambda elem: self.data_coder.encode(elem)))
+
     keyed_raw_data = keyed_raw_data | beam.ParDo(

Review Comment:
    keyed_raw_data is not used after this line. Did you mean to use 
raw_data_list?



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -447,22 +409,22 @@ def expand(
       raw_data_metadata = metadata_io.read_metadata(
           os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
 
-    keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID()))
+    keyed_raw_data = (raw_data)  #  | beam.ParDo(_ComputeAndAttachUniqueID()))

Review Comment:
   leftover comment, also we no longer add  keys , so `keyed_` might not be the 
best name. 



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -447,22 +409,22 @@ def expand(
       raw_data_metadata = metadata_io.read_metadata(
           os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
 
-    keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID()))
+    keyed_raw_data = (raw_data)  #  | beam.ParDo(_ComputeAndAttachUniqueID()))
 
     feature_set = [feature.name for feature in 
raw_data_metadata.schema.feature]
-    keyed_columns_not_in_schema = (
-        keyed_raw_data
-        | beam.ParDo(_GetMissingColumns(feature_set)))
+    self.data_coder.set_unused_columns(exclude_columns=feature_set)
 
     # To maintain consistency by outputting numpy array all the time,
     # whether a scalar value or list or np array is passed as input,
     #  we will convert scalar values to list values and TFT will ouput
     # numpy array all the time.
+    raw_data_list = (

Review Comment:
   for my understanding, why is this called raw_data_list? it's modified, so 
not raw i think, and what's here about `_list`?



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -492,26 +454,13 @@ def expand(
       # So we will use a RowTypeConstraint to create a schema'd PCollection.
       # this is needed since new columns are included in the
       # transformed_dataset.
-      del self.transformed_schema[_ID_COLUMN]
+      del self.transformed_schema[_TEMP_KEY]
       row_type = RowTypeConstraint.from_fields(
           list(self.transformed_schema.items()))
 
-      # If a non schema PCollection is passed, and one of the input columns
-      # is not transformed by any of the transforms, then the output will
-      # not have that column. So we will join the missing columns from the
-      # raw_data to the transformed_dataset.
-      keyed_transformed_dataset = (
-          transformed_dataset | beam.ParDo(_ExtractIdAndKeyPColl()))
-
-      # The grouping is needed here since tensorflow transform only outputs
-      # columns that are transformed by any of the transforms. So we will
-      # join the missing columns from the raw_data to the transformed_dataset
-      # using the id.
       transformed_dataset = (
-          (keyed_transformed_dataset, keyed_columns_not_in_schema)
-          | beam.CoGroupByKey()
-          | beam.ParDo(_MergeDicts()))
-
+          transformed_dataset
+          | "DecodeDict" >> beam.Map(lambda x: self.data_coder.decode(x)))

Review Comment:
   How about we name these stages as:
   ```
   "EncodeUnmodifiedColumns" >> beam.Map
   "DecodeUnmodifiedColumns" >> beam.Map
   ```



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -83,20 +85,52 @@
 tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]]
 
 
+class DataCoder:
+  def __init__(self, exclude_columns=None):
+    """
+    Uses PickleCoder to encode/decode the dictonaries.
+    Args:
+      exclude_columns: list of columns to exclude from the encoding.
+    """
+    self.coder = coders.registry.get_coder(Any)
+    self.exclude_columns = exclude_columns
+
+  def set_unused_columns(self, exclude_columns):
+    self.exclude_columns = exclude_columns
+
+  def encode(self, element):
+    if not self.exclude_columns:

Review Comment:
   interesting. Is it possible for `exclude_columns` be emtpy? I'd imagine it 
could rather be the opposite, where all columns are being processed, so there 
is nothing to encode/decode.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to