Signed-off-by: Raja Rathour <[email protected]>

This patch implements asynchronous model execution for the LibTorch backend 
in FFmpeg's DNN module. 

Key changes:
- Integrated a worker thread and a pending queue to handle inference.
- Prevents the main filter thread from blocking during model execution.
- Aligns LibTorch backend behavior with the existing OpenVINO async 
implementation.
- Improves overall throughput for deep learning filters using LibTorch.

The implementation has been tested with various torch models to ensure 
stability and correct frame synchronization.

---
 libavfilter/dnn/dnn_backend_torch.cpp | 115 +++++++++++++++++++++++++-
 1 file changed, 111 insertions(+), 4 deletions(-)

diff --git a/libavfilter/dnn/dnn_backend_torch.cpp 
b/libavfilter/dnn/dnn_backend_torch.cpp
index 2e4326d9d4..ad81aff8da 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -25,6 +25,10 @@
 
 #include <torch/torch.h>
 #include <torch/script.h>
+#include <thread>
+#include <mutex>
+#include <condition_variable>
+#include <atomic>
 
 extern "C" {
 #include "dnn_io_proc.h"
@@ -39,9 +43,16 @@ typedef struct THModel {
     DNNModel model;
     DnnContext *ctx;
     torch::jit::Module *jit_model;
-    SafeQueue *request_queue;
+    SafeQueue *request_queue;       // Holds available/idle request slots
     Queue *task_queue;
     Queue *lltask_queue;
+    
+    // --- Async Support ---
+    SafeQueue *pending_queue;       // Holds requests waiting for inference
+    std::thread *worker_thread;     // The background worker
+    std::mutex *mutex;              // Protects the condition variable
+    std::condition_variable *cond;  // Wakes up worker when new task arrives
+    std::atomic<bool> worker_stop;  // Flag to stop the thread
 } THModel;
 
 typedef struct THInferRequest {
@@ -119,6 +130,32 @@ static void dnn_free_model_th(DNNModel **model)
         return;
 
     th_model = (THModel *) (*model);
+
+    // --- Stop and Join Worker Thread ---
+    if (th_model->worker_thread) {
+        {
+            std::lock_guard<std::mutex> lock(*th_model->mutex);
+            th_model->worker_stop = true;
+        }
+        th_model->cond->notify_all();
+        
+        if (th_model->worker_thread->joinable()) {
+            th_model->worker_thread->join();
+        }
+        delete th_model->worker_thread;
+        delete th_model->mutex;
+        delete th_model->cond;
+    }
+
+    if (th_model->pending_queue) {
+        // Clear remaining items (if any)
+        while (ff_safe_queue_size(th_model->pending_queue) != 0) {
+            ff_safe_queue_pop_front(th_model->pending_queue);
+        }
+        ff_safe_queue_destroy(th_model->pending_queue);
+    }
+    // -----------------------------------
+
     while (ff_safe_queue_size(th_model->request_queue) != 0) {
         THRequestItem *item = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->request_queue);
         destroy_request_item(&item);
@@ -318,6 +355,41 @@ err:
     }
 }
 
+// --- Worker Thread Function ---
+static void th_worker_thread(THModel *th_model) {
+    while (true) {
+        THRequestItem *request = NULL;
+        
+        {
+            // Acquire lock to check condition
+            std::unique_lock<std::mutex> lock(*th_model->mutex);
+            
+            // Wait until: We are told to stop OR there is work in the queue
+            th_model->cond->wait(lock, [&]{ 
+                return th_model->worker_stop || 
ff_safe_queue_size(th_model->pending_queue) > 0; 
+            });
+
+            // If stopped and no work left, exit
+            if (th_model->worker_stop && 
ff_safe_queue_size(th_model->pending_queue) == 0) {
+                break;
+            }
+
+            // Get work
+            request = (THRequestItem 
*)ff_safe_queue_pop_front(th_model->pending_queue);
+        }
+        
+        // Process work (Lock released so we don't block submission)
+        if (request) {
+            int ret = th_start_inference(request);
+            if (ret != 0) {
+                av_log(th_model->ctx, AV_LOG_ERROR, "Async inference 
failed\n");
+            }
+            // Always callback to clean up and notify FFmpeg
+            infer_completion_callback(request);
+        }
+    }
+}
+
 static int execute_model_th(THRequestItem *request, Queue *lltask_queue)
 {
     THModel *th_model = NULL;
@@ -343,9 +415,24 @@ static int execute_model_th(THRequestItem *request, Queue 
*lltask_queue)
     if ( ret != 0) {
         goto err;
     }
+
+    // --- EXECUTION LOGIC (ASYNC vs SYNC) ---
     if (task->async) {
-        avpriv_report_missing_feature(th_model->ctx, "LibTorch async");
+        // 1. Acquire lock
+        std::lock_guard<std::mutex> lock(*th_model->mutex);
+        
+        // 2. Push to pending queue
+        if (ff_safe_queue_push_back(th_model->pending_queue, request) < 0) {
+            return AVERROR(ENOMEM);
+        }
+        
+        // 3. Wake up worker
+        th_model->cond->notify_one();
+        
+        // 4. Return immediately (Success)
+        return 0; 
     } else {
+        // Synchronous fallback
         ret = th_start_inference((void *)(request));
         if (ret != 0) {
             goto err;
@@ -484,6 +571,25 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, 
DNNFunctionType func_type, A
         goto fail;
     }
 
+    // --- INITIALIZE ASYNC QUEUE AND THREAD ---
+    th_model->pending_queue = ff_safe_queue_create();
+    if (!th_model->pending_queue) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to create pending queue\n");
+        goto fail;
+    }
+
+    try {
+        th_model->mutex = new std::mutex();
+        th_model->cond = new std::condition_variable();
+        th_model->worker_stop = false;
+        
+        // Start worker thread
+        th_model->worker_thread = new std::thread(th_worker_thread, th_model);
+    } catch (...) {
+        av_log(ctx, AV_LOG_ERROR, "Failed to initialize worker thread or 
mutexes\n");
+        goto fail;
+    }
+
     model->get_input = &get_input_th;
     model->get_output = &get_output_th;
     model->filter_ctx = filter_ctx;
@@ -519,7 +625,8 @@ static int dnn_execute_model_th(const DNNModel *model, 
DNNExecBaseParams *exec_p
         return AVERROR(ENOMEM);
     }
 
-    ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1);
+    // Set 'async' flag based on context (ctx->async) instead of hardcoded 0
+    ret = ff_dnn_fill_task(task, exec_params, th_model, ctx->async, 1);
     if (ret != 0) {
         av_freep(&task);
         av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n");
@@ -580,4 +687,4 @@ extern const DNNModule ff_dnn_backend_torch = {
     .get_result     = dnn_get_result_th,
     .flush          = dnn_flush_th,
     .free_model     = dnn_free_model_th,
-};
+};
\ No newline at end of file
-- 
2.48.1

_______________________________________________
ffmpeg-devel mailing list -- [email protected]
To unsubscribe send an email to [email protected]

Reply via email to