damccorm commented on code in PR #29564:
URL: https://github.com/apache/beam/pull/29564#discussion_r1410898194


##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -42,12 +64,68 @@
 OperationOutputT = TypeVar('OperationOutputT')
 
 
+def _convert_list_of_dicts_to_dict_of_lists(
+    list_of_dicts: Sequence[Dict[str, Any]]) -> Dict[str, List[Any]]:
+  keys_to_element_list = collections.defaultdict(list)
+  for d in list_of_dicts:
+    for key, value in d.items():
+      keys_to_element_list[key].append(value)
+  return keys_to_element_list
+
+
+def _convert_dict_of_lists_to_lists_of_dict(
+    dict_of_lists: Dict[str, List[Any]],
+    batch_length: int) -> List[Dict[str, Any]]:
+  result: List[Dict[str, Any]] = [{} for _ in range(batch_length)]
+  for key, values in dict_of_lists.items():
+    for i in range(len(values)):
+      result[i][key] = values[i]

Review Comment:
   Is it possible for i to ever be larger than batch length? A more robust way 
to do this might be without using batch length; you could just check if 
result[i] exists and then append a {} if not



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -222,14 +334,21 @@ def with_transform(self, transform: BaseOperation):
     Returns:
       A MLTransform instance.
     """
-    self._validate_transform(transform)
-    self._process_handler.append_transform(transform)
+    # self._validate_transform(transform)

Review Comment:
   Any reason to not have this anymore?



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -254,3 +371,243 @@ def _increment_counters():
         pipeline
         | beam.Create([None])
         | beam.Map(lambda _: _increment_counters()))
+
+
+class _TransformAttributeManager:
+  """
+  Base class used for saving and loading the attributes.
+  """
+  @staticmethod
+  def save_attributes(artifact_location):
+    """
+    Save the attributes to json file using stdlib json.
+    """
+    raise NotImplementedError
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    """
+    Load the attributes from json file.
+    """
+    raise NotImplementedError
+
+
+class _JsonPickleTransformAttributeManager(_TransformAttributeManager):
+  """
+  Use Jsonpickle to save and load the attributes. Here the attributes refer
+  to the list of PTransforms that are used to process the data.
+
+  jsonpickle is used to serialize the PTransforms and save it to a json file 
and
+  is compatible across python versions.
+  """
+  @staticmethod
+  def _is_remote_path(path):
+    is_gcs = path.find('gs://') != -1
+    # TODO: Add support for other remote paths.
+    if not is_gcs and path.find('://') != -1:
+      raise RuntimeError(
+          "Artifact locations are currently supported for only available for "
+          "local paths and GCS paths. Got: %s" % path)
+    return is_gcs
+
+  @staticmethod
+  def save_attributes(
+      ptransform_list,
+      artifact_location,
+      **kwargs,
+  ):
+    if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location):
+      try:
+        options = kwargs.get('options')
+      except KeyError:
+        raise RuntimeError(
+            'pipeline options are required to save the attributes.'
+            'in the artifact location %s' % artifact_location)
+
+      temp_dir = tempfile.mkdtemp()
+      temp_json_file = os.path.join(temp_dir, _ATTRIBUTE_FILE_NAME)
+      with open(temp_json_file, 'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+      with open(temp_json_file, 'rb') as f:
+        from apache_beam.runners.dataflow.internal import apiclient
+        _LOGGER.info('Creating artifact location: %s', artifact_location)
+        apiclient.DataflowApplicationClient(options=options).stage_file(
+            gcs_or_local_path=artifact_location,
+            file_name=_ATTRIBUTE_FILE_NAME,
+            stream=f,
+            mime_type='application/json')
+    else:
+      if not FileSystems.exists(artifact_location):
+        FileSystems.mkdirs(artifact_location)
+      # FileSystems.open() fails if the file does not exist.
+      with open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME),
+                'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    with FileSystems.open(os.path.join(artifact_location, 
_ATTRIBUTE_FILE_NAME),
+                          'rb') as f:
+      return jsonpickle.decode(f.read())
+
+
+_transform_attribute_manager = _JsonPickleTransformAttributeManager
+
+
+class _MLTransformToPTransformMapper:
+  """
+  This class takes in a list of data processing transforms compatible to be
+  wrapped around MLTransform and returns a list of PTransforms that are used to
+  run the data processing transforms.
+
+  The _MLTransformToPTransformMapper is responsible for loading and saving the
+  PTransforms or attributes of PTransforms to the artifact location to seal
+  the gap between the training and inference pipelines.
+  """
+  def __init__(
+      self,
+      transforms: List[Union[BaseOperation, EmbeddingsManager]],
+      artifact_location: str,
+      artifact_mode: str,
+      pipeline_options: Optional[PipelineOptions] = None,
+  ):
+    self.transforms = transforms
+    self._parent_artifact_location = artifact_location
+    self.artifact_mode = artifact_mode
+    self.pipeline_options = pipeline_options
+
+  def create_and_save_ptransform_list(self):
+    ptransform_list = self.create_ptransform_list()
+    self.save_transforms_in_artifact_location(ptransform_list)
+    return ptransform_list
+
+  def create_ptransform_list(self):
+    previous_ptransform_type = None
+    current_ptransform = None
+    ptransform_list = []
+    for transform in self.transforms:
+      if not isinstance(transform, PTransformProvider):
+        raise RuntimeError(
+            'Transforms must be instances of PTransformProvider and '
+            'implement get_ptransform_for_processing() method.')
+      # for each instance of PTransform, create a new artifact location
+      current_ptransform = transform.get_ptransform_for_processing(
+          artifact_location=os.path.join(
+              self._parent_artifact_location, uuid.uuid4().hex[:6]),
+          artifact_mode=self.artifact_mode)
+      # Determine if a new ptransform should be added to the list
+      is_different_type = (type(current_ptransform) != 
previous_ptransform_type)
+      if is_different_type or not transform.requires_chaining():
+        ptransform_list.append(current_ptransform)
+        previous_ptransform_type = type(current_ptransform)
+
+      if hasattr(ptransform_list[-1], 'append_transform'):
+        ptransform_list[-1].append_transform(transform)

Review Comment:
   One possible idea is to mandate that all `PTransformProvider` objects have 
an `append_transform` function that returns `True/False` depending on whether 
the transform could be appended and just return `False` if its a no-op (which 
could be done in the base class). Then we could (a) check if its the same type, 
(b) try appending, and (c) add it to the list if appending failed



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -42,12 +64,68 @@
 OperationOutputT = TypeVar('OperationOutputT')
 
 
+def _convert_list_of_dicts_to_dict_of_lists(
+    list_of_dicts: Sequence[Dict[str, Any]]) -> Dict[str, List[Any]]:
+  keys_to_element_list = collections.defaultdict(list)
+  for d in list_of_dicts:
+    for key, value in d.items():
+      keys_to_element_list[key].append(value)
+  return keys_to_element_list
+
+
+def _convert_dict_of_lists_to_lists_of_dict(
+    dict_of_lists: Dict[str, List[Any]],
+    batch_length: int) -> List[Dict[str, Any]]:
+  result: List[Dict[str, Any]] = [{} for _ in range(batch_length)]
+  for key, values in dict_of_lists.items():
+    for i in range(len(values)):
+      result[i][key] = values[i]
+  return result
+
+
 class ArtifactMode(object):
   PRODUCE = 'produce'
   CONSUME = 'consume'
 
 
-class BaseOperation(Generic[OperationInputT, OperationOutputT], abc.ABC):
+class PTransformProvider:
+  """
+  Data processing transforms that are intended to be used with MLTransform
+  should subclass PTransformProvider and implement the following methods:
+  1. get_ptransform_for_processing()
+  2. requires_chaining()
+
+  get_ptransform_for_processing() method should return a PTransform that can be
+  used to process the data.
+
+  requires_chaining() method should return True if the data processing
+  transforms needs to be chained sequentially with compatible data processing
+  transforms.
+  """
+  @abc.abstractmethod
+  def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform:
+    """
+    Returns a PTransform that can be used to process the data.
+    """
+
+  @abc.abstractmethod
+  def requires_chaining(self):
+    """
+    Returns True if the data processing transforms needs to be chained
+    sequentially with compatible data processing transforms.

