---
 libavfilter/dnn/dnn_backend_torch.cpp | 354 ++++++++------------------
 1 file changed, 113 insertions(+), 241 deletions(-)

diff --git a/libavfilter/dnn/dnn_backend_torch.cpp 
b/libavfilter/dnn/dnn_backend_torch.cpp
index 33809bf983..4c781cc0b6 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -25,10 +25,6 @@
 
 #include <torch/torch.h>
 #include <torch/script.h>
-#include <thread>
-#include <mutex>
-#include <condition_variable>
-#include <atomic>
 
 extern "C" {
 #include "dnn_io_proc.h"
@@ -46,11 +42,6 @@ typedef struct THModel {
     SafeQueue *request_queue;
     Queue *task_queue;
     Queue *lltask_queue;
-    SafeQueue *pending_queue;       ///< requests waiting for inference
-    std::thread *worker_thread;     ///< background worker thread
-    std::mutex *mutex;              ///< mutex for the condition variable
-    std::condition_variable *cond;  ///< condition variable for worker wakeup
-    std::atomic<bool> worker_stop;  ///< signal for thread exit
 } THModel;
 
 typedef struct THInferRequest {
@@ -64,7 +55,6 @@ typedef struct THRequestItem {
     DNNAsyncExecModule exec_module;
 } THRequestItem;
 
-
 #define OFFSET(x) offsetof(THOptions, x)
 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM
 static const AVOption dnn_th_options[] = {
@@ -104,15 +94,17 @@ static void th_free_request(THInferRequest *request)
         delete(request->input_tensor);
         request->input_tensor = NULL;
     }
-    return;
+    if (request->input_data) {
+        av_freep(&request->input_data);
+        request->input_data_size = 0;
+    }
 }
 
 static inline void destroy_request_item(THRequestItem **arg)
 {
     THRequestItem *item;
-    if (!arg || !*arg) {
+    if (!arg || !*arg)
         return;
-    }
     item = *arg;
     th_free_request(item->infer_request);
     av_freep(&item->infer_request);
@@ -129,38 +121,6 @@ static void dnn_free_model_th(DNNModel **model)
 
     th_model = (THModel *)(*model);
 
-    /* 1. Stop and join the worker thread if it exists */
-    if (th_model->worker_thread) {
-        {
-            std::lock_guard<std::mutex> lock(*th_model->mutex);
-            th_model->worker_stop = true;
-        }
-        th_model->cond->notify_all();
-        th_model->worker_thread->join();
-        delete th_model->worker_thread;
-        th_model->worker_thread = NULL;
-    }
-
-    /* 2. Safely delete C++ synchronization objects */
-    if (th_model->mutex) {
-        delete th_model->mutex;
-        th_model->mutex = NULL;
-    }
-    if (th_model->cond) {
-        delete th_model->cond;
-        th_model->cond = NULL;
-    }
-
-    /* 3. Clean up the pending queue */
-    if (th_model->pending_queue) {
-        while (ff_safe_queue_size(th_model->pending_queue) > 0) {
-            THRequestItem *item = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->pending_queue);
-            destroy_request_item(&item);
-        }
-        ff_safe_queue_destroy(th_model->pending_queue);
-    }
-
-    /* 4. Clean up standard backend queues */
     if (th_model->request_queue) {
         while (ff_safe_queue_size(th_model->request_queue) != 0) {
             THRequestItem *item = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->request_queue);
@@ -187,7 +147,6 @@ static void dnn_free_model_th(DNNModel **model)
         ff_queue_destroy(th_model->task_queue);
     }
 
-    /* 5. Final model cleanup */
     if (th_model->jit_model)
         delete th_model->jit_model;
 
@@ -214,37 +173,55 @@ static void deleter(void *arg)
 
 static int fill_model_input_th(THModel *th_model, THRequestItem *request)
 {
-    LastLevelTaskItem *lltask = NULL;
-    TaskItem *task = NULL;
     THInferRequest *infer_request = NULL;
+    TaskItem *task = NULL;
+    LastLevelTaskItem *lltask = NULL;
     DNNData input = { 0 };
     DnnContext *ctx = th_model->ctx;
     int ret, width_idx, height_idx, channel_idx;
+    size_t cur_size;
 
     lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue);
-    if (!lltask) {
-        ret = AVERROR(EINVAL);
-        goto err;
-    }
+    if (!lltask)
+        return AVERROR(EINVAL);
+
     request->lltask = lltask;
     task = lltask->task;
     infer_request = request->infer_request;
 
     ret = get_input_th(&th_model->model, &input, NULL);
-    if ( ret != 0) {
-        goto err;
-    }
+    if (ret)
+        return ret;
+
     width_idx = dnn_get_width_idx_by_layout(input.layout);
     height_idx = dnn_get_height_idx_by_layout(input.layout);
     channel_idx = dnn_get_channel_idx_by_layout(input.layout);
     input.dims[height_idx] = task->in_frame->height;
     input.dims[width_idx] = task->in_frame->width;
-    input.data = av_malloc(input.dims[height_idx] * input.dims[width_idx] *
-                           input.dims[channel_idx] * sizeof(float));
-    if (!input.data)
-        return AVERROR(ENOMEM);
-    infer_request->input_tensor = new torch::Tensor();
-    infer_request->output = new torch::Tensor();
+
+    // Calculate required size for the current frame
+    cur_size = input.dims[height_idx] * input.dims[width_idx] *
+               input.dims[channel_idx] * sizeof(float);
+
+    /**
+     * Dynamic Resizing Logic:
+     * Only reallocate if the existing buffer is too small or doesn't exist.
+     * Removed the (float *) cast to comply with FFmpeg style guidelines.
+     */
+    if (!infer_request->input_data || infer_request->input_data_size < 
cur_size) {
+        av_freep(&infer_request->input_data);
+        infer_request->input_data = av_malloc(cur_size);
+        if (!infer_request->input_data)
+            return AVERROR(ENOMEM);
+        infer_request->input_data_size = cur_size;
+    }
+
+    input.data = infer_request->input_data;
+
+    if (!infer_request->input_tensor)
+        infer_request->input_tensor = new torch::Tensor();
+    if (!infer_request->output)
+        infer_request->output = new torch::Tensor();
 
     switch (th_model->model.func_type) {
     case DFT_PROCESS_FRAME:
@@ -261,52 +238,30 @@ static int fill_model_input_th(THModel *th_model, 
THRequestItem *request)
         avpriv_report_missing_feature(NULL, "model function type %d", 
th_model->model.func_type);
         break;
     }
+
     *infer_request->input_tensor = torch::from_blob(input.data,
         {1, input.dims[channel_idx], input.dims[height_idx], 
input.dims[width_idx]},
         deleter, torch::kFloat32);
-    return 0;
 
-err:
-    th_free_request(infer_request);
-    return ret;
+    return 0;
 }
 
 static int th_start_inference(void *args)
 {
     THRequestItem *request = (THRequestItem *)args;
-    THInferRequest *infer_request = NULL;
-    LastLevelTaskItem *lltask = NULL;
-    TaskItem *task = NULL;
-    THModel *th_model = NULL;
-    DnnContext *ctx = NULL;
+    THInferRequest *infer_request = request->infer_request;
+    LastLevelTaskItem *lltask = request->lltask;
+    TaskItem *task = lltask->task;
+    THModel *th_model = (THModel *)task->model;
     std::vector<torch::jit::IValue> inputs;
-    torch::NoGradGuard no_grad;
-
-    if (!request) {
-        av_log(NULL, AV_LOG_ERROR, "THRequestItem is NULL\n");
-        return AVERROR(EINVAL);
-    }
-    infer_request = request->infer_request;
-    lltask = request->lltask;
-    task = lltask->task;
-    th_model = (THModel *)task->model;
-    ctx = th_model->ctx;
 
-    if (ctx->torch_option.optimize)
-        torch::jit::setGraphExecutorOptimize(true);
-    else
-        torch::jit::setGraphExecutorOptimize(false);
+    
torch::jit::setGraphExecutorOptimize(!!th_model->ctx->torch_option.optimize);
 
-    if (!infer_request->input_tensor || !infer_request->output) {
-        av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
-        return DNN_GENERIC_ERROR;
-    }
-    // Transfer tensor to the same device as model
     c10::Device device = (*th_model->jit_model->parameters().begin()).device();
     if (infer_request->input_tensor->device() != device)
         *infer_request->input_tensor = infer_request->input_tensor->to(device);
-    inputs.push_back(*infer_request->input_tensor);
 
+    inputs.push_back(*infer_request->input_tensor);
     *infer_request->output = th_model->jit_model->forward(inputs).toTensor();
 
     return 0;
@@ -325,13 +280,12 @@ static void infer_completion_callback(void *args) {
     outputs.order = DCO_RGB;
     outputs.layout = DL_NCHW;
     outputs.dt = DNN_FLOAT;
+
     if (sizes.size() == 4) {
-        // 4 dimensions: [batch_size, channel, height, width]
-        // this format of data is normally used for video frame SR
-        outputs.dims[0] = sizes.at(0); // N
-        outputs.dims[1] = sizes.at(1); // C
-        outputs.dims[2] = sizes.at(2); // H
-        outputs.dims[3] = sizes.at(3); // W
+        outputs.dims[0] = sizes.at(0);
+        outputs.dims[1] = sizes.at(1);
+        outputs.dims[2] = sizes.at(2);
+        outputs.dims[3] = sizes.at(3);
     } else {
         avpriv_report_missing_feature(th_model->ctx, "Support of this kind of 
model");
         goto err;
@@ -340,7 +294,6 @@ static void infer_completion_callback(void *args) {
     switch (th_model->model.func_type) {
     case DFT_PROCESS_FRAME:
         if (task->do_ioproc) {
-            // Post process can only deal with CPU memory.
             if (output->device() != torch::kCPU)
                 *output = output->to(torch::kCPU);
             outputs.scale = 255;
@@ -361,35 +314,11 @@ static void infer_completion_callback(void *args) {
     }
     task->inference_done++;
     av_freep(&request->lltask);
+
 err:
     th_free_request(infer_request);
-
-    if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) {
+    if (ff_safe_queue_push_back(th_model->request_queue, request) < 0)
         destroy_request_item(&request);
-        av_log(th_model->ctx, AV_LOG_ERROR, "Unable to push back request_queue 
when failed to start inference.\n");
-    }
-}
-
-static void th_worker_thread(THModel *th_model) {
-    while (true) {
-        THRequestItem *request = NULL;
-        {
-            std::unique_lock<std::mutex> lock(*th_model->mutex);
-            th_model->cond->wait(lock, [&]{
-                return th_model->worker_stop || 
ff_safe_queue_size(th_model->pending_queue) > 0;
-            });
-
-            if (th_model->worker_stop && 
ff_safe_queue_size(th_model->pending_queue) == 0)
-                break;
-
-            request = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->pending_queue);
-        }
-
-        if (request) {
-            th_start_inference(request);
-            infer_completion_callback(request);
-        }
-    }
 }
 
 static int execute_model_th(THRequestItem *request, Queue *lltask_queue)
@@ -405,32 +334,27 @@ static int execute_model_th(THRequestItem *request, Queue 
*lltask_queue)
     }
 
     lltask = (LastLevelTaskItem *)ff_queue_peek_front(lltask_queue);
-    if (lltask == NULL) {
-        av_log(NULL, AV_LOG_ERROR, "Failed to get LastLevelTaskItem\n");
-        ret = AVERROR(EINVAL);
-        goto err;
+    if (!lltask) {
+        destroy_request_item(&request);
+        return AVERROR(EINVAL);
     }
+
     task = lltask->task;
     th_model = (THModel *)task->model;
 
     ret = fill_model_input_th(th_model, request);
-    if ( ret != 0) {
-        goto err;
-    }
-    if (task->async) {
-        std::lock_guard<std::mutex> lock(*th_model->mutex);
-        if (ff_safe_queue_push_back(th_model->pending_queue, request) < 0) {
-            return AVERROR(ENOMEM);
-        }
-        th_model->cond->notify_one();
-        return 0;
+    if (ret) {
+        th_free_request(request->infer_request);
+        if (ff_safe_queue_push_back(th_model->request_queue, request) < 0)
+            destroy_request_item(&request);
+        return ret;
     }
 
-err:
-    th_free_request(request->infer_request);
-    if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) {
-        destroy_request_item(&request);
-    }
+    if (task->async)
+        return ff_dnn_async_module_submit(&request->exec_module);
+
+    ret = th_start_inference(request);
+    infer_completion_callback(request);
     return ret;
 }
 
@@ -449,29 +373,29 @@ static int get_output_th(DNNModel *model, const char 
*input_name, int input_widt
         .in_frame       = NULL,
         .out_frame      = NULL,
     };
+
     ret = ff_dnn_fill_gettingoutput_task(&task, &exec_params, th_model, 
input_height, input_width, ctx);
-    if ( ret != 0) {
-        goto err;
-    }
+    if (ret)
+        return ret;
 
     ret = extract_lltask_from_task(&task, th_model->lltask_queue);
-    if ( ret != 0) {
-        av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from 
task.\n");
-        goto err;
+    if (ret) {
+        av_frame_free(&task.out_frame);
+        av_frame_free(&task.in_frame);
+        return ret;
     }
 
     request = (THRequestItem*) 
ff_safe_queue_pop_front(th_model->request_queue);
     if (!request) {
-        av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
-        ret = AVERROR(EINVAL);
-        goto err;
+        av_frame_free(&task.out_frame);
+        av_frame_free(&task.in_frame);
+        return AVERROR(EINVAL);
     }
 
     ret = execute_model_th(request, th_model->lltask_queue);
     *output_width = task.out_frame->width;
     *output_height = task.out_frame->height;
 
-err:
     av_frame_free(&task.out_frame);
     av_frame_free(&task.in_frame);
     return ret;
@@ -479,105 +403,67 @@ err:
 
 static THInferRequest *th_create_inference_request(void)
 {
-    THInferRequest *request = (THInferRequest 
*)av_malloc(sizeof(THInferRequest));
-    if (!request) {
+    THInferRequest *request = av_mallocz(sizeof(THInferRequest));
+    if (!request)
         return NULL;
-    }
-    request->input_tensor = NULL;
-    request->output = NULL;
     return request;
 }
 
 static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, 
AVFilterContext *filter_ctx)
 {
-    DNNModel *model = NULL;
-    THModel *th_model = NULL;
+    THModel *th_model = av_mallocz(sizeof(THModel));
     THRequestItem *item = NULL;
-    const char *device_name = ctx->device ? ctx->device : "cpu";
 
-    th_model = (THModel *)av_mallocz(sizeof(THModel));
     if (!th_model)
         return NULL;
-    model = &th_model->model;
-    th_model->ctx = ctx;
-
-    c10::Device device = c10::Device(device_name);
-    if (device.is_xpu()) {
-        if (!at::hasXPU()) {
-            av_log(ctx, AV_LOG_ERROR, "No XPU device found\n");
-            goto fail;
-        }
-        at::detail::getXPUHooks().initXPU();
-    } else if (!device.is_cpu()) {
-        av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", 
device_name);
-        goto fail;
-    }
 
-    try {
-        th_model->jit_model = new torch::jit::Module;
-        (*th_model->jit_model) = torch::jit::load(ctx->model_filename);
-        th_model->jit_model->to(device);
-    } catch (const c10::Error& e) {
-        av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
-        goto fail;
-    }
+    th_model->ctx = ctx;
+    th_model->jit_model = new torch::jit::Module;
+    // Commit 1 uses the simplest loading logic
+    *th_model->jit_model = torch::jit::load(ctx->model_filename);
 
     th_model->request_queue = ff_safe_queue_create();
-    if (!th_model->request_queue) {
+    if (!th_model->request_queue)
         goto fail;
-    }
 
-    item = (THRequestItem *)av_mallocz(sizeof(THRequestItem));
-    if (!item) {
+    item = av_mallocz(sizeof(THRequestItem));
+    if (!item)
         goto fail;
-    }
-    item->lltask = NULL;
+
     item->infer_request = th_create_inference_request();
-    if (!item->infer_request) {
-        av_log(NULL, AV_LOG_ERROR, "Failed to allocate memory for Torch 
inference request\n");
+    if (!item->infer_request)
         goto fail;
-    }
+
+    // Infrastructure setup for Async Module
     item->exec_module.start_inference = &th_start_inference;
     item->exec_module.callback = &infer_completion_callback;
     item->exec_module.args = item;
 
-    if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) {
+    if (ff_safe_queue_push_back(th_model->request_queue, item) < 0)
         goto fail;
-    }
     item = NULL;
 
     th_model->task_queue = ff_queue_create();
-    if (!th_model->task_queue) {
+    if (!th_model->task_queue)
         goto fail;
-    }
 
     th_model->lltask_queue = ff_queue_create();
-    if (!th_model->lltask_queue) {
-        goto fail;
-    }
-
-    th_model->pending_queue = ff_safe_queue_create();
-    if (!th_model->pending_queue) {
+    if (!th_model->lltask_queue)
         goto fail;
-    }
 
-    th_model->mutex = new std::mutex();
-    th_model->cond = new std::condition_variable();
-    th_model->worker_stop = false;
-    th_model->worker_thread = new std::thread(th_worker_thread, th_model);
+    th_model->model.get_input = &get_input_th;
+    th_model->model.get_output = &get_output_th;
+    th_model->model.filter_ctx = filter_ctx;
+    th_model->model.func_type = func_type;
 
-    model->get_input = &get_input_th;
-    model->get_output = &get_output_th;
-    model->filter_ctx = filter_ctx;
-    model->func_type = func_type;
-    return model;
+    return &th_model->model;
 
 fail:
-    if (item) {
+    if (item)
         destroy_request_item(&item);
-        av_freep(&item);
-    }
-    dnn_free_model_th(&model);
+    // Passing the address of the model pointer
+    DNNModel *temp_model = &th_model->model;
+    dnn_free_model_th(&temp_model);
     return NULL;
 }
 
@@ -590,42 +476,31 @@ static int dnn_execute_model_th(const DNNModel *model, 
DNNExecBaseParams *exec_p
     int ret = 0;
 
     ret = ff_check_exec_params(ctx, DNN_TH, model->func_type, exec_params);
-    if (ret != 0) {
-        av_log(ctx, AV_LOG_ERROR, "exec parameter checking fail.\n");
+    if (ret)
         return ret;
-    }
 
-    task = (TaskItem *)av_malloc(sizeof(TaskItem));
-    if (!task) {
-        av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task item.\n");
+    task = av_mallocz(sizeof(TaskItem));
+    if (!task)
         return AVERROR(ENOMEM);
-    }
 
     ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1);
-    if (ret != 0) {
+    if (ret) {
         av_freep(&task);
-        av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n");
         return ret;
     }
 
-    ret = ff_queue_push_back(th_model->task_queue, task);
-    if (ret < 0) {
+    if (ff_queue_push_back(th_model->task_queue, task) < 0) {
         av_freep(&task);
-        av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n");
-        return ret;
+        return AVERROR(ENOMEM);
     }
 
     ret = extract_lltask_from_task(task, th_model->lltask_queue);
-    if (ret != 0) {
-        av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from 
task.\n");
+    if (ret)
         return ret;
-    }
 
     request = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->request_queue);
-    if (!request) {
-        av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
+    if (!request)
         return AVERROR(EINVAL);
-    }
 
     return execute_model_th(request, th_model->lltask_queue);
 }
@@ -642,14 +517,11 @@ static int dnn_flush_th(const DNNModel *model)
     THRequestItem *request;
 
     if (ff_queue_size(th_model->lltask_queue) == 0)
-        // no pending task need to flush
         return 0;
 
     request = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->request_queue);
-    if (!request) {
-        av_log(th_model->ctx, AV_LOG_ERROR, "unable to get infer request.\n");
+    if (!request)
         return AVERROR(EINVAL);
-    }
 
     return execute_model_th(request, th_model->lltask_queue);
 }
@@ -662,4 +534,4 @@ extern const DNNModule ff_dnn_backend_torch = {
     .get_result     = dnn_get_result_th,
     .flush          = dnn_flush_th,
     .free_model     = dnn_free_model_th,
-};
+};
\ No newline at end of file
-- 
2.51.0

_______________________________________________
ffmpeg-devel mailing list -- [email protected]
To unsubscribe send an email to [email protected]

Reply via email to