damccorm commented on code in PR #27544:
URL: https://github.com/apache/beam/pull/27544#discussion_r1269850397


##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -119,6 +121,82 @@ def expand(
       return pcoll | beam.Map(lambda x: x._asdict())
 
 
+class ComputeAndAttachHashKey(beam.DoFn):
+  """
+  Computues and attaches a hash key to the element.

Review Comment:
   ```suggestion
     Computes and attaches a hash key to the element.
   ```



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -119,6 +121,82 @@ def expand(
       return pcoll | beam.Map(lambda x: x._asdict())
 
 
+class ComputeAndAttachHashKey(beam.DoFn):
+  """
+  Computues and attaches a hash key to the element.
+  Only for internal use. No backwards compatibility guarantees.
+  """
+  def process(self, element):
+    hash_object = hashlib.sha256()
+    for _, value in element.items():
+      # handle the case where value is a list or numpy array
+      if isinstance(value, (list, np.ndarray)):
+        hash_object.update(str(list(value)).encode())
+      else:  # assume value is a primitive that can be turned into str
+        hash_object.update(str(value).encode())
+    yield (hash_object.hexdigest(), element)
+
+
+class GetMissingColumnsPColl(beam.DoFn):
+  """
+  Returns data containing only the columns that are not
+  present in the schema. This is needed since TFT only outputs
+  columns that are transformed by any of the data processing transforms.
+
+  Only for internal use. No backwards compatibility guarantees.
+  """
+  def __init__(self, existing_columns):
+    self.existing_columns = existing_columns
+
+  def process(self, element):
+    new_dict = {}
+    hash_key, element = element
+    for key, value in element.items():
+      if key not in self.existing_columns:
+        new_dict[key] = value
+    yield (hash_key, new_dict)
+
+
+class MakeHashKeyAsColumn(beam.DoFn):
+  """
+  Extracts the hash key from the element and adds it as a column.
+
+  Only for internal use. No backwards compatibility guarantees.
+  """
+  def process(self, element):
+    hash_key, element = element
+    element['hash_key'] = hash_key
+    yield element
+
+
+class ExtractHashAndKeyPColl(beam.DoFn):
+  """
+  Extracts the hash key and return hashkey and element as a tuple.
+
+  Only for internal use. No backwards compatibility guarantees.
+  """
+  def process(self, element):
+    hashkey = element['hash_key']
+    if isinstance(hashkey, np.ndarray):

Review Comment:
   Why do we need this piece? Shouldn't we know exactly how the hashkey is 
defined?



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -369,22 +428,35 @@ def process_data(
     # whether a scalar value or list or np array is passed as input,
     #  we will convert scalar values to list values and TFT will ouput
     # numpy array all the time.
-    raw_data |= beam.ParDo(ConvertScalarValuesToListValues())
+
+    keyed_raw_data = (raw_data | beam.ParDo(ComputeAndAttachHashKey()))
+
+    feature_set = [feature.name for feature in 
raw_data_metadata.schema.feature]
+    columns_not_in_schema_with_hash = (
+        keyed_raw_data
+        | beam.ParDo(GetMissingColumnsPColl(feature_set)))
+
+    keyed_raw_data = keyed_raw_data | beam.ParDo(
+        ConvertScalarValuesToListValues())
+
+    raw_data_list = (keyed_raw_data | beam.ParDo(MakeHashKeyAsColumn()))
 
     with tft_beam.Context(temp_dir=self.artifact_location):
-      data = (raw_data, raw_data_metadata)
+      data = (raw_data_list, raw_data_metadata)
       if self.artifact_mode == ArtifactMode.PRODUCE:
         transform_fn = (
             data
             | "AnalyzeDataset" >> 
tft_beam.AnalyzeDataset(self.process_data_fn))

Review Comment:
   This should be fine - since it gets called as part of expand, it will be 
prefixed by the encapsulating transforms name. We do the same thing with 
RunInference



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -365,26 +446,38 @@ def process_data(
       raw_data_metadata = metadata_io.read_metadata(
           os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
 
+    keyed_raw_data = (raw_data | beam.ParDo(ComputeAndAttachHashKey()))
+
+    feature_set = [feature.name for feature in 
raw_data_metadata.schema.feature]
+    columns_not_in_schema_with_hash = (
+        keyed_raw_data
+        | beam.ParDo(GetMissingColumnsPColl(feature_set)))
+
     # To maintain consistency by outputting numpy array all the time,
     # whether a scalar value or list or np array is passed as input,
     #  we will convert scalar values to list values and TFT will ouput
     # numpy array all the time.
-    raw_data |= beam.ParDo(ConvertScalarValuesToListValues())
+    keyed_raw_data = keyed_raw_data | beam.ParDo(
+        ConvertScalarValuesToListValues())
+
+    raw_data_list = (keyed_raw_data | beam.ParDo(MakeHashKeyAsColumn()))
 
     with tft_beam.Context(temp_dir=self.artifact_location):
-      data = (raw_data, raw_data_metadata)
+      data = (raw_data_list, raw_data_metadata)
       if self.artifact_mode == ArtifactMode.PRODUCE:
         transform_fn = (
             data
             | "AnalyzeDataset" >> 
tft_beam.AnalyzeDataset(self.process_data_fn))
+        # TODO: Remove the 'hash_key' column from the transformed
+        # dataset schema.

Review Comment:
   We do this below, right?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to