Polber commented on code in PR #33406:
URL: https://github.com/apache/beam/pull/33406#discussion_r1890761642
##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -33,11 +40,419 @@
tft = None # type: ignore
+def normalize_ml(spec):
+ if spec['type'] == 'RunInference':
+ config = spec.get('config')
+ for required in ('model_handler', ):
+ if required not in config:
+ raise ValueError(
+ f'Missing {required} parameter in RunInference config '
+ f'at line {SafeLineLoader.get_line(spec)}')
+ model_handler = config.get('model_handler')
+ if not isinstance(model_handler, dict):
+ raise ValueError(
+ 'Invalid model_handler specification at line '
+ f'{SafeLineLoader.get_line(spec)}. Expected '
+ f'dict but was {type(model_handler)}.')
+ for required in ('type', 'config'):
+ if required not in model_handler:
+ raise ValueError(
+ f'Missing {required} in model handler '
+ f'at line {SafeLineLoader.get_line(model_handler)}')
+ typ = model_handler['type']
+ extra_params = set(SafeLineLoader.strip_metadata(model_handler).keys()) - {
+ 'type', 'config'
+ }
+ if extra_params:
+ raise ValueError(
+ f'Unexpected parameters in model handler of type {typ} '
+ f'at line {SafeLineLoader.get_line(spec)}: {extra_params}')
+ model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None)
+ if model_handler_provider:
+ model_handler_provider.validate(model_handler['config'])
+ else:
+ raise NotImplementedError(
+ f'Unknown model handler type: {typ} '
+ f'at line {SafeLineLoader.get_line(spec)}.')
+
+ return spec
+
+
+class ModelHandlerProvider:
+ handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {}
+
+ def __init__(
+ self, handler, preprocess: Callable = None, postprocess: Callable =
None):
+ self._handler = handler
+ self._preprocess = self.parse_processing_transform(
+ preprocess, 'preprocess') or self.preprocess_fn
+ self._postprocess = self.parse_processing_transform(
+ postprocess, 'postprocess') or self.postprocess_fn
+
+ def get_output_schema(self):
+ return Any
+
+ @staticmethod
+ def parse_processing_transform(processing_transform, typ):
+ def _parse_config(callable=None, path=None, name=None):
+ if callable and (path or name):
+ raise ValueError(
+ f"Cannot specify 'callable' with 'path' and 'name' for {typ} "
+ f"function.")
+ if path and name:
+ return python_callable.PythonCallableWithSource.load_from_script(
+ FileSystems.open(path).read().decode(), name)
+ elif callable:
+ return python_callable.PythonCallableWithSource(callable)
+ else:
+ raise ValueError(
+ f"Must specify one of 'callable' or 'path' and 'name' for {typ} "
+ f"function.")
+
+ if processing_transform:
+ if isinstance(processing_transform, dict):
+ return _parse_config(**processing_transform)
+ else:
+ raise ValueError("Invalid model_handler specification.")
+
+ def underlying_handler(self):
+ return self._handler
+
+ def preprocess_fn(self, row):
+ raise ValueError(
+ 'Handler does not implement a default preprocess '
+ 'method. Please define a preprocessing method using the '
+ '\'preprocess\' tag.')
+
+ def create_preprocess_fn(self):
+ return lambda row: (row, self._preprocess(row))
+
+ @staticmethod
+ def postprocess_fn(x):
+ return x
+
+ def create_postprocess_fn(self):
+ return lambda result: (result[0], self._postprocess(result[1]))
+
+ @staticmethod
+ def validate(model_handler_spec):
+ raise NotImplementedError(type(ModelHandlerProvider))
+
+ @classmethod
+ def register_handler_type(cls, type_name):
+ def apply(constructor):
+ cls.handler_types[type_name] = constructor
+ return constructor
+
+ return apply
+
+ @classmethod
+ def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider":
+ typ = model_handler_spec['type']
+ config = model_handler_spec['config']
+ try:
+ result = cls.handler_types[typ](**config)
+ if not hasattr(result, 'to_json'):
+ result.to_json = lambda: model_handler_spec
+ return result
+ except Exception as exn:
+ raise ValueError(
+ f'Unable to instantiate model handler of type {typ}. {exn}')
+
+
[email protected]_handler_type('VertexAIModelHandlerJSON')
+class VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
+ def __init__(
+ self,
+ endpoint_id: str,
+ endpoint_project: str,
+ endpoint_region: str,
+ experiment: Optional[str] = None,
+ network: Optional[str] = None,
+ private: bool = False,
+ min_batch_size: Optional[int] = None,
+ max_batch_size: Optional[int] = None,
+ max_batch_duration_secs: Optional[int] = None,
+ env_vars=None,
+ preprocess: Callable = None,
+ postprocess: Callable = None):
+ """ModelHandler for Vertex AI.
+
+ For example: ::
+
+ - type: RunInference
+ config:
+ inference_tag: 'my_inference'
+ model_handler:
+ type: VertexAIModelHandlerJSON
+ config:
+ endpoint_id: 9876543210
Review Comment:
Not that I am aware of
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]