---
 libavfilter/dnn/dnn_backend_torch.cpp | 38 +++++++++++++++++++++------
 1 file changed, 30 insertions(+), 8 deletions(-)

diff --git a/libavfilter/dnn/dnn_backend_torch.cpp 
b/libavfilter/dnn/dnn_backend_torch.cpp
index 4f7ae17aab..73eadc6b7e 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -255,16 +255,30 @@ static int th_start_inference(void *args)
     LastLevelTaskItem *lltask = request->lltask;
     TaskItem *task = lltask->task;
     THModel *th_model = (THModel *)task->model;
+    DnnContext *ctx = th_model->ctx;
     std::vector<torch::jit::IValue> inputs;
 
-    
torch::jit::setGraphExecutorOptimize(!!th_model->ctx->torch_option.optimize);
+    torch::jit::setGraphExecutorOptimize(!!ctx->torch_option.optimize);
+
+    if (!infer_request->input_tensor || !infer_request->output) {
+        av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
+        return DNN_GENERIC_ERROR;
+    }
+
+    const char *device_name = ctx->device ? ctx->device : "cpu";
+    c10::Device device = c10::Device(device_name);
 
-    c10::Device device = (*th_model->jit_model->parameters().begin()).device();
     if (infer_request->input_tensor->device() != device)
         *infer_request->input_tensor = infer_request->input_tensor->to(device);
 
     inputs.push_back(*infer_request->input_tensor);
-    *infer_request->output = th_model->jit_model->forward(inputs).toTensor();
+
+    try {
+        *infer_request->output = 
th_model->jit_model->forward(inputs).toTensor();
+    } catch (const c10::Error& e) {
+        av_log(ctx, AV_LOG_ERROR, "Torch forward pass failed: %s\n", e.what());
+        return DNN_GENERIC_ERROR;
+    }
 
     return 0;
 }
@@ -415,14 +429,23 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, 
DNNFunctionType func_type, A
 {
     THModel *th_model = av_mallocz(sizeof(THModel));
     THRequestItem *item = NULL;
+    const char *device_name = ctx->device ? ctx->device : "cpu";
 
     if (!th_model)
         return NULL;
 
     th_model->ctx = ctx;
-    th_model->jit_model = new torch::jit::Module;
-    // Commit 1 uses the simplest loading logic
-    *th_model->jit_model = torch::jit::load(ctx->model_filename);
+
+    // Robustness: Wrap model loading and device movement in try-catch
+    try {
+        c10::Device device = c10::Device(device_name);
+        th_model->jit_model = new torch::jit::Module;
+        (*th_model->jit_model) = torch::jit::load(ctx->model_filename);
+        th_model->jit_model->to(device);
+    } catch (const c10::Error& e) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to load torch model: %s\n", 
e.what());
+        goto fail;
+    }
 
     th_model->request_queue = ff_safe_queue_create();
     if (!th_model->request_queue)
@@ -436,7 +459,6 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, 
DNNFunctionType func_type, A
     if (!item->infer_request)
         goto fail;
 
-    // Infrastructure setup for Async Module
     item->exec_module.start_inference = &th_start_inference;
     item->exec_module.callback = &infer_completion_callback;
     item->exec_module.args = item;
@@ -463,7 +485,7 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, 
DNNFunctionType func_type, A
 fail:
     if (item)
         destroy_request_item(&item);
-    // Passing the address of the model pointer
+    
     DNNModel *temp_model = &th_model->model;
     dnn_free_model_th(&temp_model);
     return NULL;
-- 
2.51.0

_______________________________________________
ffmpeg-devel mailing list -- [email protected]
To unsubscribe send an email to [email protected]

Reply via email to