Signed-off-by: MaximilianKaindl <m.kaindl0...@gmail.com>
---
 libavfilter/dnn/dnn_backend_torch.cpp | 76 +++++++++++++++++++++++++++
 1 file changed, 76 insertions(+)

diff --git a/libavfilter/dnn/dnn_backend_torch.cpp 
b/libavfilter/dnn/dnn_backend_torch.cpp
index 3a0ef931f9..12ba2674b3 100644
--- a/libavfilter/dnn/dnn_backend_torch.cpp
+++ b/libavfilter/dnn/dnn_backend_torch.cpp
@@ -381,6 +381,82 @@ static int copy_softmax_units(THModel *th_model, const int 
*softmax_units, int s
     return 0;
 }

+static torch::Tensor calculate_similarity(torch::Tensor &tensor1, 
torch::Tensor &tensor2, bool normalize, float logit_scale, DnnContext *ctx)
+{
+    try {
+        if (normalize) {
+            tensor1 = torch::nn::functional::normalize(tensor1, 
torch::nn::functional::NormalizeFuncOptions().p(2).dim(-1));
+            tensor2 = torch::nn::functional::normalize(tensor2, 
torch::nn::functional::NormalizeFuncOptions().p(2).dim(-1));
+        }
+
+        // Compute similarity matrix
+        torch::Tensor similarity = logit_scale * torch::matmul(tensor2, 
tensor1.transpose(0, 1));
+        return similarity.transpose(0, 1);
+    } catch (const c10::Error &e) {
+        av_log(ctx, AV_LOG_ERROR, "Similarity computation failed: %s\n", 
e.what());
+        return torch::Tensor(); // Return empty tensor properly
+    }
+}
+
+static torch::Tensor apply_softmax(torch::Tensor input_tensor, float 
temperature, const int *softmax_units, int softmax_units_count, DnnContext *ctx)
+{
+    try {
+        // Check for empty or invalid input tensor
+        if (input_tensor.numel() == 0 || input_tensor.dim() < 2) {
+            av_log(ctx, AV_LOG_ERROR, "Invalid input tensor for softmax\n");
+            return input_tensor;
+        }
+
+        // Apply temperature if needed
+        torch::Tensor scaled_tensor;
+        if (temperature > 0.0f && temperature != 1.0f) {
+            scaled_tensor = input_tensor / temperature;
+        } else {
+            scaled_tensor = input_tensor;
+        }
+
+        // If no specific units are provided, apply softmax to the entire 
tensor
+        if (!softmax_units || softmax_units_count <= 0) {
+            return torch::nn::functional::softmax(scaled_tensor, 
torch::nn::functional::SoftmaxFuncOptions(1));
+        }
+
+        // Create a new output tensor with the same shape as the input
+        torch::Tensor result = torch::empty_like(scaled_tensor);
+        int offset = 0;
+
+        // Apply softmax to each specified segment
+        for (int i = 0; i < softmax_units_count; i++) {
+            int length = softmax_units[i];
+            if (length <= 0 || offset + length > scaled_tensor.size(1)) {
+                av_log(ctx, AV_LOG_ERROR, "Invlid Softmax units were given to 
softmax. Index invalid or out of Bounds.\n");
+                return input_tensor;
+            }
+
+            // Apply softmax to the segment and directly place it in the 
result tensor
+            result.slice(1, offset, offset + length) = 
torch::nn::functional::softmax(
+                scaled_tensor.slice(1, offset, offset + length), 
torch::nn::functional::SoftmaxFuncOptions(1));
+
+            // Move offset forward
+            offset += length;
+        }
+
+        // Copy any remaining unprocessed parts if there are any
+        if (offset < scaled_tensor.size(1)) {
+            result.slice(1, offset, scaled_tensor.size(1)) = 
scaled_tensor.slice(1, offset, scaled_tensor.size(1));
+            // Copy remaining unprocessed elements without modification
+            av_log(ctx, AV_LOG_ERROR, "Some tensor elements (%d to %ld) were 
not processed by softmax\n", offset,
+                    scaled_tensor.size(1) - 1);
+        }
+
+        return result;
+    } catch (const c10::Error &e) {
+        av_log(ctx, AV_LOG_ERROR, "Error applying softmax: %s\n", e.what());
+        return input_tensor; // Return original tensor on error
+    } catch (const std::exception &e) {
+        av_log(ctx, AV_LOG_ERROR, "Error applying softmax: %s\n", e.what());
+        return input_tensor; // Return original tensor on error
+    }
+}

 static int fill_model_input_th(THModel *th_model, THRequestItem *request)
 {
--
2.34.1


_______________________________________________
ffmpeg-devel mailing list
ffmpeg-devel@ffmpeg.org
https://ffmpeg.org/mailman/listinfo/ffmpeg-devel

To unsubscribe, visit link above, or email
ffmpeg-devel-requ...@ffmpeg.org with subject "unsubscribe".

Reply via email to