Review Comment:
   We should add some detail on what this means. It is unclear to me on my 
initial pass what this does without understanding what other transforms do.



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -42,12 +64,68 @@
 OperationOutputT = TypeVar('OperationOutputT')
 
 
+def _convert_list_of_dicts_to_dict_of_lists(
+    list_of_dicts: Sequence[Dict[str, Any]]) -> Dict[str, List[Any]]:
+  keys_to_element_list = collections.defaultdict(list)
+  for d in list_of_dicts:
+    for key, value in d.items():
+      keys_to_element_list[key].append(value)
+  return keys_to_element_list
+
+
+def _convert_dict_of_lists_to_lists_of_dict(
+    dict_of_lists: Dict[str, List[Any]],
+    batch_length: int) -> List[Dict[str, Any]]:
+  result: List[Dict[str, Any]] = [{} for _ in range(batch_length)]
+  for key, values in dict_of_lists.items():
+    for i in range(len(values)):
+      result[i][key] = values[i]
+  return result
+
+
 class ArtifactMode(object):
   PRODUCE = 'produce'
   CONSUME = 'consume'
 
 
-class BaseOperation(Generic[OperationInputT, OperationOutputT], abc.ABC):
+class PTransformProvider:

Review Comment:
   Maybe MLTransformProvider?



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -209,10 +298,33 @@ def expand(
     Returns:
       A PCollection of MLTransformOutputT type
     """
+    _ = [self._validate_transform(transform) for transform in self.transforms]
+    if self._artifact_mode == ArtifactMode.PRODUCE:
+      ptransform_partitioner = _MLTransformToPTransformMapper(
+          transforms=self.transforms,
+          artifact_location=self._parent_artifact_location,
+          artifact_mode=self._artifact_mode,
+          pipeline_options=pcoll.pipeline.options)
+      ptransform_list = 
ptransform_partitioner.create_and_save_ptransform_list()
+    else:
+      ptransform_list = (
+          
_MLTransformToPTransformMapper.load_transforms_from_artifact_location(
+              self._parent_artifact_location))
+
+    # the saved transforms has artifact mode set to PRODUCE.
+    # set the artifact mode to CONSUME.
+    if self._artifact_mode == ArtifactMode.CONSUME:
+      for i in range(len(ptransform_list)):
+        if hasattr(ptransform_list[i], 'artifact_mode'):
+          ptransform_list[i].artifact_mode = self._artifact_mode

Review Comment:
   Nit: can we stick this block in the else above?



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -209,10 +298,33 @@ def expand(
     Returns:
       A PCollection of MLTransformOutputT type
     """
+    _ = [self._validate_transform(transform) for transform in self.transforms]
+    if self._artifact_mode == ArtifactMode.PRODUCE:
+      ptransform_partitioner = _MLTransformToPTransformMapper(
+          transforms=self.transforms,
+          artifact_location=self._parent_artifact_location,
+          artifact_mode=self._artifact_mode,
+          pipeline_options=pcoll.pipeline.options)
+      ptransform_list = 
ptransform_partitioner.create_and_save_ptransform_list()
+    else:
+      ptransform_list = (
+          
_MLTransformToPTransformMapper.load_transforms_from_artifact_location(
+              self._parent_artifact_location))
+
+    # the saved transforms has artifact mode set to PRODUCE.
+    # set the artifact mode to CONSUME.
+    if self._artifact_mode == ArtifactMode.CONSUME:
+      for i in range(len(ptransform_list)):
+        if hasattr(ptransform_list[i], 'artifact_mode'):
+          ptransform_list[i].artifact_mode = self._artifact_mode
+
+    for ptransform in ptransform_list:
+      pcoll = pcoll | ptransform
+
     _ = (
         pcoll.pipeline
         | "MLTransformMetricsUsage" >> MLTransformMetricsUsage(self))
-    return self._process_handler.process_data(pcoll)
+    return pcoll  # type: ignore[return-value]
 
   def with_transform(self, transform: BaseOperation):

Review Comment:
   Nit: Type needs to be updated to `PTransformProvider`



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -254,3 +371,243 @@ def _increment_counters():
         pipeline
         | beam.Create([None])
         | beam.Map(lambda _: _increment_counters()))
