damccorm commented on code in PR #26795:
URL: https://github.com/apache/beam/pull/26795#discussion_r1239891302
##########
sdks/python/tox.ini:
##########
@@ -326,6 +326,12 @@ commands =
# Run all DataFrame API unit tests
bash {toxinidir}/scripts/run_pytest.sh {envname} 'apache_beam/dataframe'
+[testenv:py{38,39}-tft-113]
Review Comment:
Any reason to limit this to 3.8/9 (and not 3.10/11)?
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Generic
+from typing import TypeVar
+
+import apache_beam as beam
+
+# TODO: Abstract methods are not getting pickled with dill.
Review Comment:
Does this TODO still apply? What are the consequences?
##########
sdks/python/tox.ini:
##########
@@ -326,6 +326,12 @@ commands =
# Run all DataFrame API unit tests
bash {toxinidir}/scripts/run_pytest.sh {envname} 'apache_beam/dataframe'
+[testenv:py{38,39}-tft-113]
Review Comment:
And should we trigger any non-3.8 versions in a precommit? Maybe 3.11 to get
lowest/highest?
##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -0,0 +1,436 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import collections
+import logging
+import os
+import tempfile
+import typing
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import numpy as np
+
+import apache_beam as beam
+from apache_beam.ml.transforms.base import ArtifactMode
+from apache_beam.ml.transforms.base import _ProcessHandler
+from apache_beam.ml.transforms.base import ProcessInputT
+from apache_beam.ml.transforms.base import ProcessOutputT
+from apache_beam.ml.transforms.tft_transforms import TFTOperation
+from apache_beam.ml.transforms.tft_transforms import _EXPECTED_TYPES
+from apache_beam.typehints import native_type_compatibility
+from apache_beam.typehints.row_type import RowTypeConstraint
+import pyarrow as pa
+import tensorflow as tf
+from tensorflow_metadata.proto.v0 import schema_pb2
+import tensorflow_transform.beam as tft_beam
+from tensorflow_transform import common_types
+from tensorflow_transform.beam.tft_beam_io import beam_metadata_io
+from tensorflow_transform.beam.tft_beam_io import transform_fn_io
+from tensorflow_transform.tf_metadata import dataset_metadata
+from tensorflow_transform.tf_metadata import metadata_io
+from tensorflow_transform.tf_metadata import schema_utils
+from tfx_bsl.tfxio import tf_example_record
+
+__all__ = [
+ 'TFTProcessHandler',
+]
+
+RAW_DATA_METADATA_DIR = 'raw_data_metadata'
+SCHEMA_FILE = 'schema.pbtxt'
+# tensorflow transform doesn't support the types other than tf.int64,
+# tf.float32 and tf.string.
+_default_type_to_tensor_type_map = {
+ int: tf.int64,
+ float: tf.float32,
+ str: tf.string,
+ bytes: tf.string,
+ np.int64: tf.int64,
+ np.int32: tf.int64,
+ np.float32: tf.float32,
+ np.float64: tf.float32,
+ np.bytes_: tf.string,
+ np.str_: tf.string,
+}
+_primitive_types_to_typing_container_type = {
+ int: List[int], float: List[float], str: List[str], bytes: List[bytes]
+}
+
+tft_process_handler_input_type = typing.Union[typing.NamedTuple,
+ beam.Row,
+ Dict[str,
+ typing.Union[str,
+ float,
+ int,
+ bytes,
+ np.ndarray]]]
+
+
+class ConvertScalarValuesToListValues(beam.DoFn):
+ def process(
+ self, element: Dict[str, typing.Any]
+ ) -> typing.Iterable[Dict[str, typing.List[typing.Any]]]:
+ new_dict = {}
+ for key, value in element.items():
+ if isinstance(value,
+ tuple(_primitive_types_to_typing_container_type.keys())):
+ new_dict[key] = [value]
+ else:
+ new_dict[key] = value
+ yield new_dict
+
+
+class ConvertNamedTupleToDict(
+ beam.PTransform[beam.PCollection[typing.Union[beam.Row,
typing.NamedTuple]],
+ beam.PCollection[Dict[str,
+ common_types.InstanceDictType]]]):
+ """
+ A PTransform that converts a collection of NamedTuples or Rows into a
+ collection of dictionaries.
+ """
+ def expand(
+ self, pcoll: beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]]
+ ) -> beam.PCollection[common_types.InstanceDictType]:
+ """
+ Args:
+ pcoll: A PCollection of NamedTuples or Rows.
+ Returns:
+ A PCollection of dictionaries.
+ """
+ if isinstance(pcoll.element_type, RowTypeConstraint):
+ # Row instance
+ return pcoll | beam.Map(lambda x: x.as_dict())
+ else:
+ # named tuple
+ return pcoll | beam.Map(lambda x: x._asdict())
+
+
+class TFTProcessHandler(_ProcessHandler[ProcessInputT, ProcessOutputT]):
+ def __init__(
+ self,
+ *,
+ artifact_location: str = None,
+ transforms: Optional[List[TFTOperation]] = None,
+ preprocessing_fn: typing.Optional[typing.Callable] = None,
+ is_input_record_batches: bool = False,
+ output_record_batches: bool = False,
+ artifact_mode: str = ArtifactMode.PRODUCE):
+ """
+ A handler class for processing data with TensorFlow Transform (TFT)
+ operations. This class is intended to be subclassed, with subclasses
+ implementing the `preprocessing_fn` method.
+ """
+ self.transforms = transforms if transforms else []
+ self.transformed_schema = None
+ self.artifact_location = artifact_location
+ self.preprocessing_fn = preprocessing_fn
+ self.is_input_record_batches = is_input_record_batches
+ self.output_record_batches = output_record_batches
+ self.artifact_mode = artifact_mode
+ if artifact_mode not in ['produce', 'consume']:
+ raise ValueError('artifact_mode must be either `produce` or `consume`.')
+
+ if not self.artifact_location:
+ self.artifact_location = tempfile.mkdtemp()
+
+ def append_transform(self, transform):
+ self.transforms.append(transform)
+
+ def _map_column_names_to_types(self, row_type):
+ """
+ Return a dictionary of column names and types.
+ Args:
+ element_type: A type of the element. This could be a NamedTuple or a Row.
+ Returns:
+ A dictionary of column names and types.
+ """
+ try:
+ if not isinstance(row_type, RowTypeConstraint):
+ row_type = RowTypeConstraint.from_user_type(row_type)
+
+ inferred_types = {name: typ for name, typ in row_type._fields}
+
+ for k, t in inferred_types.items():
+ if t in _primitive_types_to_typing_container_type:
+ inferred_types[k] = _primitive_types_to_typing_container_type[t]
+
+ # sometimes a numpy type can be provided as np.dtype('int64').
+ # convert numpy.dtype to numpy type since both are same.
+ for name, typ in inferred_types.items():
+ if isinstance(typ, np.dtype):
+ inferred_types[name] = typ.type
+
+ return inferred_types
+ except: # pylint: disable=bare-except
+ return {}
+
+ def _map_column_names_to_types_from_transforms(self):
+ column_type_mapping = {}
+ for transform in self.transforms:
+ for col in transform.columns:
+ if col not in column_type_mapping:
+ # we just need to dtype of first occurance of column in transforms.
+ class_name = transform.__class__.__name__
+ if class_name not in _EXPECTED_TYPES:
+ raise KeyError(
+ f"Transform {class_name} is not registered with a supported "
+ "type. Please register the transform with a supported type "
+ "using register_input_dtype decorator.")
+ column_type_mapping[col] = _EXPECTED_TYPES[
+ transform.__class__.__name__]
+ return column_type_mapping
+
+ def get_raw_data_feature_spec(
+ self, input_types: Dict[str, type]) -> Dict[str, tf.io.VarLenFeature]:
+ """
+ Return a DatasetMetadata object to be used with
+ tft_beam.AnalyzeAndTransformDataset.
+ Args:
+ input_types: A dictionary of column names and types.
+ Returns:
+ A DatasetMetadata object.
+ """
+ raw_data_feature_spec = {}
+ for key, value in input_types.items():
+ raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column(
+ typ=value, col_name=key)
+ return raw_data_feature_spec
+
+ def convert_raw_data_feature_spec_to_dataset_metadata(
+ self, raw_data_feature_spec) -> dataset_metadata.DatasetMetadata:
+ raw_data_metadata = dataset_metadata.DatasetMetadata(
+ schema_utils.schema_from_feature_spec(raw_data_feature_spec))
+ return raw_data_metadata
+
+ def _get_raw_data_feature_spec_per_column(
+ self, typ: type, col_name: str) -> tf.io.VarLenFeature:
+ """
+ Return a FeatureSpec object to be used with
+ tft_beam.AnalyzeAndTransformDataset
+ Args:
+ typ: A type of the column.
+ col_name: A name of the column.
+ Returns:
+ A FeatureSpec object.
+ """
+ # lets conver the builtin types to typing types for consistency.
+ typ = native_type_compatibility.convert_builtin_to_typing(typ)
+ primitive_containers_type = (
+ list,
+ collections.abc.Sequence,
+ )
+ is_primitive_container = (
+ typing.get_origin(typ) in primitive_containers_type)
+
+ if is_primitive_container:
+ dtype = typing.get_args(typ)[0] # type: ignore[attr-defined]
+ if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union:
# type: ignore[attr-defined]
+ raise RuntimeError(
+ f"Union type is not supported for column: {col_name}. "
+ f"Please pass a PCollection with valid schema for column "
+ f"{col_name} by passing a single type "
+ "in container. For example, List[int].")
+ elif issubclass(typ, np.generic) or typ in
_default_type_to_tensor_type_map:
+ dtype = typ
+ else:
+ raise TypeError(
+ f"Unable to identify type: {typ} specified on column: {col_name}. "
+ f"Please provide a valid type from the following: "
+ f"{_default_type_to_tensor_type_map.keys()}")
+ return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype])
+
+ def get_raw_data_metadata(
+ self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
+ raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
+ return self.convert_raw_data_feature_spec_to_dataset_metadata(
+ raw_data_feature_spec)
+
+ def write_transform_artifacts(self, transform_fn, location):
+ """
+ Write transform artifacts to the given location.
+ Args:
+ transform_fn: A transform_fn object.
+ location: A location to write the artifacts.
+ Returns:
+ A PCollection of WriteTransformFn writing a TF transform graph.
+ """
+ return (
+ transform_fn
+ | 'Write Transform Artifacts' >>
+ transform_fn_io.WriteTransformFn(location))
+
+ def _fail_on_non_default_windowing(self, pcoll: beam.PCollection):
+ if not pcoll.windowing.is_default():
+ raise RuntimeError(
+ "TFTProcessHandler only supports GlobalWindows when producing "
+ "artifacts such as min, max, variance etc over the dataset."
+ "Please use beam.WindowInto(beam.transforms.window.GlobalWindows()) "
+ "to convert your PCollection to GlobalWindow.")
+
+ def process_data_fn(
+ self, inputs: Dict[str, common_types.ConsistentTensorType]
+ ) -> Dict[str, common_types.ConsistentTensorType]:
+ """
+ This method is used in the AnalyzeAndTransformDataset step. It applies
+ the transforms to the `inputs` in sequential order on the columns
+ provided for a given transform.
+ Args:
+ inputs: A dictionary of column names and data.
+ Returns:
+ A dictionary of column names and transformed data.
+ """
+ outputs = inputs.copy()
+ for transform in self.transforms:
+ columns = transform.columns
+ for col in columns:
+ intermediate_result = transform.apply(
+ outputs[col], output_column_name=col)
+ for key, value in intermediate_result.items():
+ outputs[key] = value
+ return outputs
+
+ def _get_transformed_data_schema(
+ self,
+ metadata: dataset_metadata.DatasetMetadata,
+ ) -> Dict[str, typing.Sequence[typing.Union[np.float32, np.int64, bytes]]]:
+ schema = metadata._schema
+ transformed_types = {}
+ logging.info("Schema: %s", schema)
+ for feature in schema.feature:
+ name = feature.name
+ feature_type = feature.type
+ if feature_type == schema_pb2.FeatureType.FLOAT:
+ transformed_types[name] = typing.Sequence[np.float32]
+ elif feature_type == schema_pb2.FeatureType.INT:
+ transformed_types[name] = typing.Sequence[np.int64]
+ elif feature_type == schema_pb2.FeatureType.BYTES:
+ transformed_types[name] = typing.Sequence[bytes]
+ else:
+ # TODO: This else condition won't be hit since TFT doesn't output
+ # other than float, int and bytes. Refactor the code here.
+ raise RuntimeError(
+ 'Unsupported feature type: %s encountered' % feature_type)
+ logging.info(transformed_types)
+ return transformed_types
+
+ def process_data(
+ self, raw_data: beam.PCollection[tft_process_handler_input_type]
+ ) -> beam.PCollection[typing.Union[
+ beam.Row, Dict[str, np.ndarray], pa.RecordBatch]]:
+ """
+ This method also computes the required dataset metadata for the tft
+ AnalyzeDataset/TransformDataset step.
+
+ This method uses tensorflow_transform's Analyze step to produce the
+ artifacts and Transform step to apply the transforms on the data.
+ Artifacts are only produced if the artifact_mode is set to `produce`.
+ If artifact_mode is set to `consume`, then the artifacts are read from the
+ artifact_location, which was previously used to store the produced
+ artifacts.
+ """
+ if self.artifact_mode == ArtifactMode.PRODUCE:
+ # If we are computing artifacts, we should fail for windows other than
+ # default windowing since for example, for a fixed window, each window
can
+ # be treated as a separate dataset and we might need to compute artifacts
+ # for each window. This is not supported yet.
+ self._fail_on_non_default_windowing(raw_data)
+ element_type = raw_data.element_type
+ column_type_mapping = {}
+ if (isinstance(element_type, RowTypeConstraint) or
+ native_type_compatibility.match_is_named_tuple(element_type)):
+ column_type_mapping = self._map_column_names_to_types(
+ row_type=element_type)
+ # convert Row or NamedTuple to Dict
+ raw_data = (
+ raw_data
+ | ConvertNamedTupleToDict().with_output_types(
+ Dict[str, typing.Union[tuple(column_type_mapping.values())]]))
+ # AnalyzeAndTransformDataset raise type hint since this is
+ # schema'd PCollection and the current output type would be a
+ # custom type(NamedTuple) or a beam.Row type.
+ else:
+ column_type_mapping = self._map_column_names_to_types_from_transforms()
+ raw_data_metadata = self.get_raw_data_metadata(
+ input_types=column_type_mapping)
+ # Write untransformed metadata to a file so that it can be re-used
+ # during Transform step.
+ metadata_io.write_metadata(
+ metadata=raw_data_metadata,
+ path=os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+ else:
+ # Read the metadata from the artifact_location.
+ if not os.path.exists(os.path.join(
+ self.artifact_location, RAW_DATA_METADATA_DIR, SCHEMA_FILE)):
+ raise FileNotFoundError(
+ "Raw data metadata not found at %s" %
Review Comment:
We should be more descriptive in these errors (something along the lines of
"you're running in consume mode" what that means, and "have you run this in
produce mode?")
##########
sdks/python/apache_beam/ml/transforms/tft_transforms.py:
##########
@@ -0,0 +1,442 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This module defines a set of data processing transforms that can be used
+to perform common data transformations on a dataset. These transforms are
+implemented using the TensorFlow Transform (TFT) library. The transforms
+in this module are intended to be used in conjunction with the
+beam.ml.MLTransform class, which provides a convenient interface for
+applying a sequence of data processing transforms to a dataset with the
+help of the TFTProcessHandler class.
+
+See the documentation for beam.ml.MLTransform for more details.
+
+Since the transforms in this module are implemented using TFT, they
+should be wrapped inside a TFTProcessHandler object before being passed
+to the beam.ml.MLTransform class. The TFTProcessHandler will let MLTransform
+know which type of input is expected and infers the relevant schema required
+for the TFT library.
Review Comment:
This is outdated
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import typing
+from typing import Dict
+from typing import Optional
+from typing import TypeVar
+from typing import Generic
+
+import apache_beam as beam
+
+# TODO: Abstract methods are not getting pickled with dill.
+# https://github.com/uqfoundation/dill/issues/332
+# import abc
+
+__all__ = ['MLTransform', 'MLTransformOutput', 'ProcessHandler']
+
+TransformedDatasetT = TypeVar('TransformedDatasetT')
+TransformedMetadataT = TypeVar('TransformedMetadataT')
+
+# Input/Output types to the MLTransform.
+ExampleT = TypeVar('ExampleT')
+MLTransformOutputT = TypeVar('MLTransformOutputT')
+
+# Input to the process data. This could be same or different from ExampleT.
Review Comment:
I still am wondering about this one
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -0,0 +1,165 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Generic
+from typing import List
+from typing import Optional
+from typing import TypeVar
+
+import apache_beam as beam
+
+# TODO: Abstract methods are not getting pickled with dill.
+# https://github.com/uqfoundation/dill/issues/332
+# import abc
+
+__all__ = ['MLTransform']
+
+TransformedDatasetT = TypeVar('TransformedDatasetT')
+TransformedMetadataT = TypeVar('TransformedMetadataT')
+
+# Input/Output types to the MLTransform.
+ExampleT = TypeVar('ExampleT')
+MLTransformOutputT = TypeVar('MLTransformOutputT')
+
+# Input to the process data. This could be same or different from ExampleT.
+ProcessInputT = TypeVar('ProcessInputT')
+# Output of the process data. This could be same or different
+# from MLTransformOutputT
+ProcessOutputT = TypeVar('ProcessOutputT')
+
+# Input to the apply() method of BaseOperation.
+OperationInputT = TypeVar('OperationInputT')
+# Output of the apply() method of BaseOperation.
+OperationOutputT = TypeVar('OperationOutputT')
+
+
+class ArtifactMode(object):
+ PRODUCE = 'produce'
+ CONSUME = 'consume'
+
+
+class BaseOperation(Generic[OperationInputT, OperationOutputT]):
+ def apply(
+ self, inputs: OperationInputT, column_name: str, *args,
+ **kwargs) -> OperationOutputT:
+ """
+ Define any processing logic in the apply() method.
+ processing logics are applied on inputs and returns a transformed
+ output.
+ Args:
+ inputs: input data.
+ """
+ raise NotImplementedError
+
+
+class _ProcessHandler(Generic[ProcessInputT, ProcessOutputT]):
+ """
+ Only for internal use. No backwards compatibility guarantees.
+ """
+ def process_data(
+ self, pcoll: beam.PCollection[ProcessInputT]
+ ) -> beam.PCollection[ProcessOutputT]:
+ """
+ Logic to process the data. This will be the entrypoint in
+ beam.MLTransform to process incoming data.
+ """
+ raise NotImplementedError
+
+ def append_transform(self, transform: BaseOperation):
+ raise NotImplementedError
+
+
+class MLTransform(beam.PTransform[beam.PCollection[ExampleT],
+ beam.PCollection[MLTransformOutputT]],
+ Generic[ExampleT, MLTransformOutputT]):
+ def __init__(
+ self,
+ *,
+ artifact_location: str,
+ artifact_mode: str = ArtifactMode.PRODUCE,
+ transforms: Optional[List[BaseOperation]] = None,
+ is_input_record_batches: bool = False,
+ output_record_batches: bool = False,
Review Comment:
If we need something like this in `MLTransform` itself, I'd significantly
prefer a public `TFTProcessHandler` to take in that config. These options will
be meaningless for most other frameworks we might support.
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -0,0 +1,165 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Generic
+from typing import List
+from typing import Optional
+from typing import TypeVar
+
+import apache_beam as beam
+
+# TODO: Abstract methods are not getting pickled with dill.
+# https://github.com/uqfoundation/dill/issues/332
+# import abc
+
+__all__ = ['MLTransform']
+
+TransformedDatasetT = TypeVar('TransformedDatasetT')
+TransformedMetadataT = TypeVar('TransformedMetadataT')
+
+# Input/Output types to the MLTransform.
+ExampleT = TypeVar('ExampleT')
+MLTransformOutputT = TypeVar('MLTransformOutputT')
+
+# Input to the process data. This could be same or different from ExampleT.
+ProcessInputT = TypeVar('ProcessInputT')
+# Output of the process data. This could be same or different
+# from MLTransformOutputT
+ProcessOutputT = TypeVar('ProcessOutputT')
+
+# Input to the apply() method of BaseOperation.
+OperationInputT = TypeVar('OperationInputT')
+# Output of the apply() method of BaseOperation.
+OperationOutputT = TypeVar('OperationOutputT')
+
+
+class ArtifactMode(object):
+ PRODUCE = 'produce'
+ CONSUME = 'consume'
+
+
+class BaseOperation(Generic[OperationInputT, OperationOutputT]):
+ def apply(
+ self, inputs: OperationInputT, column_name: str, *args,
+ **kwargs) -> OperationOutputT:
+ """
+ Define any processing logic in the apply() method.
+ processing logics are applied on inputs and returns a transformed
+ output.
+ Args:
+ inputs: input data.
+ """
+ raise NotImplementedError
+
+
+class _ProcessHandler(Generic[ProcessInputT, ProcessOutputT]):
+ """
+ Only for internal use. No backwards compatibility guarantees.
+ """
+ def process_data(
+ self, pcoll: beam.PCollection[ProcessInputT]
+ ) -> beam.PCollection[ProcessOutputT]:
+ """
+ Logic to process the data. This will be the entrypoint in
+ beam.MLTransform to process incoming data.
+ """
+ raise NotImplementedError
+
+ def append_transform(self, transform: BaseOperation):
+ raise NotImplementedError
+
+
+class MLTransform(beam.PTransform[beam.PCollection[ExampleT],
+ beam.PCollection[MLTransformOutputT]],
+ Generic[ExampleT, MLTransformOutputT]):
+ def __init__(
+ self,
+ *,
+ artifact_location: str,
+ artifact_mode: str = ArtifactMode.PRODUCE,
+ transforms: Optional[List[BaseOperation]] = None,
+ is_input_record_batches: bool = False,
+ output_record_batches: bool = False,
+ ):
+ """
+ Args:
+ artifact_location: A storage location for artifacts resulting from
+ MLTransform. These artifacts include transformations applied to
+ the dataset and generated values like min, max from ScaleTo01,
+ and mean, var from ScaleToZScore. Artifacts are produced and stored
+ in this location when the `artifact_mode` is set to 'produce'.
+ Conversely, when `artifact_mode` is set to 'consume', artifacts are
+ retrieved from this location. Note that when consuming artifacts,
+ it is not necessary to pass the transforms since they are inherently
+ stored within the artifacts themselves. The value assigned to
+ `artifact_location` should be a valid storage path where the artifacts
+ can be written to or read from.
+ transforms: A list of transforms to apply to the data. All the transforms
+ are applied in the order they are specified. The input of the
+ i-th transform is the output of the (i-1)-th transform. Multi-input
+ transforms are not supported yet.
+ is_input_record_batches: Whether the input is a RecordBatch.
+ output_record_batches: Output RecordBatches instead of beam.Row().
+ artifact_mode: Whether to produce or consume artifacts. If set to
+ 'consume', the handler will assume that the artifacts are already
+ computed and stored in the artifact_location. Pass the same artifact
+ location that was passed during produce phase to ensure that the
+ right artifacts are read. If set to 'produce', the handler
+ will compute the artifacts and store them in the artifact_location.
+ The artifacts will be read from this location during the consume phase.
+ There is no need to pass the transforms in this case since they are
+ already embedded in the stored artifacts.
+ """
+ # avoid circular import
+ # pylint: disable=wrong-import-order, wrong-import-position
+ from apache_beam.ml.transforms.handlers import TFTProcessHandler
+ process_handler = TFTProcessHandler(
Review Comment:
Might be worth dropping a TODO here. Eventually, we'll want to map
transforms to process_handlers (so if you say `ComputeAndApplyVocabulary` we
use the `TFTProcessHandler`, if you say `MyCoolJaxFn` we use
`JaxProcessHandler`). Don't need to implement anything, but it would be good to
be clear on our path forward there.
##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -0,0 +1,436 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import collections
+import logging
+import os
+import tempfile
+import typing
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import numpy as np
+
+import apache_beam as beam
+from apache_beam.ml.transforms.base import ArtifactMode
+from apache_beam.ml.transforms.base import _ProcessHandler
+from apache_beam.ml.transforms.base import ProcessInputT
+from apache_beam.ml.transforms.base import ProcessOutputT
+from apache_beam.ml.transforms.tft_transforms import TFTOperation
+from apache_beam.ml.transforms.tft_transforms import _EXPECTED_TYPES
+from apache_beam.typehints import native_type_compatibility
+from apache_beam.typehints.row_type import RowTypeConstraint
+import pyarrow as pa
+import tensorflow as tf
+from tensorflow_metadata.proto.v0 import schema_pb2
+import tensorflow_transform.beam as tft_beam
+from tensorflow_transform import common_types
+from tensorflow_transform.beam.tft_beam_io import beam_metadata_io
+from tensorflow_transform.beam.tft_beam_io import transform_fn_io
+from tensorflow_transform.tf_metadata import dataset_metadata
+from tensorflow_transform.tf_metadata import metadata_io
+from tensorflow_transform.tf_metadata import schema_utils
+from tfx_bsl.tfxio import tf_example_record
+
+__all__ = [
+ 'TFTProcessHandler',
+]
+
+RAW_DATA_METADATA_DIR = 'raw_data_metadata'
+SCHEMA_FILE = 'schema.pbtxt'
+# tensorflow transform doesn't support the types other than tf.int64,
+# tf.float32 and tf.string.
+_default_type_to_tensor_type_map = {
+ int: tf.int64,
+ float: tf.float32,
+ str: tf.string,
+ bytes: tf.string,
+ np.int64: tf.int64,
+ np.int32: tf.int64,
+ np.float32: tf.float32,
+ np.float64: tf.float32,
+ np.bytes_: tf.string,
+ np.str_: tf.string,
+}
+_primitive_types_to_typing_container_type = {
+ int: List[int], float: List[float], str: List[str], bytes: List[bytes]
+}
+
+tft_process_handler_input_type = typing.Union[typing.NamedTuple,
+ beam.Row,
+ Dict[str,
+ typing.Union[str,
+ float,
+ int,
+ bytes,
+ np.ndarray]]]
+
+
+class ConvertScalarValuesToListValues(beam.DoFn):
+ def process(
+ self, element: Dict[str, typing.Any]
+ ) -> typing.Iterable[Dict[str, typing.List[typing.Any]]]:
+ new_dict = {}
+ for key, value in element.items():
+ if isinstance(value,
+ tuple(_primitive_types_to_typing_container_type.keys())):
+ new_dict[key] = [value]
+ else:
+ new_dict[key] = value
+ yield new_dict
+
+
+class ConvertNamedTupleToDict(
+ beam.PTransform[beam.PCollection[typing.Union[beam.Row,
typing.NamedTuple]],
+ beam.PCollection[Dict[str,
+ common_types.InstanceDictType]]]):
+ """
+ A PTransform that converts a collection of NamedTuples or Rows into a
+ collection of dictionaries.
+ """
+ def expand(
+ self, pcoll: beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]]
+ ) -> beam.PCollection[common_types.InstanceDictType]:
+ """
+ Args:
+ pcoll: A PCollection of NamedTuples or Rows.
+ Returns:
+ A PCollection of dictionaries.
+ """
+ if isinstance(pcoll.element_type, RowTypeConstraint):
+ # Row instance
+ return pcoll | beam.Map(lambda x: x.as_dict())
+ else:
+ # named tuple
+ return pcoll | beam.Map(lambda x: x._asdict())
+
+
+class TFTProcessHandler(_ProcessHandler[ProcessInputT, ProcessOutputT]):
+ def __init__(
+ self,
+ *,
+ artifact_location: str = None,
+ transforms: Optional[List[TFTOperation]] = None,
+ preprocessing_fn: typing.Optional[typing.Callable] = None,
+ is_input_record_batches: bool = False,
+ output_record_batches: bool = False,
+ artifact_mode: str = ArtifactMode.PRODUCE):
+ """
+ A handler class for processing data with TensorFlow Transform (TFT)
+ operations. This class is intended to be subclassed, with subclasses
+ implementing the `preprocessing_fn` method.
+ """
+ self.transforms = transforms if transforms else []
+ self.transformed_schema = None
+ self.artifact_location = artifact_location
+ self.preprocessing_fn = preprocessing_fn
+ self.is_input_record_batches = is_input_record_batches
+ self.output_record_batches = output_record_batches
+ self.artifact_mode = artifact_mode
+ if artifact_mode not in ['produce', 'consume']:
+ raise ValueError('artifact_mode must be either `produce` or `consume`.')
+
+ if not self.artifact_location:
+ self.artifact_location = tempfile.mkdtemp()
+
+ def append_transform(self, transform):
+ self.transforms.append(transform)
+
+ def _map_column_names_to_types(self, row_type):
+ """
+ Return a dictionary of column names and types.
+ Args:
+ element_type: A type of the element. This could be a NamedTuple or a Row.
+ Returns:
+ A dictionary of column names and types.
+ """
+ try:
+ if not isinstance(row_type, RowTypeConstraint):
+ row_type = RowTypeConstraint.from_user_type(row_type)
+
+ inferred_types = {name: typ for name, typ in row_type._fields}
+
+ for k, t in inferred_types.items():
+ if t in _primitive_types_to_typing_container_type:
+ inferred_types[k] = _primitive_types_to_typing_container_type[t]
+
+ # sometimes a numpy type can be provided as np.dtype('int64').
+ # convert numpy.dtype to numpy type since both are same.
+ for name, typ in inferred_types.items():
+ if isinstance(typ, np.dtype):
+ inferred_types[name] = typ.type
+
+ return inferred_types
+ except: # pylint: disable=bare-except
+ return {}
+
+ def _map_column_names_to_types_from_transforms(self):
+ column_type_mapping = {}
+ for transform in self.transforms:
+ for col in transform.columns:
+ if col not in column_type_mapping:
+ # we just need to dtype of first occurance of column in transforms.
+ class_name = transform.__class__.__name__
+ if class_name not in _EXPECTED_TYPES:
+ raise KeyError(
+ f"Transform {class_name} is not registered with a supported "
+ "type. Please register the transform with a supported type "
+ "using register_input_dtype decorator.")
+ column_type_mapping[col] = _EXPECTED_TYPES[
+ transform.__class__.__name__]
+ return column_type_mapping
+
+ def get_raw_data_feature_spec(
+ self, input_types: Dict[str, type]) -> Dict[str, tf.io.VarLenFeature]:
+ """
+ Return a DatasetMetadata object to be used with
+ tft_beam.AnalyzeAndTransformDataset.
+ Args:
+ input_types: A dictionary of column names and types.
+ Returns:
+ A DatasetMetadata object.
+ """
+ raw_data_feature_spec = {}
+ for key, value in input_types.items():
+ raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column(
+ typ=value, col_name=key)
+ return raw_data_feature_spec
+
+ def convert_raw_data_feature_spec_to_dataset_metadata(
+ self, raw_data_feature_spec) -> dataset_metadata.DatasetMetadata:
+ raw_data_metadata = dataset_metadata.DatasetMetadata(
+ schema_utils.schema_from_feature_spec(raw_data_feature_spec))
+ return raw_data_metadata
+
+ def _get_raw_data_feature_spec_per_column(
+ self, typ: type, col_name: str) -> tf.io.VarLenFeature:
+ """
+ Return a FeatureSpec object to be used with
+ tft_beam.AnalyzeAndTransformDataset
+ Args:
+ typ: A type of the column.
+ col_name: A name of the column.
+ Returns:
+ A FeatureSpec object.
+ """
+ # lets conver the builtin types to typing types for consistency.
+ typ = native_type_compatibility.convert_builtin_to_typing(typ)
+ primitive_containers_type = (
+ list,
+ collections.abc.Sequence,
+ )
+ is_primitive_container = (
+ typing.get_origin(typ) in primitive_containers_type)
+
+ if is_primitive_container:
+ dtype = typing.get_args(typ)[0] # type: ignore[attr-defined]
+ if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union:
# type: ignore[attr-defined]
+ raise RuntimeError(
+ f"Union type is not supported for column: {col_name}. "
+ f"Please pass a PCollection with valid schema for column "
+ f"{col_name} by passing a single type "
+ "in container. For example, List[int].")
+ elif issubclass(typ, np.generic) or typ in
_default_type_to_tensor_type_map:
+ dtype = typ
+ else:
+ raise TypeError(
+ f"Unable to identify type: {typ} specified on column: {col_name}. "
+ f"Please provide a valid type from the following: "
+ f"{_default_type_to_tensor_type_map.keys()}")
+ return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype])
+
+ def get_raw_data_metadata(
+ self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
+ raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
+ return self.convert_raw_data_feature_spec_to_dataset_metadata(
+ raw_data_feature_spec)
+
+ def write_transform_artifacts(self, transform_fn, location):
+ """
+ Write transform artifacts to the given location.
+ Args:
+ transform_fn: A transform_fn object.
+ location: A location to write the artifacts.
+ Returns:
+ A PCollection of WriteTransformFn writing a TF transform graph.
+ """
+ return (
+ transform_fn
+ | 'Write Transform Artifacts' >>
+ transform_fn_io.WriteTransformFn(location))
+
+ def _fail_on_non_default_windowing(self, pcoll: beam.PCollection):
+ if not pcoll.windowing.is_default():
+ raise RuntimeError(
+ "TFTProcessHandler only supports GlobalWindows when producing "
Review Comment:
Nit: errors shouldn't explicitly reference `TFTProcessHandler` anymore,
instead we should provide the transform name where we ran into this (or just
the list of transform names)
##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -0,0 +1,436 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import collections
+import logging
+import os
+import tempfile
+import typing
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import numpy as np
+
+import apache_beam as beam
+from apache_beam.ml.transforms.base import ArtifactMode
+from apache_beam.ml.transforms.base import _ProcessHandler
+from apache_beam.ml.transforms.base import ProcessInputT
+from apache_beam.ml.transforms.base import ProcessOutputT
+from apache_beam.ml.transforms.tft_transforms import TFTOperation
+from apache_beam.ml.transforms.tft_transforms import _EXPECTED_TYPES
+from apache_beam.typehints import native_type_compatibility
+from apache_beam.typehints.row_type import RowTypeConstraint
+import pyarrow as pa
+import tensorflow as tf
+from tensorflow_metadata.proto.v0 import schema_pb2
+import tensorflow_transform.beam as tft_beam
+from tensorflow_transform import common_types
+from tensorflow_transform.beam.tft_beam_io import beam_metadata_io
+from tensorflow_transform.beam.tft_beam_io import transform_fn_io
+from tensorflow_transform.tf_metadata import dataset_metadata
+from tensorflow_transform.tf_metadata import metadata_io
+from tensorflow_transform.tf_metadata import schema_utils
+from tfx_bsl.tfxio import tf_example_record
+
+__all__ = [
+ 'TFTProcessHandler',
+]
+
+RAW_DATA_METADATA_DIR = 'raw_data_metadata'
+SCHEMA_FILE = 'schema.pbtxt'
+# tensorflow transform doesn't support the types other than tf.int64,
+# tf.float32 and tf.string.
+_default_type_to_tensor_type_map = {
+ int: tf.int64,
+ float: tf.float32,
+ str: tf.string,
+ bytes: tf.string,
+ np.int64: tf.int64,
+ np.int32: tf.int64,
+ np.float32: tf.float32,
+ np.float64: tf.float32,
+ np.bytes_: tf.string,
+ np.str_: tf.string,
+}
+_primitive_types_to_typing_container_type = {
+ int: List[int], float: List[float], str: List[str], bytes: List[bytes]
+}
+
+tft_process_handler_input_type = typing.Union[typing.NamedTuple,
+ beam.Row,
+ Dict[str,
+ typing.Union[str,
+ float,
+ int,
+ bytes,
+ np.ndarray]]]
+
+
+class ConvertScalarValuesToListValues(beam.DoFn):
+ def process(
+ self, element: Dict[str, typing.Any]
+ ) -> typing.Iterable[Dict[str, typing.List[typing.Any]]]:
+ new_dict = {}
+ for key, value in element.items():
+ if isinstance(value,
+ tuple(_primitive_types_to_typing_container_type.keys())):
+ new_dict[key] = [value]
+ else:
+ new_dict[key] = value
+ yield new_dict
+
+
+class ConvertNamedTupleToDict(
+ beam.PTransform[beam.PCollection[typing.Union[beam.Row,
typing.NamedTuple]],
+ beam.PCollection[Dict[str,
+ common_types.InstanceDictType]]]):
+ """
+ A PTransform that converts a collection of NamedTuples or Rows into a
+ collection of dictionaries.
+ """
+ def expand(
+ self, pcoll: beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]]
+ ) -> beam.PCollection[common_types.InstanceDictType]:
+ """
+ Args:
+ pcoll: A PCollection of NamedTuples or Rows.
+ Returns:
+ A PCollection of dictionaries.
+ """
+ if isinstance(pcoll.element_type, RowTypeConstraint):
+ # Row instance
+ return pcoll | beam.Map(lambda x: x.as_dict())
+ else:
+ # named tuple
+ return pcoll | beam.Map(lambda x: x._asdict())
+
+
+class TFTProcessHandler(_ProcessHandler[ProcessInputT, ProcessOutputT]):
+ def __init__(
+ self,
+ *,
+ artifact_location: str = None,
+ transforms: Optional[List[TFTOperation]] = None,
+ preprocessing_fn: typing.Optional[typing.Callable] = None,
+ is_input_record_batches: bool = False,
+ output_record_batches: bool = False,
+ artifact_mode: str = ArtifactMode.PRODUCE):
+ """
+ A handler class for processing data with TensorFlow Transform (TFT)
+ operations. This class is intended to be subclassed, with subclasses
+ implementing the `preprocessing_fn` method.
+ """
+ self.transforms = transforms if transforms else []
+ self.transformed_schema = None
+ self.artifact_location = artifact_location
+ self.preprocessing_fn = preprocessing_fn
+ self.is_input_record_batches = is_input_record_batches
+ self.output_record_batches = output_record_batches
+ self.artifact_mode = artifact_mode
+ if artifact_mode not in ['produce', 'consume']:
+ raise ValueError('artifact_mode must be either `produce` or `consume`.')
+
+ if not self.artifact_location:
+ self.artifact_location = tempfile.mkdtemp()
+
+ def append_transform(self, transform):
+ self.transforms.append(transform)
+
+ def _map_column_names_to_types(self, row_type):
+ """
+ Return a dictionary of column names and types.
+ Args:
+ element_type: A type of the element. This could be a NamedTuple or a Row.
+ Returns:
+ A dictionary of column names and types.
+ """
+ try:
+ if not isinstance(row_type, RowTypeConstraint):
+ row_type = RowTypeConstraint.from_user_type(row_type)
+
+ inferred_types = {name: typ for name, typ in row_type._fields}
+
+ for k, t in inferred_types.items():
+ if t in _primitive_types_to_typing_container_type:
+ inferred_types[k] = _primitive_types_to_typing_container_type[t]
+
+ # sometimes a numpy type can be provided as np.dtype('int64').
+ # convert numpy.dtype to numpy type since both are same.
+ for name, typ in inferred_types.items():
+ if isinstance(typ, np.dtype):
+ inferred_types[name] = typ.type
+
+ return inferred_types
+ except: # pylint: disable=bare-except
+ return {}
+
+ def _map_column_names_to_types_from_transforms(self):
+ column_type_mapping = {}
+ for transform in self.transforms:
+ for col in transform.columns:
+ if col not in column_type_mapping:
+ # we just need to dtype of first occurance of column in transforms.
+ class_name = transform.__class__.__name__
+ if class_name not in _EXPECTED_TYPES:
+ raise KeyError(
+ f"Transform {class_name} is not registered with a supported "
+ "type. Please register the transform with a supported type "
+ "using register_input_dtype decorator.")
+ column_type_mapping[col] = _EXPECTED_TYPES[
+ transform.__class__.__name__]
+ return column_type_mapping
+
+ def get_raw_data_feature_spec(
+ self, input_types: Dict[str, type]) -> Dict[str, tf.io.VarLenFeature]:
+ """
+ Return a DatasetMetadata object to be used with
+ tft_beam.AnalyzeAndTransformDataset.
+ Args:
+ input_types: A dictionary of column names and types.
+ Returns:
+ A DatasetMetadata object.
+ """
+ raw_data_feature_spec = {}
+ for key, value in input_types.items():
+ raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column(
+ typ=value, col_name=key)
+ return raw_data_feature_spec
+
+ def convert_raw_data_feature_spec_to_dataset_metadata(
+ self, raw_data_feature_spec) -> dataset_metadata.DatasetMetadata:
+ raw_data_metadata = dataset_metadata.DatasetMetadata(
+ schema_utils.schema_from_feature_spec(raw_data_feature_spec))
+ return raw_data_metadata
+
+ def _get_raw_data_feature_spec_per_column(
+ self, typ: type, col_name: str) -> tf.io.VarLenFeature:
+ """
+ Return a FeatureSpec object to be used with
+ tft_beam.AnalyzeAndTransformDataset
+ Args:
+ typ: A type of the column.
+ col_name: A name of the column.
+ Returns:
+ A FeatureSpec object.
+ """
+ # lets conver the builtin types to typing types for consistency.
+ typ = native_type_compatibility.convert_builtin_to_typing(typ)
+ primitive_containers_type = (
+ list,
+ collections.abc.Sequence,
+ )
+ is_primitive_container = (
+ typing.get_origin(typ) in primitive_containers_type)
+
+ if is_primitive_container:
+ dtype = typing.get_args(typ)[0] # type: ignore[attr-defined]
+ if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union:
# type: ignore[attr-defined]
+ raise RuntimeError(
+ f"Union type is not supported for column: {col_name}. "
+ f"Please pass a PCollection with valid schema for column "
+ f"{col_name} by passing a single type "
+ "in container. For example, List[int].")
+ elif issubclass(typ, np.generic) or typ in
_default_type_to_tensor_type_map:
+ dtype = typ
+ else:
+ raise TypeError(
+ f"Unable to identify type: {typ} specified on column: {col_name}. "
+ f"Please provide a valid type from the following: "
+ f"{_default_type_to_tensor_type_map.keys()}")
+ return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype])
+
+ def get_raw_data_metadata(
+ self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
+ raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
+ return self.convert_raw_data_feature_spec_to_dataset_metadata(
+ raw_data_feature_spec)
+
+ def write_transform_artifacts(self, transform_fn, location):
+ """
+ Write transform artifacts to the given location.
+ Args:
+ transform_fn: A transform_fn object.
+ location: A location to write the artifacts.
+ Returns:
+ A PCollection of WriteTransformFn writing a TF transform graph.
+ """
+ return (
+ transform_fn
+ | 'Write Transform Artifacts' >>
+ transform_fn_io.WriteTransformFn(location))
+
+ def _fail_on_non_default_windowing(self, pcoll: beam.PCollection):
+ if not pcoll.windowing.is_default():
+ raise RuntimeError(
+ "TFTProcessHandler only supports GlobalWindows when producing "
+ "artifacts such as min, max, variance etc over the dataset."
+ "Please use beam.WindowInto(beam.transforms.window.GlobalWindows()) "
+ "to convert your PCollection to GlobalWindow.")
+
+ def process_data_fn(
+ self, inputs: Dict[str, common_types.ConsistentTensorType]
+ ) -> Dict[str, common_types.ConsistentTensorType]:
+ """
+ This method is used in the AnalyzeAndTransformDataset step. It applies
+ the transforms to the `inputs` in sequential order on the columns
+ provided for a given transform.
+ Args:
+ inputs: A dictionary of column names and data.
+ Returns:
+ A dictionary of column names and transformed data.
+ """
+ outputs = inputs.copy()
+ for transform in self.transforms:
+ columns = transform.columns
+ for col in columns:
+ intermediate_result = transform.apply(
+ outputs[col], output_column_name=col)
+ for key, value in intermediate_result.items():
+ outputs[key] = value
+ return outputs
+
+ def _get_transformed_data_schema(
+ self,
+ metadata: dataset_metadata.DatasetMetadata,
+ ) -> Dict[str, typing.Sequence[typing.Union[np.float32, np.int64, bytes]]]:
+ schema = metadata._schema
+ transformed_types = {}
+ logging.info("Schema: %s", schema)
+ for feature in schema.feature:
+ name = feature.name
+ feature_type = feature.type
+ if feature_type == schema_pb2.FeatureType.FLOAT:
+ transformed_types[name] = typing.Sequence[np.float32]
+ elif feature_type == schema_pb2.FeatureType.INT:
+ transformed_types[name] = typing.Sequence[np.int64]
+ elif feature_type == schema_pb2.FeatureType.BYTES:
+ transformed_types[name] = typing.Sequence[bytes]
+ else:
+ # TODO: This else condition won't be hit since TFT doesn't output
+ # other than float, int and bytes. Refactor the code here.
+ raise RuntimeError(
+ 'Unsupported feature type: %s encountered' % feature_type)
+ logging.info(transformed_types)
+ return transformed_types
+
+ def process_data(
+ self, raw_data: beam.PCollection[tft_process_handler_input_type]
+ ) -> beam.PCollection[typing.Union[
+ beam.Row, Dict[str, np.ndarray], pa.RecordBatch]]:
+ """
+ This method also computes the required dataset metadata for the tft
+ AnalyzeDataset/TransformDataset step.
+
+ This method uses tensorflow_transform's Analyze step to produce the
+ artifacts and Transform step to apply the transforms on the data.
+ Artifacts are only produced if the artifact_mode is set to `produce`.
+ If artifact_mode is set to `consume`, then the artifacts are read from the
+ artifact_location, which was previously used to store the produced
+ artifacts.
+ """
+ if self.artifact_mode == ArtifactMode.PRODUCE:
+ # If we are computing artifacts, we should fail for windows other than
+ # default windowing since for example, for a fixed window, each window
can
+ # be treated as a separate dataset and we might need to compute artifacts
+ # for each window. This is not supported yet.
+ self._fail_on_non_default_windowing(raw_data)
+ element_type = raw_data.element_type
+ column_type_mapping = {}
+ if (isinstance(element_type, RowTypeConstraint) or
+ native_type_compatibility.match_is_named_tuple(element_type)):
+ column_type_mapping = self._map_column_names_to_types(
+ row_type=element_type)
+ # convert Row or NamedTuple to Dict
+ raw_data = (
+ raw_data
+ | ConvertNamedTupleToDict().with_output_types(
+ Dict[str, typing.Union[tuple(column_type_mapping.values())]]))
+ # AnalyzeAndTransformDataset raise type hint since this is
+ # schema'd PCollection and the current output type would be a
+ # custom type(NamedTuple) or a beam.Row type.
+ else:
+ column_type_mapping = self._map_column_names_to_types_from_transforms()
+ raw_data_metadata = self.get_raw_data_metadata(
+ input_types=column_type_mapping)
+ # Write untransformed metadata to a file so that it can be re-used
+ # during Transform step.
+ metadata_io.write_metadata(
+ metadata=raw_data_metadata,
+ path=os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+ else:
+ # Read the metadata from the artifact_location.
+ if not os.path.exists(os.path.join(
+ self.artifact_location, RAW_DATA_METADATA_DIR, SCHEMA_FILE)):
+ raise FileNotFoundError(
+ "Raw data metadata not found at %s" %
+ os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+ raw_data_metadata = metadata_io.read_metadata(
+ os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+
+ # To maintain consistency by outputting numpy array all the time,
+ # whether a scalar value or list or np array is passed as input,
+ # we will convert scalar values to list values and TFT will ouput
+ # numpy array all the time.
+ if not self.is_input_record_batches:
+ raw_data |= beam.ParDo(ConvertScalarValuesToListValues())
Review Comment:
The main advantage is that if we can kill this, it defers the problem of
having framework specific inputs. With that said, I think we need an idea of
how we'll solve that regardless
##########
sdks/python/apache_beam/ml/transforms/handlers_test.py:
##########
@@ -0,0 +1,380 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import typing
+from typing import NamedTuple
+from typing import List
+from typing import Union
+
+import unittest
+import numpy as np
+from parameterized import param
+from parameterized import parameterized
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=wrong-import-position, ungrouped-imports
+try:
+ from apache_beam.ml.transforms import base
+ from apache_beam.ml.transforms import handlers
+ from apache_beam.ml.transforms import tft_transforms
+ from apache_beam.ml.transforms.tft_transforms import TFTOperation
+ import tensorflow as tf
+except ImportError:
+ tft_transforms = None
+
+if not tft_transforms:
+ raise unittest.SkipTest('tensorflow_transform is not installed.')
+
+
+class _FakeOperation(TFTOperation):
+ def __init__(self, name, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.name = name
+
+ def apply(self, inputs, output_column_name, **kwargs):
+ return {output_column_name: inputs}
+
+
+class _AddOperation(TFTOperation):
+ def apply(self, inputs, output_column_name, **kwargs):
+ return {output_column_name: inputs + 1}
+
+
+class _MultiplyOperation(TFTOperation):
+ def apply(self, inputs, output_column_name, **kwargs):
+ return {output_column_name: inputs * 10}
+
+
+class _FakeOperationWithArtifacts(TFTOperation):
+ def apply(self, inputs, output_column_name, **kwargs):
+ return {
+ **{
+ output_column_name: inputs
+ },
+ **(self.get_artifacts(inputs, 'artifact'))
+ }
+
+ def get_artifacts(self, data, col_name):
+ return {'artifact': tf.convert_to_tensor([1])}
+
+
+class UnBatchedIntType(NamedTuple):
+ x: int
+
+
+class BatchedIntType(NamedTuple):
+ x: List[int]
+
+
+class BatchedNumpyType(NamedTuple):
+ x: np.int64
+
+
+class TFTProcessHandlerSchemaTest(unittest.TestCase):
+ def setUp(self) -> None:
+ self.pipeline = TestPipeline()
+
+ @parameterized.expand([
+ ({
+ 'x': 1, 'y': 2
+ }, ['x'], {
+ 'x': 20, 'y': 2
+ }),
+ ({
+ 'x': 1, 'y': 2
+ }, ['x', 'y'], {
+ 'x': 20, 'y': 30
+ }),
+ ])
+ def test_tft_operation_preprocessing_fn(
+ self, inputs, columns, expected_result):
+ add_fn = _AddOperation(columns=columns)
+ mul_fn = _MultiplyOperation(columns=columns)
+ process_handler = handlers.TFTProcessHandlerSchema(
+ transforms=[add_fn, mul_fn])
+
+ actual_result = process_handler.process_data_fn(inputs)
+ self.assertDictEqual(actual_result, expected_result)
+
+ def test_preprocessing_fn_with_artifacts(self):
+ process_handler = handlers.TFTProcessHandlerSchema(
+ transforms=[_FakeOperationWithArtifacts(columns=['x'])])
+ inputs = {'x': [1, 2, 3]}
+ preprocessing_fn = process_handler.process_data_fn
+ actual_result = preprocessing_fn(inputs)
+ expected_result = {'x': [1, 2, 3], 'artifact': tf.convert_to_tensor([1])}
+ self.assertDictEqual(actual_result, expected_result)
+
+ def test_ml_transform_appends_transforms_to_process_handler_correctly(self):
+ fake_fn_1 = _FakeOperation(name='fake_fn_1', columns=['x'])
+ transforms = [fake_fn_1]
+ process_handler = handlers.TFTProcessHandlerSchema(transforms=transforms)
+ ml_transform = base.MLTransform(process_handler=process_handler)
+ ml_transform = ml_transform.with_transform(
+ transform=_FakeOperation(name='fake_fn_2', columns=['x']))
+
+ self.assertEqual(len(ml_transform._process_handler.transforms), 2)
+ self.assertEqual(
+ ml_transform._process_handler.transforms[0].name, 'fake_fn_1')
+ self.assertEqual(
+ ml_transform._process_handler.transforms[1].name, 'fake_fn_2')
+
+ def test_input_type_from_schema_named_tuple_pcoll_unbatched(self):
+ non_batched_data = [{'x': 1}]
+ with beam.Pipeline() as p:
+ data = (
+ p | beam.Create(non_batched_data)
+ | beam.Map(lambda x: UnBatchedIntType(**x)).with_output_types(
+ UnBatchedIntType))
+ element_type = data.element_type
+ process_handler = handlers.TFTProcessHandlerSchema()
+ inferred_input_type = process_handler._map_column_names_to_types(
+ element_type)
+ expected_input_type = dict(x=List[int])
+
+ self.assertEqual(inferred_input_type, expected_input_type)
+
+ def test_input_type_from_schema_named_tuple_pcoll_batched(self):
+ batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}]
+ with beam.Pipeline() as p:
+ data = (
+ p | beam.Create(batched_data)
+ | beam.Map(lambda x: BatchedIntType(**x)).with_output_types(
+ BatchedIntType))
+ element_type = data.element_type
+ process_handler = handlers.TFTProcessHandlerSchema()
+ inferred_input_type = process_handler._map_column_names_to_types(
+ element_type)
+ expected_input_type = dict(x=List[int])
+ self.assertEqual(inferred_input_type, expected_input_type)
+
+ def test_input_type_from_row_type_pcoll_unbatched(self):
+ non_batched_data = [{'x': 1}]
+ with beam.Pipeline() as p:
+ data = (
+ p | beam.Create(non_batched_data)
+ | beam.Map(lambda ele: beam.Row(x=int(ele['x']))))
+ element_type = data.element_type
+ process_handler = handlers.TFTProcessHandlerSchema()
+ inferred_input_type = process_handler._map_column_names_to_types(
+ element_type)
+ expected_input_type = dict(x=List[int])
+ self.assertEqual(inferred_input_type, expected_input_type)
+
+ def test_input_type_from_row_type_pcoll_batched(self):
+ batched_data = [{'x': [1, 2, 3]}, {'x': [4, 5, 6]}]
+ with beam.Pipeline() as p:
+ data = (
+ p | beam.Create(batched_data)
+ | beam.Map(lambda ele: beam.Row(x=list(ele['x']))).with_output_types(
+ beam.row_type.RowTypeConstraint.from_fields([('x', List[int])])))
+
+ element_type = data.element_type
+ process_handler = handlers.TFTProcessHandlerSchema()
+ inferred_input_type = process_handler._map_column_names_to_types(
+ element_type)
+ expected_input_type = dict(x=List[int])
+ self.assertEqual(inferred_input_type, expected_input_type)
+
+ def test_input_type_from_named_tuple_pcoll_batched_numpy(self):
+ batched = [{
+ 'x': np.array([1, 2, 3], dtype=np.int64)
+ }, {
+ 'x': np.array([4, 5, 6], dtype=np.int64)
+ }]
+ with beam.Pipeline() as p:
+ data = (
+ p | beam.Create(batched)
+ | beam.Map(lambda x: BatchedNumpyType(**x)).with_output_types(
+ BatchedNumpyType))
+ element_type = data.element_type
+ process_handler = handlers.TFTProcessHandlerSchema()
+ inferred_input_type = process_handler._map_column_names_to_types(
+ element_type)
+ expected_type = dict(x=np.int64)
+ self.assertEqual(inferred_input_type, expected_type)
+
+ def test_input_type_non_schema_pcoll(self):
+ non_batched_data = [{'x': 1}]
+ with beam.Pipeline() as p:
+ data = (p | beam.Create(non_batched_data))
+ element_type = data.element_type
+ process_handler = handlers.TFTProcessHandlerSchema()
+ with self.assertRaises(TypeError):
+ _ = process_handler._map_column_names_to_types(element_type)
+
+ def test_tensorflow_raw_data_metadata_primitive_types(self):
+ input_types = dict(x=int, y=float, k=bytes, l=str)
+ process_handler = handlers.TFTProcessHandlerSchema()
+
+ for col_name, typ in input_types.items():
+ feature_spec = process_handler._get_raw_data_feature_spec_per_column(
+ typ=typ, col_name=col_name)
+ self.assertEqual(
+ handlers._default_type_to_tensor_type_map[typ], feature_spec.dtype)
+ self.assertIsInstance(feature_spec, tf.io.VarLenFeature)
+
+ def test_tensorflow_raw_data_metadata_primitive_types_in_containers(self):
+ input_types = dict([("x", List[int]), ("y", List[float]),
+ ("k", List[bytes]), ("l", List[str])])
+ process_handler = handlers.TFTProcessHandlerSchema()
+ for col_name, typ in input_types.items():
+ feature_spec = process_handler._get_raw_data_feature_spec_per_column(
+ typ=typ, col_name=col_name)
+ self.assertIsInstance(feature_spec, tf.io.VarLenFeature)
+
+ def test_tensorflow_raw_data_metadata_primitive_native_container_types(self):
+ input_types = dict([("x", list[int]), ("y", list[float]),
+ ("k", list[bytes]), ("l", list[str])])
+ process_handler = handlers.TFTProcessHandlerSchema()
+ for col_name, typ in input_types.items():
+ feature_spec = process_handler._get_raw_data_feature_spec_per_column(
+ typ=typ, col_name=col_name)
+ self.assertIsInstance(feature_spec, tf.io.VarLenFeature)
+
+ def test_tensorflow_raw_data_metadata_numpy_types(self):
+ input_types = dict(x=np.int64, y=np.float32, z=List[np.int64])
+ process_handler = handlers.TFTProcessHandlerSchema()
+ for col_name, typ in input_types.items():
+ feature_spec = process_handler._get_raw_data_feature_spec_per_column(
+ typ=typ, col_name=col_name)
+ self.assertIsInstance(feature_spec, tf.io.VarLenFeature)
+
+ def test_tensorflow_raw_data_metadata_union_type_in_single_column(self):
+ input_types = dict(x=Union[int, float])
+ process_handler = handlers.TFTProcessHandlerSchema()
+ with self.assertRaises(TypeError):
+ for col_name, typ in input_types.items():
+ _ = process_handler._get_raw_data_feature_spec_per_column(
+ typ=typ, col_name=col_name)
+
+ def test_tensorflow_raw_data_metadata_dtypes(self):
+ input_types = dict(x=np.int32, y=np.float64)
+ expected_dtype = dict(x=np.int64, y=np.float32)
+ process_handler = handlers.TFTProcessHandlerSchema()
+ for col_name, typ in input_types.items():
+ feature_spec = process_handler._get_raw_data_feature_spec_per_column(
+ typ=typ, col_name=col_name)
+ self.assertEqual(expected_dtype[col_name], feature_spec.dtype)
+
+ @parameterized.expand([
+ param(
+ input_data=[{
+ 'x': 1,
+ 'y': 2.0,
+ }],
+ input_types={
+ 'x': int, 'y': float
+ },
+ expected_dtype={
+ 'x': typing.Sequence[np.float32],
+ 'y': typing.Sequence[np.float32]
+ }),
+ param(
+ input_data=[{
+ 'x': np.array([1], dtype=np.int64),
+ 'y': np.array([2.0], dtype=np.float32)
+ }],
+ input_types={
+ 'x': np.int32, 'y': np.float32
+ },
+ expected_dtype={
+ 'x': typing.Sequence[np.float32],
+ 'y': typing.Sequence[np.float32]
+ }),
+ param(
+ input_data=[{
+ 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0]
+ }],
+ input_types={
+ 'x': List[int], 'y': List[float]
+ },
+ expected_dtype={
+ 'x': typing.Sequence[np.float32],
+ 'y': typing.Sequence[np.float32]
+ }),
+ param(
+ input_data=[{
+ 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0]
+ }],
+ input_types={
+ 'x': typing.Sequence[int], 'y': typing.Sequence[float]
+ },
+ expected_dtype={
+ 'x': typing.Sequence[np.float32],
+ 'y': typing.Sequence[np.float32]
+ }),
+ # this fails on Python 3.8 since tpye subscripting is not supported
+ # param(
+ # input_data=[{
+ # 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0]
+ # }],
+ # input_types={
+ # 'x': list[int], 'y': list[float]
+ # },
+ # expected_dtype={
+ # 'x': typing.Sequence[np.float32],
+ # 'y': typing.Sequence[np.float32]
+ # }),
+ ])
+ def test_tft_process_handler_dict_output_pcoll_schema(
+ self, input_data, input_types, expected_dtype):
+ transforms = [tft_transforms.ScaleTo01(columns=['x'])]
Review Comment:
We should probably have at least 1 multi-column test case
##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -0,0 +1,436 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import collections
+import logging
+import os
+import tempfile
+import typing
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import numpy as np
+
+import apache_beam as beam
+from apache_beam.ml.transforms.base import ArtifactMode
+from apache_beam.ml.transforms.base import _ProcessHandler
+from apache_beam.ml.transforms.base import ProcessInputT
+from apache_beam.ml.transforms.base import ProcessOutputT
+from apache_beam.ml.transforms.tft_transforms import TFTOperation
+from apache_beam.ml.transforms.tft_transforms import _EXPECTED_TYPES
+from apache_beam.typehints import native_type_compatibility
+from apache_beam.typehints.row_type import RowTypeConstraint
+import pyarrow as pa
+import tensorflow as tf
+from tensorflow_metadata.proto.v0 import schema_pb2
+import tensorflow_transform.beam as tft_beam
+from tensorflow_transform import common_types
+from tensorflow_transform.beam.tft_beam_io import beam_metadata_io
+from tensorflow_transform.beam.tft_beam_io import transform_fn_io
+from tensorflow_transform.tf_metadata import dataset_metadata
+from tensorflow_transform.tf_metadata import metadata_io
+from tensorflow_transform.tf_metadata import schema_utils
+from tfx_bsl.tfxio import tf_example_record
+
+__all__ = [
+ 'TFTProcessHandler',
+]
+
+RAW_DATA_METADATA_DIR = 'raw_data_metadata'
+SCHEMA_FILE = 'schema.pbtxt'
+# tensorflow transform doesn't support the types other than tf.int64,
+# tf.float32 and tf.string.
+_default_type_to_tensor_type_map = {
+ int: tf.int64,
+ float: tf.float32,
+ str: tf.string,
+ bytes: tf.string,
+ np.int64: tf.int64,
+ np.int32: tf.int64,
+ np.float32: tf.float32,
+ np.float64: tf.float32,
+ np.bytes_: tf.string,
+ np.str_: tf.string,
+}
+_primitive_types_to_typing_container_type = {
+ int: List[int], float: List[float], str: List[str], bytes: List[bytes]
+}
+
+tft_process_handler_input_type = typing.Union[typing.NamedTuple,
+ beam.Row,
+ Dict[str,
+ typing.Union[str,
+ float,
+ int,
+ bytes,
+ np.ndarray]]]
+
+
+class ConvertScalarValuesToListValues(beam.DoFn):
+ def process(
+ self, element: Dict[str, typing.Any]
+ ) -> typing.Iterable[Dict[str, typing.List[typing.Any]]]:
+ new_dict = {}
+ for key, value in element.items():
+ if isinstance(value,
+ tuple(_primitive_types_to_typing_container_type.keys())):
+ new_dict[key] = [value]
+ else:
+ new_dict[key] = value
+ yield new_dict
+
+
+class ConvertNamedTupleToDict(
+ beam.PTransform[beam.PCollection[typing.Union[beam.Row,
typing.NamedTuple]],
+ beam.PCollection[Dict[str,
+ common_types.InstanceDictType]]]):
+ """
+ A PTransform that converts a collection of NamedTuples or Rows into a
+ collection of dictionaries.
+ """
+ def expand(
+ self, pcoll: beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]]
+ ) -> beam.PCollection[common_types.InstanceDictType]:
+ """
+ Args:
+ pcoll: A PCollection of NamedTuples or Rows.
+ Returns:
+ A PCollection of dictionaries.
+ """
+ if isinstance(pcoll.element_type, RowTypeConstraint):
+ # Row instance
+ return pcoll | beam.Map(lambda x: x.as_dict())
+ else:
+ # named tuple
+ return pcoll | beam.Map(lambda x: x._asdict())
+
+
+class TFTProcessHandler(_ProcessHandler[ProcessInputT, ProcessOutputT]):
+ def __init__(
+ self,
+ *,
+ artifact_location: str = None,
+ transforms: Optional[List[TFTOperation]] = None,
+ preprocessing_fn: typing.Optional[typing.Callable] = None,
+ is_input_record_batches: bool = False,
+ output_record_batches: bool = False,
+ artifact_mode: str = ArtifactMode.PRODUCE):
+ """
+ A handler class for processing data with TensorFlow Transform (TFT)
+ operations. This class is intended to be subclassed, with subclasses
+ implementing the `preprocessing_fn` method.
+ """
+ self.transforms = transforms if transforms else []
+ self.transformed_schema = None
+ self.artifact_location = artifact_location
+ self.preprocessing_fn = preprocessing_fn
+ self.is_input_record_batches = is_input_record_batches
+ self.output_record_batches = output_record_batches
+ self.artifact_mode = artifact_mode
+ if artifact_mode not in ['produce', 'consume']:
+ raise ValueError('artifact_mode must be either `produce` or `consume`.')
+
+ if not self.artifact_location:
+ self.artifact_location = tempfile.mkdtemp()
+
+ def append_transform(self, transform):
+ self.transforms.append(transform)
+
+ def _map_column_names_to_types(self, row_type):
+ """
+ Return a dictionary of column names and types.
+ Args:
+ element_type: A type of the element. This could be a NamedTuple or a Row.
+ Returns:
+ A dictionary of column names and types.
+ """
+ try:
+ if not isinstance(row_type, RowTypeConstraint):
+ row_type = RowTypeConstraint.from_user_type(row_type)
+
+ inferred_types = {name: typ for name, typ in row_type._fields}
+
+ for k, t in inferred_types.items():
+ if t in _primitive_types_to_typing_container_type:
+ inferred_types[k] = _primitive_types_to_typing_container_type[t]
+
+ # sometimes a numpy type can be provided as np.dtype('int64').
+ # convert numpy.dtype to numpy type since both are same.
+ for name, typ in inferred_types.items():
+ if isinstance(typ, np.dtype):
+ inferred_types[name] = typ.type
+
+ return inferred_types
+ except: # pylint: disable=bare-except
+ return {}
+
+ def _map_column_names_to_types_from_transforms(self):
+ column_type_mapping = {}
+ for transform in self.transforms:
+ for col in transform.columns:
+ if col not in column_type_mapping:
+ # we just need to dtype of first occurance of column in transforms.
+ class_name = transform.__class__.__name__
+ if class_name not in _EXPECTED_TYPES:
+ raise KeyError(
+ f"Transform {class_name} is not registered with a supported "
+ "type. Please register the transform with a supported type "
+ "using register_input_dtype decorator.")
+ column_type_mapping[col] = _EXPECTED_TYPES[
+ transform.__class__.__name__]
+ return column_type_mapping
+
+ def get_raw_data_feature_spec(
+ self, input_types: Dict[str, type]) -> Dict[str, tf.io.VarLenFeature]:
+ """
+ Return a DatasetMetadata object to be used with
+ tft_beam.AnalyzeAndTransformDataset.
+ Args:
+ input_types: A dictionary of column names and types.
+ Returns:
+ A DatasetMetadata object.
+ """
+ raw_data_feature_spec = {}
+ for key, value in input_types.items():
+ raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column(
+ typ=value, col_name=key)
+ return raw_data_feature_spec
+
+ def convert_raw_data_feature_spec_to_dataset_metadata(
+ self, raw_data_feature_spec) -> dataset_metadata.DatasetMetadata:
+ raw_data_metadata = dataset_metadata.DatasetMetadata(
+ schema_utils.schema_from_feature_spec(raw_data_feature_spec))
+ return raw_data_metadata
+
+ def _get_raw_data_feature_spec_per_column(
+ self, typ: type, col_name: str) -> tf.io.VarLenFeature:
+ """
+ Return a FeatureSpec object to be used with
+ tft_beam.AnalyzeAndTransformDataset
+ Args:
+ typ: A type of the column.
+ col_name: A name of the column.
+ Returns:
+ A FeatureSpec object.
+ """
+ # lets conver the builtin types to typing types for consistency.
+ typ = native_type_compatibility.convert_builtin_to_typing(typ)
+ primitive_containers_type = (
+ list,
+ collections.abc.Sequence,
+ )
+ is_primitive_container = (
+ typing.get_origin(typ) in primitive_containers_type)
+
+ if is_primitive_container:
+ dtype = typing.get_args(typ)[0] # type: ignore[attr-defined]
+ if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union:
# type: ignore[attr-defined]
+ raise RuntimeError(
+ f"Union type is not supported for column: {col_name}. "
+ f"Please pass a PCollection with valid schema for column "
+ f"{col_name} by passing a single type "
+ "in container. For example, List[int].")
+ elif issubclass(typ, np.generic) or typ in
_default_type_to_tensor_type_map:
+ dtype = typ
+ else:
+ raise TypeError(
+ f"Unable to identify type: {typ} specified on column: {col_name}. "
+ f"Please provide a valid type from the following: "
+ f"{_default_type_to_tensor_type_map.keys()}")
+ return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype])
+
+ def get_raw_data_metadata(
+ self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
+ raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
+ return self.convert_raw_data_feature_spec_to_dataset_metadata(
+ raw_data_feature_spec)
+
+ def write_transform_artifacts(self, transform_fn, location):
+ """
+ Write transform artifacts to the given location.
+ Args:
+ transform_fn: A transform_fn object.
+ location: A location to write the artifacts.
+ Returns:
+ A PCollection of WriteTransformFn writing a TF transform graph.
+ """
+ return (
+ transform_fn
+ | 'Write Transform Artifacts' >>
+ transform_fn_io.WriteTransformFn(location))
+
+ def _fail_on_non_default_windowing(self, pcoll: beam.PCollection):
+ if not pcoll.windowing.is_default():
+ raise RuntimeError(
+ "TFTProcessHandler only supports GlobalWindows when producing "
+ "artifacts such as min, max, variance etc over the dataset."
+ "Please use beam.WindowInto(beam.transforms.window.GlobalWindows()) "
+ "to convert your PCollection to GlobalWindow.")
+
+ def process_data_fn(
+ self, inputs: Dict[str, common_types.ConsistentTensorType]
+ ) -> Dict[str, common_types.ConsistentTensorType]:
+ """
+ This method is used in the AnalyzeAndTransformDataset step. It applies
+ the transforms to the `inputs` in sequential order on the columns
+ provided for a given transform.
+ Args:
+ inputs: A dictionary of column names and data.
+ Returns:
+ A dictionary of column names and transformed data.
+ """
+ outputs = inputs.copy()
+ for transform in self.transforms:
+ columns = transform.columns
+ for col in columns:
+ intermediate_result = transform.apply(
+ outputs[col], output_column_name=col)
+ for key, value in intermediate_result.items():
+ outputs[key] = value
+ return outputs
+
+ def _get_transformed_data_schema(
+ self,
+ metadata: dataset_metadata.DatasetMetadata,
+ ) -> Dict[str, typing.Sequence[typing.Union[np.float32, np.int64, bytes]]]:
+ schema = metadata._schema
+ transformed_types = {}
+ logging.info("Schema: %s", schema)
+ for feature in schema.feature:
+ name = feature.name
+ feature_type = feature.type
+ if feature_type == schema_pb2.FeatureType.FLOAT:
+ transformed_types[name] = typing.Sequence[np.float32]
+ elif feature_type == schema_pb2.FeatureType.INT:
+ transformed_types[name] = typing.Sequence[np.int64]
+ elif feature_type == schema_pb2.FeatureType.BYTES:
+ transformed_types[name] = typing.Sequence[bytes]
+ else:
+ # TODO: This else condition won't be hit since TFT doesn't output
+ # other than float, int and bytes. Refactor the code here.
+ raise RuntimeError(
+ 'Unsupported feature type: %s encountered' % feature_type)
+ logging.info(transformed_types)
+ return transformed_types
+
+ def process_data(
+ self, raw_data: beam.PCollection[tft_process_handler_input_type]
+ ) -> beam.PCollection[typing.Union[
+ beam.Row, Dict[str, np.ndarray], pa.RecordBatch]]:
+ """
+ This method also computes the required dataset metadata for the tft
+ AnalyzeDataset/TransformDataset step.
+
+ This method uses tensorflow_transform's Analyze step to produce the
+ artifacts and Transform step to apply the transforms on the data.
+ Artifacts are only produced if the artifact_mode is set to `produce`.
+ If artifact_mode is set to `consume`, then the artifacts are read from the
+ artifact_location, which was previously used to store the produced
+ artifacts.
+ """
+ if self.artifact_mode == ArtifactMode.PRODUCE:
+ # If we are computing artifacts, we should fail for windows other than
+ # default windowing since for example, for a fixed window, each window
can
+ # be treated as a separate dataset and we might need to compute artifacts
+ # for each window. This is not supported yet.
+ self._fail_on_non_default_windowing(raw_data)
+ element_type = raw_data.element_type
+ column_type_mapping = {}
+ if (isinstance(element_type, RowTypeConstraint) or
+ native_type_compatibility.match_is_named_tuple(element_type)):
+ column_type_mapping = self._map_column_names_to_types(
+ row_type=element_type)
+ # convert Row or NamedTuple to Dict
+ raw_data = (
+ raw_data
+ | ConvertNamedTupleToDict().with_output_types(
+ Dict[str, typing.Union[tuple(column_type_mapping.values())]]))
+ # AnalyzeAndTransformDataset raise type hint since this is
+ # schema'd PCollection and the current output type would be a
+ # custom type(NamedTuple) or a beam.Row type.
+ else:
+ column_type_mapping = self._map_column_names_to_types_from_transforms()
+ raw_data_metadata = self.get_raw_data_metadata(
+ input_types=column_type_mapping)
+ # Write untransformed metadata to a file so that it can be re-used
+ # during Transform step.
+ metadata_io.write_metadata(
+ metadata=raw_data_metadata,
+ path=os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+ else:
+ # Read the metadata from the artifact_location.
+ if not os.path.exists(os.path.join(
+ self.artifact_location, RAW_DATA_METADATA_DIR, SCHEMA_FILE)):
+ raise FileNotFoundError(
+ "Raw data metadata not found at %s" %
+ os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+ raw_data_metadata = metadata_io.read_metadata(
+ os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+
+ # To maintain consistency by outputting numpy array all the time,
+ # whether a scalar value or list or np array is passed as input,
+ # we will convert scalar values to list values and TFT will ouput
+ # numpy array all the time.
+ if not self.is_input_record_batches:
+ raw_data |= beam.ParDo(ConvertScalarValuesToListValues())
Review Comment:
Actually, I guess we'll need to solve this now for the output_record_batches
param anyways
##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -0,0 +1,436 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import collections
+import logging
+import os
+import tempfile
+import typing
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import numpy as np
+
+import apache_beam as beam
+from apache_beam.ml.transforms.base import ArtifactMode
+from apache_beam.ml.transforms.base import _ProcessHandler
+from apache_beam.ml.transforms.base import ProcessInputT
+from apache_beam.ml.transforms.base import ProcessOutputT
+from apache_beam.ml.transforms.tft_transforms import TFTOperation
+from apache_beam.ml.transforms.tft_transforms import _EXPECTED_TYPES
+from apache_beam.typehints import native_type_compatibility
+from apache_beam.typehints.row_type import RowTypeConstraint
+import pyarrow as pa
+import tensorflow as tf
+from tensorflow_metadata.proto.v0 import schema_pb2
+import tensorflow_transform.beam as tft_beam
+from tensorflow_transform import common_types
+from tensorflow_transform.beam.tft_beam_io import beam_metadata_io
+from tensorflow_transform.beam.tft_beam_io import transform_fn_io
+from tensorflow_transform.tf_metadata import dataset_metadata
+from tensorflow_transform.tf_metadata import metadata_io
+from tensorflow_transform.tf_metadata import schema_utils
+from tfx_bsl.tfxio import tf_example_record
+
+__all__ = [
+ 'TFTProcessHandler',
+]
+
+RAW_DATA_METADATA_DIR = 'raw_data_metadata'
+SCHEMA_FILE = 'schema.pbtxt'
+# tensorflow transform doesn't support the types other than tf.int64,
+# tf.float32 and tf.string.
+_default_type_to_tensor_type_map = {
+ int: tf.int64,
+ float: tf.float32,
+ str: tf.string,
+ bytes: tf.string,
+ np.int64: tf.int64,
+ np.int32: tf.int64,
+ np.float32: tf.float32,
+ np.float64: tf.float32,
+ np.bytes_: tf.string,
+ np.str_: tf.string,
+}
+_primitive_types_to_typing_container_type = {
+ int: List[int], float: List[float], str: List[str], bytes: List[bytes]
+}
+
+tft_process_handler_input_type = typing.Union[typing.NamedTuple,
+ beam.Row,
+ Dict[str,
+ typing.Union[str,
+ float,
+ int,
+ bytes,
+ np.ndarray]]]
+
+
+class ConvertScalarValuesToListValues(beam.DoFn):
+ def process(
+ self, element: Dict[str, typing.Any]
+ ) -> typing.Iterable[Dict[str, typing.List[typing.Any]]]:
+ new_dict = {}
+ for key, value in element.items():
+ if isinstance(value,
+ tuple(_primitive_types_to_typing_container_type.keys())):
+ new_dict[key] = [value]
+ else:
+ new_dict[key] = value
+ yield new_dict
+
+
+class ConvertNamedTupleToDict(
+ beam.PTransform[beam.PCollection[typing.Union[beam.Row,
typing.NamedTuple]],
+ beam.PCollection[Dict[str,
+ common_types.InstanceDictType]]]):
+ """
+ A PTransform that converts a collection of NamedTuples or Rows into a
+ collection of dictionaries.
+ """
+ def expand(
+ self, pcoll: beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]]
+ ) -> beam.PCollection[common_types.InstanceDictType]:
+ """
+ Args:
+ pcoll: A PCollection of NamedTuples or Rows.
+ Returns:
+ A PCollection of dictionaries.
+ """
+ if isinstance(pcoll.element_type, RowTypeConstraint):
+ # Row instance
+ return pcoll | beam.Map(lambda x: x.as_dict())
+ else:
+ # named tuple
+ return pcoll | beam.Map(lambda x: x._asdict())
+
+
+class TFTProcessHandler(_ProcessHandler[ProcessInputT, ProcessOutputT]):
+ def __init__(
+ self,
+ *,
+ artifact_location: str = None,
+ transforms: Optional[List[TFTOperation]] = None,
+ preprocessing_fn: typing.Optional[typing.Callable] = None,
+ is_input_record_batches: bool = False,
+ output_record_batches: bool = False,
+ artifact_mode: str = ArtifactMode.PRODUCE):
+ """
+ A handler class for processing data with TensorFlow Transform (TFT)
+ operations. This class is intended to be subclassed, with subclasses
+ implementing the `preprocessing_fn` method.
+ """
+ self.transforms = transforms if transforms else []
+ self.transformed_schema = None
+ self.artifact_location = artifact_location
+ self.preprocessing_fn = preprocessing_fn
+ self.is_input_record_batches = is_input_record_batches
+ self.output_record_batches = output_record_batches
+ self.artifact_mode = artifact_mode
+ if artifact_mode not in ['produce', 'consume']:
+ raise ValueError('artifact_mode must be either `produce` or `consume`.')
+
+ if not self.artifact_location:
+ self.artifact_location = tempfile.mkdtemp()
+
+ def append_transform(self, transform):
+ self.transforms.append(transform)
+
+ def _map_column_names_to_types(self, row_type):
+ """
+ Return a dictionary of column names and types.
+ Args:
+ element_type: A type of the element. This could be a NamedTuple or a Row.
+ Returns:
+ A dictionary of column names and types.
+ """
+ try:
+ if not isinstance(row_type, RowTypeConstraint):
+ row_type = RowTypeConstraint.from_user_type(row_type)
+
+ inferred_types = {name: typ for name, typ in row_type._fields}
+
+ for k, t in inferred_types.items():
+ if t in _primitive_types_to_typing_container_type:
+ inferred_types[k] = _primitive_types_to_typing_container_type[t]
+
+ # sometimes a numpy type can be provided as np.dtype('int64').
+ # convert numpy.dtype to numpy type since both are same.
+ for name, typ in inferred_types.items():
+ if isinstance(typ, np.dtype):
+ inferred_types[name] = typ.type
+
+ return inferred_types
+ except: # pylint: disable=bare-except
+ return {}
+
+ def _map_column_names_to_types_from_transforms(self):
+ column_type_mapping = {}
+ for transform in self.transforms:
+ for col in transform.columns:
+ if col not in column_type_mapping:
+ # we just need to dtype of first occurance of column in transforms.
+ class_name = transform.__class__.__name__
+ if class_name not in _EXPECTED_TYPES:
+ raise KeyError(
+ f"Transform {class_name} is not registered with a supported "
+ "type. Please register the transform with a supported type "
+ "using register_input_dtype decorator.")
+ column_type_mapping[col] = _EXPECTED_TYPES[
+ transform.__class__.__name__]
+ return column_type_mapping
+
+ def get_raw_data_feature_spec(
+ self, input_types: Dict[str, type]) -> Dict[str, tf.io.VarLenFeature]:
+ """
+ Return a DatasetMetadata object to be used with
+ tft_beam.AnalyzeAndTransformDataset.
+ Args:
+ input_types: A dictionary of column names and types.
+ Returns:
+ A DatasetMetadata object.
+ """
+ raw_data_feature_spec = {}
+ for key, value in input_types.items():
+ raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column(
+ typ=value, col_name=key)
+ return raw_data_feature_spec
+
+ def convert_raw_data_feature_spec_to_dataset_metadata(
+ self, raw_data_feature_spec) -> dataset_metadata.DatasetMetadata:
+ raw_data_metadata = dataset_metadata.DatasetMetadata(
+ schema_utils.schema_from_feature_spec(raw_data_feature_spec))
+ return raw_data_metadata
+
+ def _get_raw_data_feature_spec_per_column(
+ self, typ: type, col_name: str) -> tf.io.VarLenFeature:
+ """
+ Return a FeatureSpec object to be used with
+ tft_beam.AnalyzeAndTransformDataset
+ Args:
+ typ: A type of the column.
+ col_name: A name of the column.
+ Returns:
+ A FeatureSpec object.
+ """
+ # lets conver the builtin types to typing types for consistency.
+ typ = native_type_compatibility.convert_builtin_to_typing(typ)
+ primitive_containers_type = (
+ list,
+ collections.abc.Sequence,
+ )
+ is_primitive_container = (
+ typing.get_origin(typ) in primitive_containers_type)
+
+ if is_primitive_container:
+ dtype = typing.get_args(typ)[0] # type: ignore[attr-defined]
+ if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union:
# type: ignore[attr-defined]
+ raise RuntimeError(
+ f"Union type is not supported for column: {col_name}. "
+ f"Please pass a PCollection with valid schema for column "
+ f"{col_name} by passing a single type "
+ "in container. For example, List[int].")
+ elif issubclass(typ, np.generic) or typ in
_default_type_to_tensor_type_map:
+ dtype = typ
+ else:
+ raise TypeError(
+ f"Unable to identify type: {typ} specified on column: {col_name}. "
+ f"Please provide a valid type from the following: "
+ f"{_default_type_to_tensor_type_map.keys()}")
+ return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype])
+
+ def get_raw_data_metadata(
+ self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
+ raw_data_feature_spec = self.get_raw_data_feature_spec(input_types)
+ return self.convert_raw_data_feature_spec_to_dataset_metadata(
+ raw_data_feature_spec)
+
+ def write_transform_artifacts(self, transform_fn, location):
+ """
+ Write transform artifacts to the given location.
+ Args:
+ transform_fn: A transform_fn object.
+ location: A location to write the artifacts.
+ Returns:
+ A PCollection of WriteTransformFn writing a TF transform graph.
+ """
+ return (
+ transform_fn
+ | 'Write Transform Artifacts' >>
+ transform_fn_io.WriteTransformFn(location))
+
+ def _fail_on_non_default_windowing(self, pcoll: beam.PCollection):
+ if not pcoll.windowing.is_default():
+ raise RuntimeError(
+ "TFTProcessHandler only supports GlobalWindows when producing "
+ "artifacts such as min, max, variance etc over the dataset."
+ "Please use beam.WindowInto(beam.transforms.window.GlobalWindows()) "
+ "to convert your PCollection to GlobalWindow.")
+
+ def process_data_fn(
+ self, inputs: Dict[str, common_types.ConsistentTensorType]
+ ) -> Dict[str, common_types.ConsistentTensorType]:
+ """
+ This method is used in the AnalyzeAndTransformDataset step. It applies
+ the transforms to the `inputs` in sequential order on the columns
+ provided for a given transform.
+ Args:
+ inputs: A dictionary of column names and data.
+ Returns:
+ A dictionary of column names and transformed data.
+ """
+ outputs = inputs.copy()
+ for transform in self.transforms:
+ columns = transform.columns
+ for col in columns:
+ intermediate_result = transform.apply(
+ outputs[col], output_column_name=col)
+ for key, value in intermediate_result.items():
+ outputs[key] = value
+ return outputs
+
+ def _get_transformed_data_schema(
+ self,
+ metadata: dataset_metadata.DatasetMetadata,
+ ) -> Dict[str, typing.Sequence[typing.Union[np.float32, np.int64, bytes]]]:
+ schema = metadata._schema
+ transformed_types = {}
+ logging.info("Schema: %s", schema)
+ for feature in schema.feature:
+ name = feature.name
+ feature_type = feature.type
+ if feature_type == schema_pb2.FeatureType.FLOAT:
+ transformed_types[name] = typing.Sequence[np.float32]
+ elif feature_type == schema_pb2.FeatureType.INT:
+ transformed_types[name] = typing.Sequence[np.int64]
+ elif feature_type == schema_pb2.FeatureType.BYTES:
+ transformed_types[name] = typing.Sequence[bytes]
+ else:
+ # TODO: This else condition won't be hit since TFT doesn't output
+ # other than float, int and bytes. Refactor the code here.
+ raise RuntimeError(
+ 'Unsupported feature type: %s encountered' % feature_type)
+ logging.info(transformed_types)
+ return transformed_types
+
+ def process_data(
+ self, raw_data: beam.PCollection[tft_process_handler_input_type]
+ ) -> beam.PCollection[typing.Union[
+ beam.Row, Dict[str, np.ndarray], pa.RecordBatch]]:
+ """
+ This method also computes the required dataset metadata for the tft
+ AnalyzeDataset/TransformDataset step.
+
+ This method uses tensorflow_transform's Analyze step to produce the
+ artifacts and Transform step to apply the transforms on the data.
+ Artifacts are only produced if the artifact_mode is set to `produce`.
+ If artifact_mode is set to `consume`, then the artifacts are read from the
+ artifact_location, which was previously used to store the produced
+ artifacts.
+ """
+ if self.artifact_mode == ArtifactMode.PRODUCE:
+ # If we are computing artifacts, we should fail for windows other than
+ # default windowing since for example, for a fixed window, each window
can
+ # be treated as a separate dataset and we might need to compute artifacts
+ # for each window. This is not supported yet.
+ self._fail_on_non_default_windowing(raw_data)
+ element_type = raw_data.element_type
+ column_type_mapping = {}
+ if (isinstance(element_type, RowTypeConstraint) or
+ native_type_compatibility.match_is_named_tuple(element_type)):
+ column_type_mapping = self._map_column_names_to_types(
+ row_type=element_type)
+ # convert Row or NamedTuple to Dict
+ raw_data = (
+ raw_data
+ | ConvertNamedTupleToDict().with_output_types(
+ Dict[str, typing.Union[tuple(column_type_mapping.values())]]))
+ # AnalyzeAndTransformDataset raise type hint since this is
+ # schema'd PCollection and the current output type would be a
+ # custom type(NamedTuple) or a beam.Row type.
+ else:
+ column_type_mapping = self._map_column_names_to_types_from_transforms()
+ raw_data_metadata = self.get_raw_data_metadata(
+ input_types=column_type_mapping)
+ # Write untransformed metadata to a file so that it can be re-used
+ # during Transform step.
+ metadata_io.write_metadata(
+ metadata=raw_data_metadata,
+ path=os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+ else:
+ # Read the metadata from the artifact_location.
+ if not os.path.exists(os.path.join(
+ self.artifact_location, RAW_DATA_METADATA_DIR, SCHEMA_FILE)):
+ raise FileNotFoundError(
+ "Raw data metadata not found at %s" %
+ os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+ raw_data_metadata = metadata_io.read_metadata(
+ os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))
+
+ # To maintain consistency by outputting numpy array all the time,
+ # whether a scalar value or list or np array is passed as input,
+ # we will convert scalar values to list values and TFT will ouput
+ # numpy array all the time.
+ if not self.is_input_record_batches:
+ raw_data |= beam.ParDo(ConvertScalarValuesToListValues())
Review Comment:
Rather than having the `is_input_record_batches` param here, could we:
1) do a type check in `ConvertScalarValuesToListValues` and no-op if its a
record batch
2) introspect `schema` to determine if it needs the `TensorAdapter`
Not totally sure about the second piece, but if we can do something like
that it would be very helpful. I don't like the idea of having
`is_input_record_batches` as a top level config, especially once we extend this
to other frameworks.
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -0,0 +1,165 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Generic
+from typing import List
+from typing import Optional
+from typing import TypeVar
+
+import apache_beam as beam
+
+# TODO: Abstract methods are not getting pickled with dill.
+# https://github.com/uqfoundation/dill/issues/332
+# import abc
+
+__all__ = ['MLTransform']
+
+TransformedDatasetT = TypeVar('TransformedDatasetT')
+TransformedMetadataT = TypeVar('TransformedMetadataT')
+
+# Input/Output types to the MLTransform.
+ExampleT = TypeVar('ExampleT')
+MLTransformOutputT = TypeVar('MLTransformOutputT')
+
+# Input to the process data. This could be same or different from ExampleT.
+ProcessInputT = TypeVar('ProcessInputT')
+# Output of the process data. This could be same or different
+# from MLTransformOutputT
+ProcessOutputT = TypeVar('ProcessOutputT')
+
+# Input to the apply() method of BaseOperation.
+OperationInputT = TypeVar('OperationInputT')
+# Output of the apply() method of BaseOperation.
+OperationOutputT = TypeVar('OperationOutputT')
+
+
+class ArtifactMode(object):
+ PRODUCE = 'produce'
+ CONSUME = 'consume'
+
+
+class BaseOperation(Generic[OperationInputT, OperationOutputT]):
+ def apply(
+ self, inputs: OperationInputT, column_name: str, *args,
+ **kwargs) -> OperationOutputT:
+ """
+ Define any processing logic in the apply() method.
+ processing logics are applied on inputs and returns a transformed
+ output.
+ Args:
+ inputs: input data.
+ """
+ raise NotImplementedError
+
+
+class _ProcessHandler(Generic[ProcessInputT, ProcessOutputT]):
+ """
+ Only for internal use. No backwards compatibility guarantees.
+ """
+ def process_data(
+ self, pcoll: beam.PCollection[ProcessInputT]
+ ) -> beam.PCollection[ProcessOutputT]:
+ """
+ Logic to process the data. This will be the entrypoint in
+ beam.MLTransform to process incoming data.
+ """
+ raise NotImplementedError
+
+ def append_transform(self, transform: BaseOperation):
+ raise NotImplementedError
+
+
+class MLTransform(beam.PTransform[beam.PCollection[ExampleT],
+ beam.PCollection[MLTransformOutputT]],
+ Generic[ExampleT, MLTransformOutputT]):
+ def __init__(
+ self,
+ *,
+ artifact_location: str,
+ artifact_mode: str = ArtifactMode.PRODUCE,
+ transforms: Optional[List[BaseOperation]] = None,
+ is_input_record_batches: bool = False,
+ output_record_batches: bool = False,
Review Comment:
I think that would be my preferred experience. So users could do something
like:
```
MLTransform(
artifact_location=args.artifact_location,
artifact_mode=ArtifactMode.PRODUCE,
).with_transform(ComputeAndApplyVocabulary(
columns=['x'],
is_input_record_batches=True)).with_transform(TFIDF(columns=['x'],
output_input_record_batches=True))
```
For TFT, we'd then fuse together any consecutive TFT transforms into a
single TFTProcessHandler, resolve any conflicting arguments (e.g. throw if one
transform says output_record_batches and the next doesn't take recordBatches or
something), and construct the graph
##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -0,0 +1,165 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Generic
+from typing import List
+from typing import Optional
+from typing import TypeVar
+
+import apache_beam as beam
+
+# TODO: Abstract methods are not getting pickled with dill.
+# https://github.com/uqfoundation/dill/issues/332
+# import abc
+
+__all__ = ['MLTransform']
+
+TransformedDatasetT = TypeVar('TransformedDatasetT')
+TransformedMetadataT = TypeVar('TransformedMetadataT')
+
+# Input/Output types to the MLTransform.
+ExampleT = TypeVar('ExampleT')
+MLTransformOutputT = TypeVar('MLTransformOutputT')
+
+# Input to the process data. This could be same or different from ExampleT.
+ProcessInputT = TypeVar('ProcessInputT')
+# Output of the process data. This could be same or different
+# from MLTransformOutputT
+ProcessOutputT = TypeVar('ProcessOutputT')
+
+# Input to the apply() method of BaseOperation.
+OperationInputT = TypeVar('OperationInputT')
+# Output of the apply() method of BaseOperation.
+OperationOutputT = TypeVar('OperationOutputT')
+
+
+class ArtifactMode(object):
+ PRODUCE = 'produce'
+ CONSUME = 'consume'
+
+
+class BaseOperation(Generic[OperationInputT, OperationOutputT]):
+ def apply(
+ self, inputs: OperationInputT, column_name: str, *args,
+ **kwargs) -> OperationOutputT:
+ """
+ Define any processing logic in the apply() method.
+ processing logics are applied on inputs and returns a transformed
+ output.
+ Args:
+ inputs: input data.
+ """
+ raise NotImplementedError
+
+
+class _ProcessHandler(Generic[ProcessInputT, ProcessOutputT]):
+ """
+ Only for internal use. No backwards compatibility guarantees.
+ """
+ def process_data(
+ self, pcoll: beam.PCollection[ProcessInputT]
+ ) -> beam.PCollection[ProcessOutputT]:
+ """
+ Logic to process the data. This will be the entrypoint in
+ beam.MLTransform to process incoming data.
+ """
+ raise NotImplementedError
+
+ def append_transform(self, transform: BaseOperation):
+ raise NotImplementedError
+
+
+class MLTransform(beam.PTransform[beam.PCollection[ExampleT],
+ beam.PCollection[MLTransformOutputT]],
+ Generic[ExampleT, MLTransformOutputT]):
+ def __init__(
+ self,
+ *,
+ artifact_location: str,
+ artifact_mode: str = ArtifactMode.PRODUCE,
+ transforms: Optional[List[BaseOperation]] = None,
+ is_input_record_batches: bool = False,
+ output_record_batches: bool = False,
Review Comment:
One option would be to make these properties of the operations instead of
the top level transform. Chained TFT operations could then just use the values
from the first/last operation.
I also expect that we're going to run into this problem with other
frameworks in the future, so I think we need a way for adding additional
framework or transform specific parameters.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]