This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/env_vars in repository https://gitbox.apache.org/repos/asf/beam.git
commit a83f44e93d9474278df66fb7a541c913fd793a3f Author: Danny McCormick <[email protected]> AuthorDate: Mon Oct 30 12:22:14 2023 -0400 Don't assume env vars are set in model handler --- sdks/python/apache_beam/ml/inference/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 45c5078c13c..fc8ac59a1fb 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -512,7 +512,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], 'postprocessing functions defined into a keyed model handler. All ' 'pre/postprocessing functions must be defined on the outer model' 'handler.') - self._env_vars = unkeyed._env_vars + self._env_vars = getattr(unkeyed, '_env_vars', {}) self._unkeyed = unkeyed return @@ -553,7 +553,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], 'overriding the KeyedModelHandler.batch_elements_kwargs() method.', hints, batch_kwargs) - env_vars = mh._env_vars + env_vars = getattr(mh, '_env_vars', {}) if len(env_vars) > 0: logging.warning( 'mh %s defines the following _env_vars which will be ignored %s. ' @@ -816,7 +816,7 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], 'pre/postprocessing functions must be defined on the outer model' 'handler.') self._unkeyed = unkeyed - self._env_vars = unkeyed._env_vars + self._env_vars = getattr(unkeyed, '_env_vars', {}) def load_model(self) -> ModelT: return self._unkeyed.load_model() @@ -895,7 +895,7 @@ class _PreProcessingModelHandler(Generic[ExampleT, preprocess_fn: the preprocessing function to use. """ self._base = base - self._env_vars = base._env_vars + self._env_vars = getattr(base, '_env_vars', {}) self._preprocess_fn = preprocess_fn def load_model(self) -> ModelT: @@ -951,7 +951,7 @@ class _PostProcessingModelHandler(Generic[ExampleT, postprocess_fn: the preprocessing function to use. """ self._base = base - self._env_vars = base._env_vars + self._env_vars = getattr(base, '_env_vars', {}) self._postprocess_fn = postprocess_fn def load_model(self) -> ModelT:
