Rohanberiwal commented on issue #41702:
URL: https://github.com/apache/airflow/issues/41702#issuecomment-2309936325
Hi , i have worked on this issue from past two days and I came up with a
solution . I made certin chnage in the exisiting code and added the execute
function inside the operation class that does the same work that your
run_onnx_intefence() does . Please see this code and tell me if the code
anywhere matches the frequency of your expections .
Python
import onnxruntime as ort
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow import DAG
from datetime import datetime
class ONNXInferenceOperator(BaseOperator):
@apply_defaults
def __init__(self, model_path: str, input_data: dict, *args, **kwargs):
super(ONNXInferenceOperator, self).__init__(*args, **kwargs)
self.model_path = model_path
self.input_data = input_data
def execute(self, context):
session = ort.InferenceSession(self.model_path)
input_name = session.get_inputs()[0].name
result = session.run(None, {input_name: self.input_data})
self.log.info(f"Inference result: {result}")
return result
with DAG(
dag_id='onnx_inference_dag',
start_date=datetime(2023, 1, 1),
schedule_interval='@once',
catchup=False
) as dag:
inference_task = ONNXInferenceOperator(
task_id='onnx_inference_task',
model_path='/path/to/your/model.onnx',
input_data={"your_input_key": [[1.0, 2.0, 3.0]]}
)
inference_task
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]