damccorm commented on code in PR #26795: URL: https://github.com/apache/beam/pull/26795#discussion_r1211862021
########## sdks/python/apache_beam/ml/transforms/base.py: ########## @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +from typing import Dict +from typing import Optional +from typing import TypeVar +from typing import Generic + +import apache_beam as beam + +# TODO: Abstract methods are not getting pickled with dill. +# https://github.com/uqfoundation/dill/issues/332 +# import abc + +__all__ = ['MLTransform', 'MLTransformOutput', 'ProcessHandler'] + +TransformedDatasetT = TypeVar('TransformedDatasetT') +TransformedMetadataT = TypeVar('TransformedMetadataT') + +# Input/Output types to the MLTransform. +ExampleT = TypeVar('ExampleT') +MLTransformOutputT = TypeVar('MLTransformOutputT') + +# Input to the process data. This could be same or different from ExampleT. Review Comment: As implemented in this base class, these are always the same, right? ########## sdks/python/apache_beam/ml/transforms/base.py: ########## @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +from typing import Dict +from typing import Optional +from typing import TypeVar +from typing import Generic + +import apache_beam as beam + +# TODO: Abstract methods are not getting pickled with dill. +# https://github.com/uqfoundation/dill/issues/332 +# import abc + +__all__ = ['MLTransform', 'MLTransformOutput', 'ProcessHandler'] + +TransformedDatasetT = TypeVar('TransformedDatasetT') +TransformedMetadataT = TypeVar('TransformedMetadataT') + +# Input/Output types to the MLTransform. +ExampleT = TypeVar('ExampleT') +MLTransformOutputT = TypeVar('MLTransformOutputT') + +# Input to the process data. This could be same or different from ExampleT. +ProcessInputT = TypeVar('ProcessInputT') +# Output of the process data. This could be same or different +# from MLTransformOutputT +ProcessOutputT = TypeVar('ProcessOutputT') + + +class _BaseOperation(): + def apply(self, inputs, *args, **kwargs): + """ + Define any processing logic in the apply() method. + processing logics are applied on inputs and returns a transformed + output. + + Args: + inputs: input data. + """ + raise NotImplementedError + + +# TODO: Add metrics namespace. +class MLTransformOutput(typing.NamedTuple): + transformed_data: TransformedDatasetT + transformed_metadata: Optional[TransformedMetadataT] = None + asset_map: Optional[Dict[str, str]] = None + + +class ProcessHandler(Generic[ProcessInputT, ProcessOutputT]): + def process_data( + self, pcoll: beam.PCollection[ProcessInputT] + ) -> beam.PCollection[ProcessOutputT]: + """ + Logic to process the data. This will be the entrypoint in + beam.MLTransform to process incoming data. + """ + raise NotImplementedError + + +class MLTransform(beam.PTransform[beam.PCollection[ExampleT], + beam.PCollection[MLTransformOutputT]], + Generic[ExampleT, + MLTransformOutputT, + ProcessInputT, + ProcessOutputT]): + def __init__( + self, + process_handler: ProcessHandler[ProcessInputT, ProcessOutputT], + ): + """ + Args: + process_handler: A ProcessHandler instance that defines the logic to + process the data. + """ + self._process_handler = process_handler + + def expand( + self, pcoll: beam.PCollection[ExampleT] + ) -> beam.PCollection[MLTransformOutputT]: + """ + This is the entrypoint for the MLTransform. This method will + invoke the process_data() method of the ProcessHandler instance + to process the incoming data. + Args: + pcoll: A PCollection of ExampleT type. + Returns: + A PCollection of MLTransformOutputT type. + """ + return self._process_handler.process_data(pcoll) + + def with_transform(self, transform: _BaseOperation): + """ + Add a transform to the MLTransform pipeline. + Args: + transform: A _BaseOperation instance. + Returns: + A MLTransform instance. + """ + self._process_handler.transforms.append(transform) Review Comment: Anything we're planning on invoking on a generic `ProcessHandler` object should probably be exposed on the base object, not just the children (e.g. something like `process_handler.append_transform`). Otherwise we're leaking implementation details from the children classes. ########## sdks/python/apache_beam/ml/transforms/handlers.py: ########## @@ -0,0 +1,406 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import typing +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransformOutput +from apache_beam.ml.transforms.base import ProcessHandler +from apache_beam.ml.transforms.base import ProcessInputT +from apache_beam.ml.transforms.base import ProcessOutputT +from apache_beam.ml.transforms.tft_transforms import _TFTOperation +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.typehints import native_type_compatibility +from apache_beam.typehints.row_type import RowTypeConstraint +import tensorflow as tf +import tensorflow_transform.beam as tft_beam +from tensorflow_transform import common_types +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.tf_metadata import dataset_metadata +from tensorflow_transform.tf_metadata import schema_utils + +__all__ = [ + 'TFTProcessHandlerDict', +] + +# tensorflow transform doesn't support the types other than tf.int64, +# tf.float32 and tf.string. +_default_type_to_tensor_type_map = { + int: tf.int64, + float: tf.float32, + str: tf.string, + bytes: tf.string, + np.int64: tf.int64, + np.int32: tf.int64, + np.float32: tf.float32, + np.float64: tf.float32, + np.bytes_: tf.string, + np.str_: tf.string, +} + +tft_process_handler_dict_input_type = typing.Union[typing.NamedTuple, beam.Row] + + +class ConvertNamedTupleToDict( + beam.PTransform[beam.PCollection[tft_process_handler_dict_input_type], + beam.PCollection[Dict[str, + common_types.InstanceDictType]]]): + """ + A PTransform that converts a collection of NamedTuples or Rows into a + collection of dictionaries. Review Comment: This would benefit from a description of why we want to do this (to convert to an acceptable TFT input type). ########## sdks/python/apache_beam/ml/transforms/handlers.py: ########## @@ -0,0 +1,406 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import typing +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransformOutput +from apache_beam.ml.transforms.base import ProcessHandler +from apache_beam.ml.transforms.base import ProcessInputT +from apache_beam.ml.transforms.base import ProcessOutputT +from apache_beam.ml.transforms.tft_transforms import _TFTOperation +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.typehints import native_type_compatibility +from apache_beam.typehints.row_type import RowTypeConstraint +import tensorflow as tf +import tensorflow_transform.beam as tft_beam +from tensorflow_transform import common_types +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.tf_metadata import dataset_metadata +from tensorflow_transform.tf_metadata import schema_utils + +__all__ = [ + 'TFTProcessHandlerDict', +] + +# tensorflow transform doesn't support the types other than tf.int64, +# tf.float32 and tf.string. +_default_type_to_tensor_type_map = { + int: tf.int64, + float: tf.float32, + str: tf.string, + bytes: tf.string, + np.int64: tf.int64, + np.int32: tf.int64, + np.float32: tf.float32, + np.float64: tf.float32, + np.bytes_: tf.string, + np.str_: tf.string, +} + +tft_process_handler_dict_input_type = typing.Union[typing.NamedTuple, beam.Row] + + +class ConvertNamedTupleToDict( + beam.PTransform[beam.PCollection[tft_process_handler_dict_input_type], + beam.PCollection[Dict[str, + common_types.InstanceDictType]]]): + """ + A PTransform that converts a collection of NamedTuples or Rows into a + collection of dictionaries. + """ + def expand( + self, pcoll: beam.PCollection[tft_process_handler_dict_input_type] + ) -> beam.PCollection[common_types.InstanceDictType]: + """ + Args: + pcoll: A PCollection of NamedTuples or Rows. + Returns: + A PCollection of dictionaries. + """ + if isinstance(pcoll.element_type, RowTypeConstraint): + # Row instance + return pcoll | beam.Map(lambda x: x.asdict()) + else: + # named tuple + return pcoll | beam.Map(lambda x: x._asdict()) + + +# TODO: Add metrics namespace. +class TFTProcessHandler(ProcessHandler[ProcessInputT, ProcessOutputT]): + def __init__( + self, + *, + input_types: Optional[Dict[str, type]] = None, + output_record_batches=False, + transforms: List[_TFTOperation] = None, + namespace: str = 'TFTProcessHandler', + ): + """ + A handler class for processing data with TensorFlow Transform (TFT) + operations. This class is intended to be subclassed, with subclasses + implementing the `preprocessing_fn` method. + + Args: + input_types: A dictionary of column names and types. + output_record_batches: Whether to output RecordBatches instead of + dictionaries. + transforms: A list of transforms to apply to the data. All the transforms + are applied in the order they are specified. The input of the + i-th transform is the output of the (i-1)-th transform. Multi-input + transforms are not supported yet. + namespace: A metrics namespace for the TFTProcessHandler. + """ + super().__init__() + self._input_types = input_types + self.transforms = transforms if transforms else [] + self._input_types = input_types + self._output_record_batches = output_record_batches + self._artifact_location = None + self._namespace = namespace + + def get_raw_data_feature_spec( + self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + """ + Return a DatasetMetadata object to be used with + tft_beam.AnalyzeAndTransformDataset. + Args: + input_types: A dictionary of column names and types. + Returns: + A DatasetMetadata object. + """ + raw_data_feature_spec = {} + for key, value in input_types.items(): + raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column( + typ=value, col_name=key) + raw_data_metadata = dataset_metadata.DatasetMetadata( + schema_utils.schema_from_feature_spec(raw_data_feature_spec)) + return raw_data_metadata + + def _get_raw_data_feature_spec_per_column(self, typ: type, col_name: str): + """ + Return a FeatureSpec object to be used with + tft_beam.AnalyzeAndTransformDataset + Args: + typ: A type of the column. + col_name: A name of the column. + Returns: + A FeatureSpec object. + """ + # lets conver the builtin types to typing types for consistency. + typ = native_type_compatibility.convert_builtin_to_typing(typ) + containers_type = (List._name, Tuple._name) + is_container = hasattr(typ, '_name') and typ._name in containers_type + + if is_container: + dtype = typing.get_args(typ)[0] + if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union: + raise RuntimeError( + f"Incorrect type specifications in {typ} for column {col_name}. " + f"Please specify a single type.") + if dtype not in _default_type_to_tensor_type_map: + raise TypeError( + f"Unable to identify type: {dtype} specified on column: {col_name}" + f". Please specify a valid type.") Review Comment: In our error cases, it would be good to be clear about what a valid type is ########## sdks/python/apache_beam/ml/transforms/handlers.py: ########## @@ -0,0 +1,406 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import typing +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransformOutput +from apache_beam.ml.transforms.base import ProcessHandler +from apache_beam.ml.transforms.base import ProcessInputT +from apache_beam.ml.transforms.base import ProcessOutputT +from apache_beam.ml.transforms.tft_transforms import _TFTOperation +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.typehints import native_type_compatibility +from apache_beam.typehints.row_type import RowTypeConstraint +import tensorflow as tf +import tensorflow_transform.beam as tft_beam +from tensorflow_transform import common_types +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.tf_metadata import dataset_metadata +from tensorflow_transform.tf_metadata import schema_utils + +__all__ = [ + 'TFTProcessHandlerDict', +] + +# tensorflow transform doesn't support the types other than tf.int64, +# tf.float32 and tf.string. +_default_type_to_tensor_type_map = { + int: tf.int64, + float: tf.float32, + str: tf.string, + bytes: tf.string, + np.int64: tf.int64, + np.int32: tf.int64, + np.float32: tf.float32, + np.float64: tf.float32, + np.bytes_: tf.string, + np.str_: tf.string, +} + +tft_process_handler_dict_input_type = typing.Union[typing.NamedTuple, beam.Row] + + +class ConvertNamedTupleToDict( + beam.PTransform[beam.PCollection[tft_process_handler_dict_input_type], + beam.PCollection[Dict[str, + common_types.InstanceDictType]]]): + """ + A PTransform that converts a collection of NamedTuples or Rows into a + collection of dictionaries. + """ + def expand( + self, pcoll: beam.PCollection[tft_process_handler_dict_input_type] + ) -> beam.PCollection[common_types.InstanceDictType]: + """ + Args: + pcoll: A PCollection of NamedTuples or Rows. + Returns: + A PCollection of dictionaries. + """ + if isinstance(pcoll.element_type, RowTypeConstraint): + # Row instance + return pcoll | beam.Map(lambda x: x.asdict()) + else: + # named tuple + return pcoll | beam.Map(lambda x: x._asdict()) + + +# TODO: Add metrics namespace. +class TFTProcessHandler(ProcessHandler[ProcessInputT, ProcessOutputT]): + def __init__( + self, + *, + input_types: Optional[Dict[str, type]] = None, + output_record_batches=False, + transforms: List[_TFTOperation] = None, + namespace: str = 'TFTProcessHandler', + ): + """ + A handler class for processing data with TensorFlow Transform (TFT) + operations. This class is intended to be subclassed, with subclasses + implementing the `preprocessing_fn` method. + + Args: + input_types: A dictionary of column names and types. + output_record_batches: Whether to output RecordBatches instead of + dictionaries. + transforms: A list of transforms to apply to the data. All the transforms + are applied in the order they are specified. The input of the + i-th transform is the output of the (i-1)-th transform. Multi-input + transforms are not supported yet. + namespace: A metrics namespace for the TFTProcessHandler. + """ + super().__init__() + self._input_types = input_types + self.transforms = transforms if transforms else [] + self._input_types = input_types + self._output_record_batches = output_record_batches + self._artifact_location = None + self._namespace = namespace + + def get_raw_data_feature_spec( + self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + """ + Return a DatasetMetadata object to be used with + tft_beam.AnalyzeAndTransformDataset. + Args: + input_types: A dictionary of column names and types. + Returns: + A DatasetMetadata object. + """ + raw_data_feature_spec = {} + for key, value in input_types.items(): + raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column( + typ=value, col_name=key) + raw_data_metadata = dataset_metadata.DatasetMetadata( + schema_utils.schema_from_feature_spec(raw_data_feature_spec)) + return raw_data_metadata + + def _get_raw_data_feature_spec_per_column(self, typ: type, col_name: str): + """ + Return a FeatureSpec object to be used with + tft_beam.AnalyzeAndTransformDataset + Args: + typ: A type of the column. + col_name: A name of the column. + Returns: + A FeatureSpec object. + """ + # lets conver the builtin types to typing types for consistency. + typ = native_type_compatibility.convert_builtin_to_typing(typ) + containers_type = (List._name, Tuple._name) + is_container = hasattr(typ, '_name') and typ._name in containers_type + + if is_container: + dtype = typing.get_args(typ)[0] + if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union: + raise RuntimeError( + f"Incorrect type specifications in {typ} for column {col_name}. " + f"Please specify a single type.") + if dtype not in _default_type_to_tensor_type_map: + raise TypeError( Review Comment: Does this check need to be applied in the `else` case as well? If so, should it be pulled out of the if/else? ########## sdks/python/apache_beam/ml/transforms/tft_transforms.py: ########## @@ -0,0 +1,301 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from apache_beam.ml.transforms.base import _BaseOperation +import tensorflow as tf +import tensorflow_transform as tft +from tensorflow_transform import analyzers +from tensorflow_transform import common_types +from tensorflow_transform import tf_utils + +__all__ = [ + 'compute_and_apply_vocabulary', + 'scale_to_z_score', + 'scale_to_0_1', + 'apply_buckets', + 'bucketize' +] + + +class _TFTOperation(_BaseOperation): + def __init__( + self, columns, save_result=False, output_name=None, *args, **kwargs): + """ + When subclassing _TFTOperation, please make sure + positional arguments are part of the instance variables. + """ + self.columns = columns + self._args = args + self._kwargs = kwargs + self.has_artifacts = False Review Comment: Nit: we could probably get rid of this variable by just returning `None` from `get_analyzer_artifacts` and checking that. Would save some boilerplate (which could get lost) in this file. ########## sdks/python/apache_beam/ml/transforms/handlers.py: ########## @@ -0,0 +1,406 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import typing +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransformOutput +from apache_beam.ml.transforms.base import ProcessHandler +from apache_beam.ml.transforms.base import ProcessInputT +from apache_beam.ml.transforms.base import ProcessOutputT +from apache_beam.ml.transforms.tft_transforms import _TFTOperation +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.typehints import native_type_compatibility +from apache_beam.typehints.row_type import RowTypeConstraint +import tensorflow as tf +import tensorflow_transform.beam as tft_beam +from tensorflow_transform import common_types +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.tf_metadata import dataset_metadata +from tensorflow_transform.tf_metadata import schema_utils + +__all__ = [ + 'TFTProcessHandlerDict', +] + +# tensorflow transform doesn't support the types other than tf.int64, +# tf.float32 and tf.string. +_default_type_to_tensor_type_map = { + int: tf.int64, + float: tf.float32, + str: tf.string, + bytes: tf.string, + np.int64: tf.int64, + np.int32: tf.int64, + np.float32: tf.float32, + np.float64: tf.float32, + np.bytes_: tf.string, + np.str_: tf.string, +} + +tft_process_handler_dict_input_type = typing.Union[typing.NamedTuple, beam.Row] + + +class ConvertNamedTupleToDict( + beam.PTransform[beam.PCollection[tft_process_handler_dict_input_type], + beam.PCollection[Dict[str, + common_types.InstanceDictType]]]): + """ + A PTransform that converts a collection of NamedTuples or Rows into a + collection of dictionaries. + """ + def expand( + self, pcoll: beam.PCollection[tft_process_handler_dict_input_type] + ) -> beam.PCollection[common_types.InstanceDictType]: + """ + Args: + pcoll: A PCollection of NamedTuples or Rows. + Returns: + A PCollection of dictionaries. + """ + if isinstance(pcoll.element_type, RowTypeConstraint): + # Row instance + return pcoll | beam.Map(lambda x: x.asdict()) + else: + # named tuple + return pcoll | beam.Map(lambda x: x._asdict()) + + +# TODO: Add metrics namespace. +class TFTProcessHandler(ProcessHandler[ProcessInputT, ProcessOutputT]): + def __init__( + self, + *, + input_types: Optional[Dict[str, type]] = None, + output_record_batches=False, + transforms: List[_TFTOperation] = None, + namespace: str = 'TFTProcessHandler', + ): + """ + A handler class for processing data with TensorFlow Transform (TFT) + operations. This class is intended to be subclassed, with subclasses + implementing the `preprocessing_fn` method. + + Args: + input_types: A dictionary of column names and types. + output_record_batches: Whether to output RecordBatches instead of + dictionaries. + transforms: A list of transforms to apply to the data. All the transforms + are applied in the order they are specified. The input of the + i-th transform is the output of the (i-1)-th transform. Multi-input + transforms are not supported yet. + namespace: A metrics namespace for the TFTProcessHandler. + """ + super().__init__() + self._input_types = input_types + self.transforms = transforms if transforms else [] + self._input_types = input_types + self._output_record_batches = output_record_batches + self._artifact_location = None + self._namespace = namespace + + def get_raw_data_feature_spec( + self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + """ + Return a DatasetMetadata object to be used with + tft_beam.AnalyzeAndTransformDataset. + Args: + input_types: A dictionary of column names and types. + Returns: + A DatasetMetadata object. + """ + raw_data_feature_spec = {} + for key, value in input_types.items(): + raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column( + typ=value, col_name=key) + raw_data_metadata = dataset_metadata.DatasetMetadata( + schema_utils.schema_from_feature_spec(raw_data_feature_spec)) + return raw_data_metadata + + def _get_raw_data_feature_spec_per_column(self, typ: type, col_name: str): + """ + Return a FeatureSpec object to be used with + tft_beam.AnalyzeAndTransformDataset + Args: + typ: A type of the column. + col_name: A name of the column. + Returns: + A FeatureSpec object. + """ + # lets conver the builtin types to typing types for consistency. + typ = native_type_compatibility.convert_builtin_to_typing(typ) + containers_type = (List._name, Tuple._name) + is_container = hasattr(typ, '_name') and typ._name in containers_type + + if is_container: + dtype = typing.get_args(typ)[0] + if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union: + raise RuntimeError( + f"Incorrect type specifications in {typ} for column {col_name}. " + f"Please specify a single type.") + if dtype not in _default_type_to_tensor_type_map: + raise TypeError( + f"Unable to identify type: {dtype} specified on column: {col_name}" + f". Please specify a valid type.") + else: + dtype = typ + + is_container = is_container or issubclass(dtype, np.generic) + if is_container: + return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype]) + else: + return tf.io.FixedLenFeature([], _default_type_to_tensor_type_map[dtype]) + + def get_metadata(self, input_types: Dict[str, type]): + """ + Return metadata to be used with tft_beam.AnalyzeAndTransformDataset + Args: + input_types: A dictionary of column names and types. + """ + raise NotImplementedError + + def write_transform_artifacts(self, transform_fn, location): + """ + Write transform artifacts to the given location. + Args: + transform_fn: A transform_fn object. + location: A location to write the artifacts. + Returns: + A PCollection of WriteTransformFn writing a TF transform graph. + """ + return ( + transform_fn + | 'Write Transform Artifacts' >> + transform_fn_io.WriteTransformFn(location)) + + def infer_output_type(self, input_type): + if not isinstance(input_type, RowTypeConstraint): + row_type = RowTypeConstraint.from_user_type(input_type) + fields = row_type._inner_types() + return Dict[str, Union[tuple(fields)]] + + def _get_artifact_location(self, pipeline: beam.Pipeline): Review Comment: Instead of dumping this in the staging directory or a temp directory, should we require users to provide an output directory? Presumably, users will want a well defined location for retrieving their artifacts ########## sdks/python/apache_beam/ml/transforms/tft_transforms.py: ########## @@ -0,0 +1,301 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from apache_beam.ml.transforms.base import _BaseOperation +import tensorflow as tf +import tensorflow_transform as tft +from tensorflow_transform import analyzers +from tensorflow_transform import common_types +from tensorflow_transform import tf_utils + +__all__ = [ + 'compute_and_apply_vocabulary', + 'scale_to_z_score', + 'scale_to_0_1', + 'apply_buckets', + 'bucketize' +] + + +class _TFTOperation(_BaseOperation): + def __init__( + self, columns, save_result=False, output_name=None, *args, **kwargs): + """ + When subclassing _TFTOperation, please make sure + positional arguments are part of the instance variables. + """ + self.columns = columns + self._args = args + self._kwargs = kwargs + self.has_artifacts = False + + self._save_result = save_result + self._output_name = output_name + if not columns: + raise RuntimeError( + "Columns are not specified. Please specify the column for the " + " op %s" % self) + + if self._save_result and not self._output_name: + raise RuntimeError( + "Inplace is set to True. " + "but output name in which transformed data is stored" + " is not specified. Please specify the output name for " + " the op %s" % self) + + def apply(self, inputs, *args, **kwargs): + raise NotImplementedError + + def validate_args(self): + raise NotImplementedError + + def __call__(self, data): + return self.apply(data, *self._args, **self._kwargs) + + def get_analyzer_artifacts(self, data, col_name): Review Comment: There's a bunch of places types could be helpful in this file -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
