This is an automated email from the ASF dual-hosted git repository.
anandinguva pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c004cc7fa14 Remove CoGBK in MLTransform's TFTProcessHandler (#30146)
c004cc7fa14 is described below
commit c004cc7fa1425d14ef2b0c134784f80a42e555a7
Author: Anand Inguva <[email protected]>
AuthorDate: Tue Feb 13 19:19:34 2024 +0000
Remove CoGBK in MLTransform's TFTProcessHandler (#30146)
* Add _Encode and _DecodeDict
* Replace the CoGBK and utils with Encode and Decode utils
* Use PicklerCoder for encoding and decoding elements
* Remove comments
* update coder
* Address comments
* Remove return comment
* Update sdks/python/apache_beam/ml/transforms/handlers.py
Co-authored-by: tvalentyn <[email protected]>
* Make _DataCoder internal
---------
Co-authored-by: tvalentyn <[email protected]>
---
sdks/python/apache_beam/ml/transforms/handlers.py | 163 ++++++++--------------
1 file changed, 57 insertions(+), 106 deletions(-)
diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py
b/sdks/python/apache_beam/ml/transforms/handlers.py
index 3c37ddef1ed..5bcd0d16576 100644
--- a/sdks/python/apache_beam/ml/transforms/handlers.py
+++ b/sdks/python/apache_beam/ml/transforms/handlers.py
@@ -17,9 +17,10 @@
# pytype: skip-file
import collections
+import copy
import os
import typing
-import uuid
+from typing import Any
from typing import Dict
from typing import List
from typing import Optional
@@ -31,6 +32,7 @@ import numpy as np
import apache_beam as beam
import tensorflow as tf
import tensorflow_transform.beam as tft_beam
+from apache_beam import coders
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.transforms.base import ArtifactMode
from apache_beam.ml.transforms.base import ProcessHandler
@@ -50,7 +52,7 @@ __all__ = [
'TFTProcessHandler',
]
-_ID_COLUMN = 'tmp_uuid' # Name for a temporary column.
+_TEMP_KEY = 'CODED_SAMPLE' # key for the encoded sample
RAW_DATA_METADATA_DIR = 'raw_data_metadata'
SCHEMA_FILE = 'schema.pbtxt'
@@ -83,12 +85,41 @@ tft_process_handler_input_type =
typing.Union[typing.NamedTuple,
tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]]
+class _DataCoder:
+ def __init__(
+ self,
+ exclude_columns,
+ coder=coders.registry.get_coder(Any),
+ ):
+ """
+ Encodes/decodes items of a dictionary into a single element.
+ Args:
+ exclude_columns: list of columns to exclude from the encoding.
+ """
+ self.coder = coder
+ self.exclude_columns = exclude_columns
+
+ def encode(self, element):
+ data_to_encode = element.copy()
+ element_to_return = element.copy()
+ for key in self.exclude_columns:
+ if key in data_to_encode:
+ del data_to_encode[key]
+ element_to_return[_TEMP_KEY] = self.coder.encode(data_to_encode)
+ return element_to_return
+
+ def decode(self, element):
+ clone = copy.copy(element)
+ clone.update(self.coder.decode(clone[_TEMP_KEY].item()))
+ del clone[_TEMP_KEY]
+ return clone
+
+
class _ConvertScalarValuesToListValues(beam.DoFn):
def process(
self,
element,
):
- id, element = element
new_dict = {}
for key, value in element.items():
if isinstance(value,
@@ -96,7 +127,7 @@ class _ConvertScalarValuesToListValues(beam.DoFn):
new_dict[key] = [value]
else:
new_dict[key] = value
- yield (id, new_dict)
+ yield new_dict
class _ConvertNamedTupleToDict(
@@ -124,79 +155,6 @@ class _ConvertNamedTupleToDict(
return pcoll | beam.Map(lambda x: x._asdict())
-class _ComputeAndAttachUniqueID(beam.DoFn):
- """
- Computes and attaches a unique id to each element in the PCollection.
- """
- def process(self, element):
- # UUID1 includes machine-specific bits and has a counter. As long as not
too
- # many are generated at the same time, they should be unique.
- # UUID4 generation should be unique in practice as long as underlying
random
- # number generation is not compromised.
- # A combintation of both should avoid the anecdotal pitfalls where
- # replacing one with the other has helped some users.
- # UUID collision will result in data loss, but we can detect that and fail.
-
- # TODO(https://github.com/apache/beam/issues/29593): Evaluate MLTransform
- # implementation without CoGBK.
- unique_key = uuid.uuid1().bytes + uuid.uuid4().bytes
- yield (unique_key, element)
-
-
-class _GetMissingColumns(beam.DoFn):
- """
- Returns data containing only the columns that are not
- present in the schema. This is needed since TFT only outputs
- columns that are transformed by any of the data processing transforms.
- """
- def __init__(self, existing_columns):
- self.existing_columns = existing_columns
-
- def process(self, element):
- id, row_dict = element
- new_dict = {
- k: v
- for k, v in row_dict.items() if k not in self.existing_columns
- }
- yield (id, new_dict)
-
-
-class _MakeIdAsColumn(beam.DoFn):
- """
- Extracts the id from the element and adds it as a column instead.
- """
- def process(self, element):
- id, element = element
- element[_ID_COLUMN] = id
- yield element
-
-
-class _ExtractIdAndKeyPColl(beam.DoFn):
- """
- Extracts the id and return id and element as a tuple.
- """
- def process(self, element):
- id = element[_ID_COLUMN][0]
- del element[_ID_COLUMN]
- yield (id, element)
-
-
-class _MergeDicts(beam.DoFn):
- """
- Merges processed and unprocessed columns from CoGBK result into a single row.
- """
- def process(self, element):
- unused_row_id, row_dicts_tuple = element
- new_dict = {}
- for d in row_dicts_tuple:
- # After CoGBK, dicts with processed and unprocessed portions of each row
- # are wrapped in 1-element lists, since all rows have a unique id.
- # Assertion could fail due to UUID collision.
- assert len(d) == 1, f"Expected 1 element, got: {len(d)}."
- new_dict.update(d[0])
- yield new_dict
-
-
class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
tft_process_handler_output_type]):
def __init__(
@@ -325,7 +283,7 @@ class
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
def get_raw_data_metadata(
self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
- raw_data_feature_spec[_ID_COLUMN] = tf.io.VarLenFeature(dtype=tf.string)
+ raw_data_feature_spec[_TEMP_KEY] = tf.io.VarLenFeature(dtype=tf.string)
return self.convert_raw_data_feature_spec_to_dataset_metadata(
raw_data_feature_spec)
@@ -403,7 +361,6 @@ class
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
artifact_location, which was previously used to store the produced
artifacts.
"""
-
if self.artifact_mode == ArtifactMode.PRODUCE:
# If we are computing artifacts, we should fail for windows other than
# default windowing since for example, for a fixed window, each window
can
@@ -447,24 +404,29 @@ class
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
raw_data_metadata = metadata_io.read_metadata(
os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
- keyed_raw_data = (raw_data | beam.ParDo(_ComputeAndAttachUniqueID()))
-
feature_set = [feature.name for feature in
raw_data_metadata.schema.feature]
- keyed_columns_not_in_schema = (
- keyed_raw_data
- | beam.ParDo(_GetMissingColumns(feature_set)))
+
+ # TFT ignores columns in the input data that aren't explicitly defined
+ # in the schema. This is because TFT operations
+ # are designed to work with a predetermined schema.
+ # To preserve these extra columns without disrupting TFT processing,
+ # they are temporarily encoded as bytes and added to the PCollection with
+ # a unique identifier
+ data_coder = _DataCoder(exclude_columns=feature_set)
+ data_with_encoded_columns = (
+ raw_data
+ | "EncodeUnmodifiedColumns" >>
+ beam.Map(lambda elem: data_coder.encode(elem)))
# To maintain consistency by outputting numpy array all the time,
# whether a scalar value or list or np array is passed as input,
- # we will convert scalar values to list values and TFT will ouput
+ # we will convert scalar values to list values and TFT will ouput
# numpy array all the time.
- keyed_raw_data = keyed_raw_data | beam.ParDo(
+ data_list = data_with_encoded_columns | beam.ParDo(
_ConvertScalarValuesToListValues())
- raw_data_list = (keyed_raw_data | beam.ParDo(_MakeIdAsColumn()))
-
with tft_beam.Context(temp_dir=self.artifact_location):
- data = (raw_data_list, raw_data_metadata)
+ data = (data_list, raw_data_metadata)
if self.artifact_mode == ArtifactMode.PRODUCE:
transform_fn = (
data
@@ -474,7 +436,7 @@ class
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
self.write_transform_artifacts(transform_fn, self.artifact_location)
else:
transform_fn = (
- raw_data_list.pipeline
+ data_list.pipeline
| "ReadTransformFn" >> tft_beam.ReadTransformFn(
self.artifact_location))
(transformed_dataset, transformed_metadata) = (
@@ -492,26 +454,15 @@ class
TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
# So we will use a RowTypeConstraint to create a schema'd PCollection.
# this is needed since new columns are included in the
# transformed_dataset.
- del self.transformed_schema[_ID_COLUMN]
+ del self.transformed_schema[_TEMP_KEY]
row_type = RowTypeConstraint.from_fields(
list(self.transformed_schema.items()))
- # If a non schema PCollection is passed, and one of the input columns
- # is not transformed by any of the transforms, then the output will
- # not have that column. So we will join the missing columns from the
- # raw_data to the transformed_dataset.
- keyed_transformed_dataset = (
- transformed_dataset | beam.ParDo(_ExtractIdAndKeyPColl()))
-
- # The grouping is needed here since tensorflow transform only outputs
- # columns that are transformed by any of the transforms. So we will
- # join the missing columns from the raw_data to the transformed_dataset
- # using the id.
+ # Decode the extra columns that were encoded as bytes.
transformed_dataset = (
- (keyed_transformed_dataset, keyed_columns_not_in_schema)
- | beam.CoGroupByKey()
- | beam.ParDo(_MergeDicts()))
-
+ transformed_dataset
+ |
+ "DecodeUnmodifiedColumns" >> beam.Map(lambda x:
data_coder.decode(x)))
# The schema only contains the columns that are transformed.
transformed_dataset = (
transformed_dataset