This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/env_vars
in repository https://gitbox.apache.org/repos/asf/beam.git

commit a83f44e93d9474278df66fb7a541c913fd793a3f
Author: Danny McCormick <[email protected]>
AuthorDate: Mon Oct 30 12:22:14 2023 -0400

    Don't assume env vars are set in model handler
---
 sdks/python/apache_beam/ml/inference/base.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 45c5078c13c..fc8ac59a1fb 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -512,7 +512,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
             'postprocessing functions defined into a keyed model handler. All '
             'pre/postprocessing functions must be defined on the outer model'
             'handler.')
-      self._env_vars = unkeyed._env_vars
+      self._env_vars = getattr(unkeyed, '_env_vars', {})
       self._unkeyed = unkeyed
       return
 
@@ -553,7 +553,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
             'overriding the KeyedModelHandler.batch_elements_kwargs() method.',
             hints,
             batch_kwargs)
-      env_vars = mh._env_vars
+      env_vars = getattr(mh, '_env_vars', {})
       if len(env_vars) > 0:
         logging.warning(
             'mh %s defines the following _env_vars which will be ignored %s. '
@@ -816,7 +816,7 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
           'pre/postprocessing functions must be defined on the outer model'
           'handler.')
     self._unkeyed = unkeyed
-    self._env_vars = unkeyed._env_vars
+    self._env_vars = getattr(unkeyed, '_env_vars', {})
 
   def load_model(self) -> ModelT:
     return self._unkeyed.load_model()
@@ -895,7 +895,7 @@ class _PreProcessingModelHandler(Generic[ExampleT,
       preprocess_fn: the preprocessing function to use.
     """
     self._base = base
-    self._env_vars = base._env_vars
+    self._env_vars = getattr(base, '_env_vars', {})
     self._preprocess_fn = preprocess_fn
 
   def load_model(self) -> ModelT:
@@ -951,7 +951,7 @@ class _PostProcessingModelHandler(Generic[ExampleT,
       postprocess_fn: the preprocessing function to use.
     """
     self._base = base
-    self._env_vars = base._env_vars
+    self._env_vars = getattr(base, '_env_vars', {})
     self._postprocess_fn = postprocess_fn
 
   def load_model(self) -> ModelT:

Reply via email to