This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 722dc6daa32 Allow inference args to be passed in for most cases
(#37094)
722dc6daa32 is described below
commit 722dc6daa32958677fd72cfbc97faad3bc29b8b4
Author: Danny McCormick <[email protected]>
AuthorDate: Mon Dec 15 10:27:19 2025 -0500
Allow inference args to be passed in for most cases (#37094)
* Allow inference args to be passed in for most cases
* CHANGES
* tests
* yapf
---
CHANGES.md | 1 +
sdks/python/apache_beam/ml/inference/base.py | 13 +++++--------
sdks/python/apache_beam/ml/inference/base_test.py | 6 ++++++
sdks/python/apache_beam/ml/inference/pytorch_inference.py | 6 ------
sdks/python/apache_beam/ml/inference/sklearn_inference.py | 3 ++-
.../python/apache_beam/ml/inference/tensorflow_inference.py | 6 ------
sdks/python/apache_beam/ml/inference/tensorrt_inference.py | 10 ++++++++++
sdks/python/apache_beam/ml/inference/vertex_ai_inference.py | 3 ---
8 files changed, 24 insertions(+), 24 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 09e24963044..cdb56a28c00 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -73,6 +73,7 @@
## New Features / Improvements
* Support configuring Firestore database on ReadFn transforms (Java)
([#36904](https://github.com/apache/beam/issues/36904)).
+* (Python) Inference args are now allowed in most model handlers, except where
they are explicitly/intentionally disallowed
([#37093](https://github.com/apache/beam/issues/37093)).
## Breaking Changes
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index 2e1c4963f11..d79565ee24d 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -213,15 +213,12 @@ class ModelHandler(Generic[ExampleT, PredictionT,
ModelT]):
return {}
def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
- """Validates inference_args passed in the inference call.
-
- Because most frameworks do not need extra arguments in their predict()
call,
- the default behavior is to error out if inference_args are present.
"""
- if inference_args:
- raise ValueError(
- 'inference_args were provided, but should be None because this '
- 'framework does not expect extra arguments on inferences.')
+ Allows model handlers to provide some validation to make sure passed in
+ inference args are valid. Some ModelHandlers throw here to disallow
+ inference args altogether.
+ """
+ pass
def update_model_path(self, model_path: Optional[str] = None):
"""
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py
b/sdks/python/apache_beam/ml/inference/base_test.py
index 66e85ce163e..574e71de89c 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -293,6 +293,12 @@ class
FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler):
'run_inference should not be called because error should already be '
'thrown from the validate_inference_args check.')
+ def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
+ if inference_args:
+ raise ValueError(
+ 'inference_args were provided, but should be None because this '
+ 'framework does not expect extra arguments on inferences.')
+
class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
def run_inference(self, batch, unused_model, inference_args=None):
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index f73eeff808c..affbcd977f5 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -342,9 +342,6 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
"""
return 'BeamML_PyTorch'
- def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
- pass
-
def batch_elements_kwargs(self):
return self._batching_kwargs
@@ -590,9 +587,6 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[dict[str,
torch.Tensor],
"""
return 'BeamML_PyTorch'
- def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
- pass
-
def batch_elements_kwargs(self):
return self._batching_kwargs
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 1e5962ba64c..84947bec3df 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -73,9 +73,10 @@ def _default_numpy_inference_fn(
model: BaseEstimator,
batch: Sequence[numpy.ndarray],
inference_args: Optional[dict[str, Any]] = None) -> Any:
+ inference_args = {} if not inference_args else inference_args
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
- return model.predict(vectorized_batch)
+ return model.predict(vectorized_batch, **inference_args)
class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
index d13ea53cf1b..5ce293a06ac 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -219,9 +219,6 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
"""
return 'BeamML_TF_Numpy'
- def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
- pass
-
def batch_elements_kwargs(self):
return self._batching_kwargs
@@ -360,9 +357,6 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor,
PredictionResult,
"""
return 'BeamML_TF_Tensor'
- def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
- pass
-
def batch_elements_kwargs(self):
return self._batching_kwargs
diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
index 1b11bd9f39e..b575dfa849d 100644
--- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
@@ -341,3 +341,13 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
def model_copies(self) -> int:
return self._model_copies
+
+ def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
+ """
+ Currently, this model handler does not support inference args. Given that,
+ we will throw if any are passed in.
+ """
+ if inference_args:
+ raise ValueError(
+ 'inference_args were provided, but should be None because this '
+ 'framework does not expect extra arguments on inferences.')
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
index 471f2379cfb..9858b59039c 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
@@ -207,8 +207,5 @@ class VertexAIModelHandlerJSON(RemoteModelHandler[Any,
return utils._convert_to_result(
batch, prediction.predictions, prediction.deployed_model_id)
- def validate_inference_args(self, inference_args: Optional[dict[str, Any]]):
- pass
-
def batch_elements_kwargs(self) -> Mapping[str, Any]:
return self._batching_kwargs