AnandInguva commented on code in PR #26795:
URL: https://github.com/apache/beam/pull/26795#discussion_r1213655552


##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -0,0 +1,406 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import typing
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+
+import apache_beam as beam
+from apache_beam.ml.transforms.base import MLTransformOutput
+from apache_beam.ml.transforms.base import ProcessHandler
+from apache_beam.ml.transforms.base import ProcessInputT
+from apache_beam.ml.transforms.base import ProcessOutputT
+from apache_beam.ml.transforms.tft_transforms import _TFTOperation
+from apache_beam.options.pipeline_options import GoogleCloudOptions
+from apache_beam.typehints import native_type_compatibility
+from apache_beam.typehints.row_type import RowTypeConstraint
+import tensorflow as tf
+import tensorflow_transform.beam as tft_beam
+from tensorflow_transform import common_types
+from tensorflow_transform.beam.tft_beam_io import transform_fn_io
+from tensorflow_transform.tf_metadata import dataset_metadata
+from tensorflow_transform.tf_metadata import schema_utils
+
+__all__ = [
+    'TFTProcessHandlerDict',
+]
+
+# tensorflow transform doesn't support the types other than tf.int64,
+# tf.float32 and tf.string.
+_default_type_to_tensor_type_map = {
+    int: tf.int64,
+    float: tf.float32,
+    str: tf.string,
+    bytes: tf.string,
+    np.int64: tf.int64,
+    np.int32: tf.int64,
+    np.float32: tf.float32,
+    np.float64: tf.float32,
+    np.bytes_: tf.string,
+    np.str_: tf.string,
+}
+
+tft_process_handler_dict_input_type = typing.Union[typing.NamedTuple, beam.Row]
+
+
+class ConvertNamedTupleToDict(
+    beam.PTransform[beam.PCollection[tft_process_handler_dict_input_type],
+                    beam.PCollection[Dict[str,
+                                          common_types.InstanceDictType]]]):
+  """
+    A PTransform that converts a collection of NamedTuples or Rows into a
+    collection of dictionaries.

Review Comment:
   I added the reasoning as a comment where this gets called,  on why we need 
this PTransform.



##########
sdks/python/apache_beam/ml/transforms/handlers.py:
##########
@@ -0,0 +1,406 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import typing
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+
+import apache_beam as beam
+from apache_beam.ml.transforms.base import MLTransformOutput
+from apache_beam.ml.transforms.base import ProcessHandler
+from apache_beam.ml.transforms.base import ProcessInputT
+from apache_beam.ml.transforms.base import ProcessOutputT
+from apache_beam.ml.transforms.tft_transforms import _TFTOperation
+from apache_beam.options.pipeline_options import GoogleCloudOptions
+from apache_beam.typehints import native_type_compatibility
+from apache_beam.typehints.row_type import RowTypeConstraint
+import tensorflow as tf
+import tensorflow_transform.beam as tft_beam
+from tensorflow_transform import common_types
+from tensorflow_transform.beam.tft_beam_io import transform_fn_io
+from tensorflow_transform.tf_metadata import dataset_metadata
+from tensorflow_transform.tf_metadata import schema_utils
+
+__all__ = [
+    'TFTProcessHandlerDict',
+]
+
+# tensorflow transform doesn't support the types other than tf.int64,
+# tf.float32 and tf.string.
+_default_type_to_tensor_type_map = {
+    int: tf.int64,
+    float: tf.float32,
+    str: tf.string,
+    bytes: tf.string,
+    np.int64: tf.int64,
+    np.int32: tf.int64,
+    np.float32: tf.float32,
+    np.float64: tf.float32,
+    np.bytes_: tf.string,
+    np.str_: tf.string,
+}
+
+tft_process_handler_dict_input_type = typing.Union[typing.NamedTuple, beam.Row]
+
+
+class ConvertNamedTupleToDict(
+    beam.PTransform[beam.PCollection[tft_process_handler_dict_input_type],
+                    beam.PCollection[Dict[str,
+                                          common_types.InstanceDictType]]]):
+  """
+    A PTransform that converts a collection of NamedTuples or Rows into a
+    collection of dictionaries.
+  """
+  def expand(
+      self, pcoll: beam.PCollection[tft_process_handler_dict_input_type]
+  ) -> beam.PCollection[common_types.InstanceDictType]:
+    """
+    Args:
+      pcoll: A PCollection of NamedTuples or Rows.
+    Returns:
+      A PCollection of dictionaries.
+    """
+    if isinstance(pcoll.element_type, RowTypeConstraint):
+      # Row instance
+      return pcoll | beam.Map(lambda x: x.asdict())
+    else:
+      # named tuple
+      return pcoll | beam.Map(lambda x: x._asdict())
+
+
+# TODO: Add metrics namespace.
+class TFTProcessHandler(ProcessHandler[ProcessInputT, ProcessOutputT]):
+  def __init__(
+      self,
+      *,
+      input_types: Optional[Dict[str, type]] = None,
+      output_record_batches=False,
+      transforms: List[_TFTOperation] = None,
+      namespace: str = 'TFTProcessHandler',
+  ):
+    """
+    A handler class for processing data with TensorFlow Transform (TFT)
+    operations. This class is intended to be subclassed, with subclasses
+    implementing the `preprocessing_fn` method.
+
+    Args:
+      input_types: A dictionary of column names and types.
+      output_record_batches: Whether to output RecordBatches instead of
+        dictionaries.
+      transforms: A list of transforms to apply to the data. All the transforms
+        are applied in the order they are specified. The input of the
+        i-th transform is the output of the (i-1)-th transform. Multi-input
+        transforms are not supported yet.
+      namespace: A metrics namespace for the TFTProcessHandler.
+    """
+    super().__init__()
+    self._input_types = input_types
+    self.transforms = transforms if transforms else []
+    self._input_types = input_types
+    self._output_record_batches = output_record_batches
+    self._artifact_location = None
+    self._namespace = namespace
+
+  def get_raw_data_feature_spec(
+      self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata:
+    """
+    Return a DatasetMetadata object to be used with
+    tft_beam.AnalyzeAndTransformDataset.
+    Args:
+      input_types: A dictionary of column names and types.
+    Returns:
+      A DatasetMetadata object.
+    """
+    raw_data_feature_spec = {}
+    for key, value in input_types.items():
+      raw_data_feature_spec[key] = self._get_raw_data_feature_spec_per_column(
+          typ=value, col_name=key)
+    raw_data_metadata = dataset_metadata.DatasetMetadata(
+        schema_utils.schema_from_feature_spec(raw_data_feature_spec))
+    return raw_data_metadata
+
+  def _get_raw_data_feature_spec_per_column(self, typ: type, col_name: str):
+    """
+    Return a FeatureSpec object to be used with
+    tft_beam.AnalyzeAndTransformDataset
+    Args:
+      typ: A type of the column.
+      col_name: A name of the column.
+    Returns:
+      A FeatureSpec object.
+    """
+    # lets conver the builtin types to typing types for consistency.
+    typ = native_type_compatibility.convert_builtin_to_typing(typ)
+    containers_type = (List._name, Tuple._name)
+    is_container = hasattr(typ, '_name') and typ._name in containers_type
+
+    if is_container:
+      dtype = typing.get_args(typ)[0]
+      if len(typing.get_args(typ)) > 1 or typing.get_origin(dtype) == Union:
+        raise RuntimeError(
+            f"Incorrect type specifications in {typ} for column {col_name}. "
+            f"Please specify a single type.")
+      if dtype not in _default_type_to_tensor_type_map:
+        raise TypeError(
+            f"Unable to identify type: {dtype} specified on column: {col_name}"
+            f". Please specify a valid type.")
+    else:
+      dtype = typ
+
+    is_container = is_container or issubclass(dtype, np.generic)
+    if is_container:
+      return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype])
+    else:
+      return tf.io.FixedLenFeature([], _default_type_to_tensor_type_map[dtype])
+
+  def get_metadata(self, input_types: Dict[str, type]):
+    """
+    Return metadata to be used with tft_beam.AnalyzeAndTransformDataset
+    Args:
+      input_types: A dictionary of column names and types.
+    """
+    raise NotImplementedError
+
+  def write_transform_artifacts(self, transform_fn, location):
+    """
+    Write transform artifacts to the given location.
+    Args:
+      transform_fn: A transform_fn object.
+      location: A location to write the artifacts.
+    Returns:
+      A PCollection of WriteTransformFn writing a TF transform graph.
+    """
+    return (
+        transform_fn
+        | 'Write Transform Artifacts' >>
+        transform_fn_io.WriteTransformFn(location))
+
+  def infer_output_type(self, input_type):
+    if not isinstance(input_type, RowTypeConstraint):
+      row_type = RowTypeConstraint.from_user_type(input_type)
+    fields = row_type._inner_types()
+    return Dict[str, Union[tuple(fields)]]
+
+  def _get_artifact_location(self, pipeline: beam.Pipeline):

Review Comment:
   Added it as optional. If user doesn't provide it, I am falling back to this 
approach



##########
sdks/python/apache_beam/ml/transforms/tft_transforms.py:
##########
@@ -0,0 +1,301 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from apache_beam.ml.transforms.base import _BaseOperation
+import tensorflow as tf
+import tensorflow_transform as tft
+from tensorflow_transform import analyzers
+from tensorflow_transform import common_types
+from tensorflow_transform import tf_utils
+
+__all__ = [
+    'compute_and_apply_vocabulary',
+    'scale_to_z_score',
+    'scale_to_0_1',
+    'apply_buckets',
+    'bucketize'
+]
+
+
+class _TFTOperation(_BaseOperation):
+  def __init__(
+      self, columns, save_result=False, output_name=None, *args, **kwargs):
+    """
+    When subclassing _TFTOperation, please make sure
+    positional arguments are part of the instance variables.
+    """
+    self.columns = columns
+    self._args = args
+    self._kwargs = kwargs
+    self.has_artifacts = False

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to