---
libavfilter/dnn/dnn_backend_torch.cpp | 38 +++++++++++++++++++++------
1 file changed, 30 insertions(+), 8 deletions(-)
diff --git a/libavfilter/dnn/dnn_backend_torch.cpp
b/libavfilter/dnn/dnn_backend_torch.cpp
index 4f7ae17aab..73eadc6b7e 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -255,16 +255,30 @@ static int th_start_inference(void *args)
LastLevelTaskItem *lltask = request->lltask;
TaskItem *task = lltask->task;
THModel *th_model = (THModel *)task->model;
+ DnnContext *ctx = th_model->ctx;
std::vector<torch::jit::IValue> inputs;
-
torch::jit::setGraphExecutorOptimize(!!th_model->ctx->torch_option.optimize);
+ torch::jit::setGraphExecutorOptimize(!!ctx->torch_option.optimize);
+
+ if (!infer_request->input_tensor || !infer_request->output) {
+ av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
+ return DNN_GENERIC_ERROR;
+ }
+
+ const char *device_name = ctx->device ? ctx->device : "cpu";
+ c10::Device device = c10::Device(device_name);
- c10::Device device = (*th_model->jit_model->parameters().begin()).device();
if (infer_request->input_tensor->device() != device)
*infer_request->input_tensor = infer_request->input_tensor->to(device);
inputs.push_back(*infer_request->input_tensor);
- *infer_request->output = th_model->jit_model->forward(inputs).toTensor();
+
+ try {
+ *infer_request->output =
th_model->jit_model->forward(inputs).toTensor();
+ } catch (const c10::Error& e) {
+ av_log(ctx, AV_LOG_ERROR, "Torch forward pass failed: %s\n", e.what());
+ return DNN_GENERIC_ERROR;
+ }
return 0;
}
@@ -415,14 +429,23 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx,
DNNFunctionType func_type, A
{
THModel *th_model = av_mallocz(sizeof(THModel));
THRequestItem *item = NULL;
+ const char *device_name = ctx->device ? ctx->device : "cpu";
if (!th_model)
return NULL;
th_model->ctx = ctx;
- th_model->jit_model = new torch::jit::Module;
- // Commit 1 uses the simplest loading logic
- *th_model->jit_model = torch::jit::load(ctx->model_filename);
+
+ // Robustness: Wrap model loading and device movement in try-catch
+ try {
+ c10::Device device = c10::Device(device_name);
+ th_model->jit_model = new torch::jit::Module;
+ (*th_model->jit_model) = torch::jit::load(ctx->model_filename);
+ th_model->jit_model->to(device);
+ } catch (const c10::Error& e) {
+ av_log(ctx, AV_LOG_ERROR, "Failed to load torch model: %s\n",
e.what());
+ goto fail;
+ }
th_model->request_queue = ff_safe_queue_create();
if (!th_model->request_queue)
@@ -436,7 +459,6 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx,
DNNFunctionType func_type, A
if (!item->infer_request)
goto fail;
- // Infrastructure setup for Async Module
item->exec_module.start_inference = &th_start_inference;
item->exec_module.callback = &infer_completion_callback;
item->exec_module.args = item;
@@ -463,7 +485,7 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx,
DNNFunctionType func_type, A
fail:
if (item)
destroy_request_item(&item);
- // Passing the address of the model pointer
+
DNNModel *temp_model = &th_model->model;
dnn_free_model_th(&temp_model);
return NULL;
--
2.51.0
_______________________________________________
ffmpeg-devel mailing list -- [email protected]
To unsubscribe send an email to [email protected]