Signed-off-by: Raja Rathour <[email protected]>
This patch implements asynchronous model execution for the LibTorch backend
in FFmpeg's DNN module.
Key changes:
- Integrated a worker thread and a pending queue to handle inference.
- Prevents the main filter thread from blocking during model execution.
- Aligns LibTorch backend behavior with the existing OpenVINO async
implementation.
- Improves overall throughput for deep learning filters using LibTorch.
The implementation has been tested with various torch models to ensure
stability and correct frame synchronization.
---
libavfilter/dnn/dnn_backend_torch.cpp | 115 +++++++++++++++++++++++++-
1 file changed, 111 insertions(+), 4 deletions(-)
diff --git a/libavfilter/dnn/dnn_backend_torch.cpp
b/libavfilter/dnn/dnn_backend_torch.cpp
index 2e4326d9d4..ad81aff8da 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -25,6 +25,10 @@
#include <torch/torch.h>
#include <torch/script.h>
+#include <thread>
+#include <mutex>
+#include <condition_variable>
+#include <atomic>
extern "C" {
#include "dnn_io_proc.h"
@@ -39,9 +43,16 @@ typedef struct THModel {
DNNModel model;
DnnContext *ctx;
torch::jit::Module *jit_model;
- SafeQueue *request_queue;
+ SafeQueue *request_queue; // Holds available/idle request slots
Queue *task_queue;
Queue *lltask_queue;
+
+ // --- Async Support ---
+ SafeQueue *pending_queue; // Holds requests waiting for inference
+ std::thread *worker_thread; // The background worker
+ std::mutex *mutex; // Protects the condition variable
+ std::condition_variable *cond; // Wakes up worker when new task arrives
+ std::atomic<bool> worker_stop; // Flag to stop the thread
} THModel;
typedef struct THInferRequest {
@@ -119,6 +130,32 @@ static void dnn_free_model_th(DNNModel **model)
return;
th_model = (THModel *) (*model);
+
+ // --- Stop and Join Worker Thread ---
+ if (th_model->worker_thread) {
+ {
+ std::lock_guard<std::mutex> lock(*th_model->mutex);
+ th_model->worker_stop = true;
+ }
+ th_model->cond->notify_all();
+
+ if (th_model->worker_thread->joinable()) {
+ th_model->worker_thread->join();
+ }
+ delete th_model->worker_thread;
+ delete th_model->mutex;
+ delete th_model->cond;
+ }
+
+ if (th_model->pending_queue) {
+ // Clear remaining items (if any)
+ while (ff_safe_queue_size(th_model->pending_queue) != 0) {
+ ff_safe_queue_pop_front(th_model->pending_queue);
+ }
+ ff_safe_queue_destroy(th_model->pending_queue);
+ }
+ // -----------------------------------
+
while (ff_safe_queue_size(th_model->request_queue) != 0) {
THRequestItem *item = (THRequestItem
*)ff_safe_queue_pop_front(th_model->request_queue);
destroy_request_item(&item);
@@ -318,6 +355,41 @@ err:
}
}
+// --- Worker Thread Function ---
+static void th_worker_thread(THModel *th_model) {
+ while (true) {
+ THRequestItem *request = NULL;
+
+ {
+ // Acquire lock to check condition
+ std::unique_lock<std::mutex> lock(*th_model->mutex);
+
+ // Wait until: We are told to stop OR there is work in the queue
+ th_model->cond->wait(lock, [&]{
+ return th_model->worker_stop ||
ff_safe_queue_size(th_model->pending_queue) > 0;
+ });
+
+ // If stopped and no work left, exit
+ if (th_model->worker_stop &&
ff_safe_queue_size(th_model->pending_queue) == 0) {
+ break;
+ }
+
+ // Get work
+ request = (THRequestItem
*)ff_safe_queue_pop_front(th_model->pending_queue);
+ }
+
+ // Process work (Lock released so we don't block submission)
+ if (request) {
+ int ret = th_start_inference(request);
+ if (ret != 0) {
+ av_log(th_model->ctx, AV_LOG_ERROR, "Async inference
failed\n");
+ }
+ // Always callback to clean up and notify FFmpeg
+ infer_completion_callback(request);
+ }
+ }
+}
+
static int execute_model_th(THRequestItem *request, Queue *lltask_queue)
{
THModel *th_model = NULL;
@@ -343,9 +415,24 @@ static int execute_model_th(THRequestItem *request, Queue
*lltask_queue)
if ( ret != 0) {
goto err;
}
+
+ // --- EXECUTION LOGIC (ASYNC vs SYNC) ---
if (task->async) {
- avpriv_report_missing_feature(th_model->ctx, "LibTorch async");
+ // 1. Acquire lock
+ std::lock_guard<std::mutex> lock(*th_model->mutex);
+
+ // 2. Push to pending queue
+ if (ff_safe_queue_push_back(th_model->pending_queue, request) < 0) {
+ return AVERROR(ENOMEM);
+ }
+
+ // 3. Wake up worker
+ th_model->cond->notify_one();
+
+ // 4. Return immediately (Success)
+ return 0;
} else {
+ // Synchronous fallback
ret = th_start_inference((void *)(request));
if (ret != 0) {
goto err;
@@ -484,6 +571,25 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx,
DNNFunctionType func_type, A
goto fail;
}
+ // --- INITIALIZE ASYNC QUEUE AND THREAD ---
+ th_model->pending_queue = ff_safe_queue_create();
+ if (!th_model->pending_queue) {
+ av_log(ctx, AV_LOG_ERROR, "Failed to create pending queue\n");
+ goto fail;
+ }
+
+ try {
+ th_model->mutex = new std::mutex();
+ th_model->cond = new std::condition_variable();
+ th_model->worker_stop = false;
+
+ // Start worker thread
+ th_model->worker_thread = new std::thread(th_worker_thread, th_model);
+ } catch (...) {
+ av_log(ctx, AV_LOG_ERROR, "Failed to initialize worker thread or
mutexes\n");
+ goto fail;
+ }
+
model->get_input = &get_input_th;
model->get_output = &get_output_th;
model->filter_ctx = filter_ctx;
@@ -519,7 +625,8 @@ static int dnn_execute_model_th(const DNNModel *model,
DNNExecBaseParams *exec_p
return AVERROR(ENOMEM);
}
- ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1);
+ // Set 'async' flag based on context (ctx->async) instead of hardcoded 0
+ ret = ff_dnn_fill_task(task, exec_params, th_model, ctx->async, 1);
if (ret != 0) {
av_freep(&task);
av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n");
@@ -580,4 +687,4 @@ extern const DNNModule ff_dnn_backend_torch = {
.get_result = dnn_get_result_th,
.flush = dnn_flush_th,
.free_model = dnn_free_model_th,
-};
+};
\ No newline at end of file
--
2.48.1
_______________________________________________
ffmpeg-devel mailing list -- [email protected]
To unsubscribe send an email to [email protected]