+
+
+class _TransformAttributeManager:
+  """
+  Base class used for saving and loading the attributes.
+  """
+  @staticmethod
+  def save_attributes(artifact_location):
+    """
+    Save the attributes to json file using stdlib json.
+    """
+    raise NotImplementedError
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    """
+    Load the attributes from json file.
+    """
+    raise NotImplementedError
+
+
+class _JsonPickleTransformAttributeManager(_TransformAttributeManager):
+  """
+  Use Jsonpickle to save and load the attributes. Here the attributes refer
+  to the list of PTransforms that are used to process the data.
+
+  jsonpickle is used to serialize the PTransforms and save it to a json file 
and
+  is compatible across python versions.
+  """
+  @staticmethod
+  def _is_remote_path(path):
+    is_gcs = path.find('gs://') != -1
+    # TODO: Add support for other remote paths.
+    if not is_gcs and path.find('://') != -1:
+      raise RuntimeError(
+          "Artifact locations are currently supported for only available for "
+          "local paths and GCS paths. Got: %s" % path)
+    return is_gcs
+
+  @staticmethod
+  def save_attributes(
+      ptransform_list,
+      artifact_location,
+      **kwargs,
+  ):
+    if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location):
+      try:
+        options = kwargs.get('options')
+      except KeyError:
+        raise RuntimeError(
+            'pipeline options are required to save the attributes.'
+            'in the artifact location %s' % artifact_location)
+
+      temp_dir = tempfile.mkdtemp()
+      temp_json_file = os.path.join(temp_dir, _ATTRIBUTE_FILE_NAME)
+      with open(temp_json_file, 'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+      with open(temp_json_file, 'rb') as f:
+        from apache_beam.runners.dataflow.internal import apiclient
+        _LOGGER.info('Creating artifact location: %s', artifact_location)
+        apiclient.DataflowApplicationClient(options=options).stage_file(
+            gcs_or_local_path=artifact_location,
+            file_name=_ATTRIBUTE_FILE_NAME,
+            stream=f,
+            mime_type='application/json')
+    else:
+      if not FileSystems.exists(artifact_location):
+        FileSystems.mkdirs(artifact_location)
+      # FileSystems.open() fails if the file does not exist.
+      with open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME),
+                'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    with FileSystems.open(os.path.join(artifact_location, 
_ATTRIBUTE_FILE_NAME),
+                          'rb') as f:
+      return jsonpickle.decode(f.read())
+
+
+_transform_attribute_manager = _JsonPickleTransformAttributeManager
+
+
+class _MLTransformToPTransformMapper:
+  """
+  This class takes in a list of data processing transforms compatible to be
+  wrapped around MLTransform and returns a list of PTransforms that are used to
+  run the data processing transforms.
+
+  The _MLTransformToPTransformMapper is responsible for loading and saving the
+  PTransforms or attributes of PTransforms to the artifact location to seal
+  the gap between the training and inference pipelines.
+  """
+  def __init__(
+      self,
+      transforms: List[Union[BaseOperation, EmbeddingsManager]],
+      artifact_location: str,
+      artifact_mode: str,
+      pipeline_options: Optional[PipelineOptions] = None,
+  ):
+    self.transforms = transforms
+    self._parent_artifact_location = artifact_location
+    self.artifact_mode = artifact_mode
+    self.pipeline_options = pipeline_options
+
+  def create_and_save_ptransform_list(self):
+    ptransform_list = self.create_ptransform_list()
+    self.save_transforms_in_artifact_location(ptransform_list)
+    return ptransform_list
+
+  def create_ptransform_list(self):
+    previous_ptransform_type = None
+    current_ptransform = None
+    ptransform_list = []
+    for transform in self.transforms:
+      if not isinstance(transform, PTransformProvider):
+        raise RuntimeError(
+            'Transforms must be instances of PTransformProvider and '
+            'implement get_ptransform_for_processing() method.')
+      # for each instance of PTransform, create a new artifact location
+      current_ptransform = transform.get_ptransform_for_processing(
+          artifact_location=os.path.join(
+              self._parent_artifact_location, uuid.uuid4().hex[:6]),
+          artifact_mode=self.artifact_mode)
+      # Determine if a new ptransform should be added to the list
+      is_different_type = (type(current_ptransform) != 
previous_ptransform_type)
+      if is_different_type or not transform.requires_chaining():
+        ptransform_list.append(current_ptransform)
+        previous_ptransform_type = type(current_ptransform)
+
+      if hasattr(ptransform_list[-1], 'append_transform'):
+        ptransform_list[-1].append_transform(transform)

