AnandInguva commented on code in PR #24911: URL: https://github.com/apache/beam/pull/24911#discussion_r1072234910
########## sdks/python/apache_beam/ml/inference/onnx_inference.py: ########## @@ -0,0 +1,145 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pickle +import sys +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Union + +import numpy +import pandas +import onnx +import onnxruntime as ort + +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult + +try: + import joblib +except ImportError: + # joblib is an optional dependency. + pass + +__all__ = [ + 'OnnxModelHandlerNumpy' +] + +NumpyInferenceFn = Callable[ + [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]], + Iterable[PredictionResult]] + + +def _convert_to_result( + batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]] +) -> Iterable[PredictionResult]: + if isinstance(predictions, dict): + # Go from one dictionary of type: {key_type1: Iterable<val_type1>, + # key_type2: Iterable<val_type2>, ...} where each Iterable is of + # length batch_size, to a list of dictionaries: + # [{key_type1: value_type1, key_type2: value_type2}] + predictions_per_tensor = [ + dict(zip(predictions.keys(), v)) for v in zip(*predictions.values()) + ] + return [ + PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor) + ] + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + +def default_numpy_inference_fn( + inference_session: ort.InferenceSession, + batch: Sequence[numpy.ndarray], + inference_args: Optional[Dict[str, Any]] = None) -> Any: + ort_inputs = {inference_session.get_inputs()[0].name: numpy.stack(batch, axis=0)} + ort_outs = inference_session.run(None, ort_inputs, inference_args) Review Comment: I am not sure if passing inference_args like this is the right way. Let us assume model predict call accepts ``` def predict(inputs, dropout=True): .... # some processing on inputs, calculate output_1, output_2 if dropout: return output_1 else: return output_2 ``` We pass the dropout argument via inference_args: `inference_args = {'dropout`: True}. How would we pass this extra input to the `inference_session` ? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
