AnandInguva commented on code in PR #29564:
URL: https://github.com/apache/beam/pull/29564#discussion_r1414141813


##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -254,3 +371,243 @@ def _increment_counters():
         pipeline
         | beam.Create([None])
         | beam.Map(lambda _: _increment_counters()))
+
+
+class _TransformAttributeManager:
+  """
+  Base class used for saving and loading the attributes.
+  """
+  @staticmethod
+  def save_attributes(artifact_location):
+    """
+    Save the attributes to json file using stdlib json.
+    """
+    raise NotImplementedError
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    """
+    Load the attributes from json file.
+    """
+    raise NotImplementedError
+
+
+class _JsonPickleTransformAttributeManager(_TransformAttributeManager):
+  """
+  Use Jsonpickle to save and load the attributes. Here the attributes refer
+  to the list of PTransforms that are used to process the data.
+
+  jsonpickle is used to serialize the PTransforms and save it to a json file 
and
+  is compatible across python versions.
+  """
+  @staticmethod
+  def _is_remote_path(path):
+    is_gcs = path.find('gs://') != -1
+    # TODO: Add support for other remote paths.
+    if not is_gcs and path.find('://') != -1:
+      raise RuntimeError(
+          "Artifact locations are currently supported for only available for "
+          "local paths and GCS paths. Got: %s" % path)
+    return is_gcs
+
+  @staticmethod
+  def save_attributes(
+      ptransform_list,
+      artifact_location,
+      **kwargs,
+  ):
+    if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location):
+      try:
+        options = kwargs.get('options')
+      except KeyError:
+        raise RuntimeError(
+            'pipeline options are required to save the attributes.'
+            'in the artifact location %s' % artifact_location)
+
+      temp_dir = tempfile.mkdtemp()
+      temp_json_file = os.path.join(temp_dir, _ATTRIBUTE_FILE_NAME)
+      with open(temp_json_file, 'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+      with open(temp_json_file, 'rb') as f:
+        from apache_beam.runners.dataflow.internal import apiclient
+        _LOGGER.info('Creating artifact location: %s', artifact_location)
+        apiclient.DataflowApplicationClient(options=options).stage_file(
+            gcs_or_local_path=artifact_location,
+            file_name=_ATTRIBUTE_FILE_NAME,
+            stream=f,
+            mime_type='application/json')
+    else:
+      if not FileSystems.exists(artifact_location):
+        FileSystems.mkdirs(artifact_location)
+      # FileSystems.open() fails if the file does not exist.
+      with open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME),
+                'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    with FileSystems.open(os.path.join(artifact_location, 
_ATTRIBUTE_FILE_NAME),
+                          'rb') as f:
+      return jsonpickle.decode(f.read())
+
+
+_transform_attribute_manager = _JsonPickleTransformAttributeManager
+
+
+class _MLTransformToPTransformMapper:
+  """
+  This class takes in a list of data processing transforms compatible to be
+  wrapped around MLTransform and returns a list of PTransforms that are used to
+  run the data processing transforms.
+
+  The _MLTransformToPTransformMapper is responsible for loading and saving the
+  PTransforms or attributes of PTransforms to the artifact location to seal
+  the gap between the training and inference pipelines.
+  """
+  def __init__(
+      self,
+      transforms: List[Union[BaseOperation, EmbeddingsManager]],
+      artifact_location: str,
+      artifact_mode: str,
+      pipeline_options: Optional[PipelineOptions] = None,
+  ):
+    self.transforms = transforms
+    self._parent_artifact_location = artifact_location
+    self.artifact_mode = artifact_mode
+    self.pipeline_options = pipeline_options
+
+  def create_and_save_ptransform_list(self):
+    ptransform_list = self.create_ptransform_list()
+    self.save_transforms_in_artifact_location(ptransform_list)
+    return ptransform_list
+
+  def create_ptransform_list(self):
+    previous_ptransform_type = None
+    current_ptransform = None
+    ptransform_list = []
+    for transform in self.transforms:
+      if not isinstance(transform, PTransformProvider):
+        raise RuntimeError(
+            'Transforms must be instances of PTransformProvider and '
+            'implement get_ptransform_for_processing() method.')
+      # for each instance of PTransform, create a new artifact location
+      current_ptransform = transform.get_ptransform_for_processing(
+          artifact_location=os.path.join(
+              self._parent_artifact_location, uuid.uuid4().hex[:6]),
+          artifact_mode=self.artifact_mode)
+      # Determine if a new ptransform should be added to the list
+      is_different_type = (type(current_ptransform) != 
previous_ptransform_type)
+      if is_different_type or not transform.requires_chaining():
+        ptransform_list.append(current_ptransform)
+        previous_ptransform_type = type(current_ptransform)
+
+      if hasattr(ptransform_list[-1], 'append_transform'):
+        ptransform_list[-1].append_transform(transform)

Review Comment:
   >> Should this be an else if for the previous condition? E.g. if 
is_different_type is True, we'll append current_ptransform, but if that has an 
append_transform function we'll end up appending it to itself.
   
   We need to append(if applicable) the data processing transform to the most 
recent PTransform irrespective of the condition `if s_different_type or not 
transform.requires_chaining():` because `append_transform` relates to whether a 
data processing transform can be appended to the list of transforms present in 
the most recent PTransform.
   
   



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -222,14 +334,21 @@ def with_transform(self, transform: BaseOperation):
     Returns:
       A MLTransform instance.
     """
-    self._validate_transform(transform)
-    self._process_handler.append_transform(transform)
+    # self._validate_transform(transform)

Review Comment:
   uncommented it. 



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -254,3 +371,243 @@ def _increment_counters():
         pipeline
         | beam.Create([None])
         | beam.Map(lambda _: _increment_counters()))
+
+
+class _TransformAttributeManager:
+  """
+  Base class used for saving and loading the attributes.
+  """
+  @staticmethod
+  def save_attributes(artifact_location):
+    """
+    Save the attributes to json file using stdlib json.
+    """
+    raise NotImplementedError
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    """
+    Load the attributes from json file.
+    """
+    raise NotImplementedError
+
+
+class _JsonPickleTransformAttributeManager(_TransformAttributeManager):
+  """
+  Use Jsonpickle to save and load the attributes. Here the attributes refer
+  to the list of PTransforms that are used to process the data.
+
+  jsonpickle is used to serialize the PTransforms and save it to a json file 
and
+  is compatible across python versions.
+  """
+  @staticmethod
+  def _is_remote_path(path):
+    is_gcs = path.find('gs://') != -1
+    # TODO: Add support for other remote paths.
+    if not is_gcs and path.find('://') != -1:
+      raise RuntimeError(
+          "Artifact locations are currently supported for only available for "
+          "local paths and GCS paths. Got: %s" % path)
+    return is_gcs
+
+  @staticmethod
+  def save_attributes(
+      ptransform_list,
+      artifact_location,
+      **kwargs,
+  ):
+    if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location):
+      try:
+        options = kwargs.get('options')
+      except KeyError:
+        raise RuntimeError(
+            'pipeline options are required to save the attributes.'
+            'in the artifact location %s' % artifact_location)
+
+      temp_dir = tempfile.mkdtemp()
+      temp_json_file = os.path.join(temp_dir, _ATTRIBUTE_FILE_NAME)
+      with open(temp_json_file, 'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+      with open(temp_json_file, 'rb') as f:
+        from apache_beam.runners.dataflow.internal import apiclient
+        _LOGGER.info('Creating artifact location: %s', artifact_location)
+        apiclient.DataflowApplicationClient(options=options).stage_file(
+            gcs_or_local_path=artifact_location,
+            file_name=_ATTRIBUTE_FILE_NAME,
+            stream=f,
+            mime_type='application/json')
+    else:
+      if not FileSystems.exists(artifact_location):
+        FileSystems.mkdirs(artifact_location)
+      # FileSystems.open() fails if the file does not exist.
+      with open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME),
+                'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    with FileSystems.open(os.path.join(artifact_location, 
_ATTRIBUTE_FILE_NAME),
+                          'rb') as f:
+      return jsonpickle.decode(f.read())
+
+
+_transform_attribute_manager = _JsonPickleTransformAttributeManager
+
+
+class _MLTransformToPTransformMapper:
+  """
+  This class takes in a list of data processing transforms compatible to be
+  wrapped around MLTransform and returns a list of PTransforms that are used to
+  run the data processing transforms.
+
+  The _MLTransformToPTransformMapper is responsible for loading and saving the
+  PTransforms or attributes of PTransforms to the artifact location to seal
+  the gap between the training and inference pipelines.
+  """
+  def __init__(
+      self,
+      transforms: List[Union[BaseOperation, EmbeddingsManager]],
+      artifact_location: str,
+      artifact_mode: str,
+      pipeline_options: Optional[PipelineOptions] = None,
+  ):
+    self.transforms = transforms
+    self._parent_artifact_location = artifact_location
+    self.artifact_mode = artifact_mode
+    self.pipeline_options = pipeline_options
+
+  def create_and_save_ptransform_list(self):
+    ptransform_list = self.create_ptransform_list()
+    self.save_transforms_in_artifact_location(ptransform_list)
+    return ptransform_list
+
+  def create_ptransform_list(self):
+    previous_ptransform_type = None
+    current_ptransform = None
+    ptransform_list = []
+    for transform in self.transforms:
+      if not isinstance(transform, PTransformProvider):
+        raise RuntimeError(
+            'Transforms must be instances of PTransformProvider and '
+            'implement get_ptransform_for_processing() method.')
+      # for each instance of PTransform, create a new artifact location
+      current_ptransform = transform.get_ptransform_for_processing(
+          artifact_location=os.path.join(
+              self._parent_artifact_location, uuid.uuid4().hex[:6]),
+          artifact_mode=self.artifact_mode)
+      # Determine if a new ptransform should be added to the list
+      is_different_type = (type(current_ptransform) != 
previous_ptransform_type)
+      if is_different_type or not transform.requires_chaining():
+        ptransform_list.append(current_ptransform)
+        previous_ptransform_type = type(current_ptransform)
+
+      if hasattr(ptransform_list[-1], 'append_transform'):
+        ptransform_list[-1].append_transform(transform)

Review Comment:
   Removing the requires_chaining() method. Instead I am depending on if the 
ptransform contains append_transform method, if yes, I append the transform to 
the ptransform_list[-1]



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -254,3 +371,243 @@ def _increment_counters():
         pipeline
         | beam.Create([None])
         | beam.Map(lambda _: _increment_counters()))
+
+
+class _TransformAttributeManager:
+  """
+  Base class used for saving and loading the attributes.
+  """
+  @staticmethod
+  def save_attributes(artifact_location):
+    """
+    Save the attributes to json file using stdlib json.
+    """
+    raise NotImplementedError
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    """
+    Load the attributes from json file.
+    """
+    raise NotImplementedError
+
+
+class _JsonPickleTransformAttributeManager(_TransformAttributeManager):
+  """
+  Use Jsonpickle to save and load the attributes. Here the attributes refer
+  to the list of PTransforms that are used to process the data.
+
+  jsonpickle is used to serialize the PTransforms and save it to a json file 
and
+  is compatible across python versions.
+  """
+  @staticmethod
+  def _is_remote_path(path):
+    is_gcs = path.find('gs://') != -1
+    # TODO: Add support for other remote paths.
+    if not is_gcs and path.find('://') != -1:
+      raise RuntimeError(
+          "Artifact locations are currently supported for only available for "
+          "local paths and GCS paths. Got: %s" % path)
+    return is_gcs
+
+  @staticmethod
+  def save_attributes(
+      ptransform_list,
+      artifact_location,
+      **kwargs,
+  ):
+    if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location):
+      try:
+        options = kwargs.get('options')
+      except KeyError:
+        raise RuntimeError(
+            'pipeline options are required to save the attributes.'
+            'in the artifact location %s' % artifact_location)
+
+      temp_dir = tempfile.mkdtemp()
+      temp_json_file = os.path.join(temp_dir, _ATTRIBUTE_FILE_NAME)
+      with open(temp_json_file, 'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+      with open(temp_json_file, 'rb') as f:
+        from apache_beam.runners.dataflow.internal import apiclient
+        _LOGGER.info('Creating artifact location: %s', artifact_location)
+        apiclient.DataflowApplicationClient(options=options).stage_file(
+            gcs_or_local_path=artifact_location,
+            file_name=_ATTRIBUTE_FILE_NAME,
+            stream=f,
+            mime_type='application/json')
+    else:
+      if not FileSystems.exists(artifact_location):
+        FileSystems.mkdirs(artifact_location)
+      # FileSystems.open() fails if the file does not exist.
+      with open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME),
+                'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    with FileSystems.open(os.path.join(artifact_location, 
_ATTRIBUTE_FILE_NAME),
+                          'rb') as f:
+      return jsonpickle.decode(f.read())
+
+
+_transform_attribute_manager = _JsonPickleTransformAttributeManager
+
+
+class _MLTransformToPTransformMapper:
+  """
+  This class takes in a list of data processing transforms compatible to be
+  wrapped around MLTransform and returns a list of PTransforms that are used to
+  run the data processing transforms.
+
+  The _MLTransformToPTransformMapper is responsible for loading and saving the
+  PTransforms or attributes of PTransforms to the artifact location to seal
+  the gap between the training and inference pipelines.
+  """
+  def __init__(
+      self,
+      transforms: List[Union[BaseOperation, EmbeddingsManager]],
+      artifact_location: str,
+      artifact_mode: str,
+      pipeline_options: Optional[PipelineOptions] = None,
+  ):
+    self.transforms = transforms
+    self._parent_artifact_location = artifact_location
+    self.artifact_mode = artifact_mode
+    self.pipeline_options = pipeline_options
+
+  def create_and_save_ptransform_list(self):
+    ptransform_list = self.create_ptransform_list()
+    self.save_transforms_in_artifact_location(ptransform_list)
+    return ptransform_list
+
+  def create_ptransform_list(self):
+    previous_ptransform_type = None
+    current_ptransform = None
+    ptransform_list = []
+    for transform in self.transforms:
+      if not isinstance(transform, PTransformProvider):
+        raise RuntimeError(
+            'Transforms must be instances of PTransformProvider and '
+            'implement get_ptransform_for_processing() method.')
+      # for each instance of PTransform, create a new artifact location
+      current_ptransform = transform.get_ptransform_for_processing(
+          artifact_location=os.path.join(
+              self._parent_artifact_location, uuid.uuid4().hex[:6]),
+          artifact_mode=self.artifact_mode)
+      # Determine if a new ptransform should be added to the list
+      is_different_type = (type(current_ptransform) != 
previous_ptransform_type)
+      if is_different_type or not transform.requires_chaining():
+        ptransform_list.append(current_ptransform)
+        previous_ptransform_type = type(current_ptransform)
+
+      if hasattr(ptransform_list[-1], 'append_transform'):
+        ptransform_list[-1].append_transform(transform)

Review Comment:
   requires_chaining is used here to break down PTransforms. If we have two 
embedding configs in the `transforms` attr, transforms[0].requires_chaining() 
will be False, we will append RunInference to `ptransforms_list` and 
transforms[1].requires_chaining() will be False and then again we will append 
RunInference to the ptransforms list. for both transforms[0] and transforms[1], 
the type of PTransform would be same but based on the requires_chaining, we 
would create two RunInference PTransforms.



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -222,14 +334,21 @@ def with_transform(self, transform: BaseOperation):
     Returns:
       A MLTransform instance.
     """
-    self._validate_transform(transform)
-    self._process_handler.append_transform(transform)
+    # self._validate_transform(transform)
+    # avoid circular import
+    # pylint: disable=wrong-import-order, wrong-import-position

Review Comment:
   It was from my old code. Removed it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to