Review Comment:
   Should this be an else if for the previous condition? E.g. if 
`is_different_type` is True, we'll append `current_ptransform`, but if that has 
an `append_transform` function we'll end up appending it to itself.



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -112,7 +212,8 @@ def __init__(
       *,
       write_artifact_location: Optional[str] = None,
       read_artifact_location: Optional[str] = None,
-      transforms: Optional[Sequence[BaseOperation]] = None):
+      transforms: Optional[List[Union[BaseOperation,

Review Comment:
   Nit: type here could be `List[PTransformProvider]`, right?



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -254,3 +371,243 @@ def _increment_counters():
         pipeline
         | beam.Create([None])
         | beam.Map(lambda _: _increment_counters()))
+
+
+class _TransformAttributeManager:
+  """
+  Base class used for saving and loading the attributes.
+  """
+  @staticmethod
+  def save_attributes(artifact_location):
+    """
+    Save the attributes to json file using stdlib json.
+    """
+    raise NotImplementedError
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    """
+    Load the attributes from json file.
+    """
+    raise NotImplementedError
+
+
+class _JsonPickleTransformAttributeManager(_TransformAttributeManager):
+  """
+  Use Jsonpickle to save and load the attributes. Here the attributes refer
+  to the list of PTransforms that are used to process the data.
+
+  jsonpickle is used to serialize the PTransforms and save it to a json file 
and
+  is compatible across python versions.
+  """
+  @staticmethod
+  def _is_remote_path(path):
+    is_gcs = path.find('gs://') != -1
+    # TODO: Add support for other remote paths.
+    if not is_gcs and path.find('://') != -1:
+      raise RuntimeError(
+          "Artifact locations are currently supported for only available for "
+          "local paths and GCS paths. Got: %s" % path)
+    return is_gcs
+
+  @staticmethod
+  def save_attributes(
+      ptransform_list,
+      artifact_location,
+      **kwargs,
+  ):
+    if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location):
+      try:
+        options = kwargs.get('options')
+      except KeyError:
+        raise RuntimeError(
+            'pipeline options are required to save the attributes.'
+            'in the artifact location %s' % artifact_location)
+
+      temp_dir = tempfile.mkdtemp()
+      temp_json_file = os.path.join(temp_dir, _ATTRIBUTE_FILE_NAME)
+      with open(temp_json_file, 'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+      with open(temp_json_file, 'rb') as f:
+        from apache_beam.runners.dataflow.internal import apiclient
+        _LOGGER.info('Creating artifact location: %s', artifact_location)
+        apiclient.DataflowApplicationClient(options=options).stage_file(
+            gcs_or_local_path=artifact_location,
+            file_name=_ATTRIBUTE_FILE_NAME,
+            stream=f,
+            mime_type='application/json')
+    else:
+      if not FileSystems.exists(artifact_location):
+        FileSystems.mkdirs(artifact_location)
+      # FileSystems.open() fails if the file does not exist.
+      with open(os.path.join(artifact_location, _ATTRIBUTE_FILE_NAME),
+                'w+') as f:
+        f.write(jsonpickle.encode(ptransform_list))
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    with FileSystems.open(os.path.join(artifact_location, 
_ATTRIBUTE_FILE_NAME),
+                          'rb') as f:
+      return jsonpickle.decode(f.read())
+
+
+_transform_attribute_manager = _JsonPickleTransformAttributeManager
+
+
+class _MLTransformToPTransformMapper:
+  """
+  This class takes in a list of data processing transforms compatible to be
+  wrapped around MLTransform and returns a list of PTransforms that are used to
+  run the data processing transforms.
+
+  The _MLTransformToPTransformMapper is responsible for loading and saving the
+  PTransforms or attributes of PTransforms to the artifact location to seal
+  the gap between the training and inference pipelines.
+  """
+  def __init__(
+      self,
+      transforms: List[Union[BaseOperation, EmbeddingsManager]],
+      artifact_location: str,
+      artifact_mode: str,
+      pipeline_options: Optional[PipelineOptions] = None,
+  ):
+    self.transforms = transforms
+    self._parent_artifact_location = artifact_location
+    self.artifact_mode = artifact_mode
+    self.pipeline_options = pipeline_options
+
+  def create_and_save_ptransform_list(self):
+    ptransform_list = self.create_ptransform_list()
+    self.save_transforms_in_artifact_location(ptransform_list)
+    return ptransform_list
+
+  def create_ptransform_list(self):
+    previous_ptransform_type = None
+    current_ptransform = None
+    ptransform_list = []
+    for transform in self.transforms:
+      if not isinstance(transform, PTransformProvider):
+        raise RuntimeError(
+            'Transforms must be instances of PTransformProvider and '
+            'implement get_ptransform_for_processing() method.')
+      # for each instance of PTransform, create a new artifact location
+      current_ptransform = transform.get_ptransform_for_processing(
+          artifact_location=os.path.join(
+              self._parent_artifact_location, uuid.uuid4().hex[:6]),
+          artifact_mode=self.artifact_mode)
+      # Determine if a new ptransform should be added to the list
+      is_different_type = (type(current_ptransform) != 
previous_ptransform_type)
+      if is_different_type or not transform.requires_chaining():
+        ptransform_list.append(current_ptransform)
+        previous_ptransform_type = type(current_ptransform)
+
+      if hasattr(ptransform_list[-1], 'append_transform'):
+        ptransform_list[-1].append_transform(transform)

Review Comment:
   Also, do actually need the `requires_chaining` function? Isn't it enough to 
just know that the transform itself has an `append_transform` function?



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -254,3 +371,243 @@ def _increment_counters():
         pipeline
         | beam.Create([None])
         | beam.Map(lambda _: _increment_counters()))
+
+
+class _TransformAttributeManager:
+  """
+  Base class used for saving and loading the attributes.
+  """
+  @staticmethod
+  def save_attributes(artifact_location):
+    """
+    Save the attributes to json file using stdlib json.
+    """
+    raise NotImplementedError
+
+  @staticmethod
+  def load_attributes(artifact_location):
+    """
+    Load the attributes from json file.
+    """
+    raise NotImplementedError
+
+
+class _JsonPickleTransformAttributeManager(_TransformAttributeManager):
+  """
+  Use Jsonpickle to save and load the attributes. Here the attributes refer
+  to the list of PTransforms that are used to process the data.
+
+  jsonpickle is used to serialize the PTransforms and save it to a json file 
and
+  is compatible across python versions.
+  """
+  @staticmethod
+  def _is_remote_path(path):
+    is_gcs = path.find('gs://') != -1
+    # TODO: Add support for other remote paths.
+    if not is_gcs and path.find('://') != -1:
+      raise RuntimeError(
+          "Artifact locations are currently supported for only available for "
+          "local paths and GCS paths. Got: %s" % path)
+    return is_gcs
+
+  @staticmethod
+  def save_attributes(
+      ptransform_list,
+      artifact_location,
+      **kwargs,
+  ):
+    if _JsonPickleTransformAttributeManager._is_remote_path(artifact_location):
+      try:
+        options = kwargs.get('options')
+      except KeyError:
+        raise RuntimeError(
+            'pipeline options are required to save the attributes.'
+            'in the artifact location %s' % artifact_location)

Review Comment:
   Rather than throwing here, could we try creating the client and throw that 
exception instead wrapped with this info?
   
   That will potentially provide a richer exception with more info about what 
options may be missing



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -222,14 +334,21 @@ def with_transform(self, transform: BaseOperation):
     Returns:
       A MLTransform instance.
     """
-    self._validate_transform(transform)
-    self._process_handler.append_transform(transform)
+    # self._validate_transform(transform)
+    # avoid circular import
+    # pylint: disable=wrong-import-order, wrong-import-position

Review Comment:
   Why do we need this here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to