liferoad commented on code in PR #27544:
URL: https://github.com/apache/beam/pull/27544#discussion_r1268576746


##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -369,22 +428,35 @@ def process_data(
     # whether a scalar value or list or np array is passed as input,
     #  we will convert scalar values to list values and TFT will ouput
     # numpy array all the time.
-    raw_data |= beam.ParDo(ConvertScalarValuesToListValues())
+
+    keyed_raw_data = (raw_data | beam.ParDo(ComputeAndAttachHashKey()))
+
+    feature_set = [feature.name for feature in 
raw_data_metadata.schema.feature]
+    columns_not_in_schema_with_hash = (
+        keyed_raw_data
+        | beam.ParDo(GetMissingColumnsPColl(feature_set)))
+
+    keyed_raw_data = keyed_raw_data | beam.ParDo(
+        ConvertScalarValuesToListValues())
+
+    raw_data_list = (keyed_raw_data | beam.ParDo(MakeHashKeyAsColumn()))
 
     with tft_beam.Context(temp_dir=self.artifact_location):
-      data = (raw_data, raw_data_metadata)
+      data = (raw_data_list, raw_data_metadata)
       if self.artifact_mode == ArtifactMode.PRODUCE:
         transform_fn = (
             data
             | "AnalyzeDataset" >> 
tft_beam.AnalyzeDataset(self.process_data_fn))

Review Comment:
   will the fixed labels for PTransform used in this file cause the issues if 
users use this multiple times in a Beam pipeline?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to