This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ec33bb558 [TOPI] Support non-batch cases for topi.nll_loss (#14060)
5ec33bb558 is described below

commit 5ec33bb5582de6574e9eb426c84950ec80b0f0c4
Author: Yixin Dong <[email protected]>
AuthorDate: Tue Feb 21 17:35:35 2023 +0800

    [TOPI] Support non-batch cases for topi.nll_loss (#14060)
    
    This PR supports the cases when input does not contain batches for 
`topi.nll_loss`.
    
    When there is no batches, the shape of the prediction parameter is `(C,)`,  
the shape of the target parameter is `()`, the shape of the target parameter is 
`(C,)`, and the shape of the output is always `()` no matter which reduction 
method it uses.
---
 include/tvm/topi/nn.h                      | 28 ++++++++++++++++++++++++++++
 tests/python/topi/python/test_topi_loss.py | 11 +++++++++--
 2 files changed, 37 insertions(+), 2 deletions(-)

diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h
index 90c1c09a07..27c1043dde 100644
--- a/include/tvm/topi/nn.h
+++ b/include/tvm/topi/nn.h
@@ -660,6 +660,32 @@ inline tvm::te::Tensor batch_to_space_nd(const 
tvm::te::Tensor& data,
 inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const 
Tensor& weights,
                        std::string reduction = "mean", int ignore_index = -100,
                        const std::string name = "nll_loss", const std::string 
tag = kBroadcast) {
+  if (predictions.ndim() == 1) {
+    // corner case: no batch in shape
+    // prediction->shape = (C,), targets->shape = (), weights->shape = (C,)
+    auto T = tvm::te::compute(
+        {},
+        [&](const tvm::Array<tvm::tir::Var>& target_indices) {
+          auto c = targets();
+          return tvm::tir::Select(c != ignore_index, -predictions(c) * 
weights(c),
+                                  tvm::tir::make_const(predictions->dtype, 0));
+        },
+        name, tag);
+    if (reduction == "mean") {
+      auto W = tvm::te::compute(
+          {},
+          [&](const tvm::Array<tvm::tir::Var>& target_indices) {
+            auto c = targets();
+            return tvm::tir::Select(c != ignore_index, weights(c),
+                                    tvm::tir::make_const(predictions->dtype, 
0));
+          },
+          name, tag);
+      return topi::divide(T, W);
+    } else {
+      return T;
+    }
+  }
+
   auto T = tvm::te::compute(
       targets->shape,
       [&](const tvm::Array<tvm::tir::Var>& target_indices) {
@@ -674,6 +700,7 @@ inline Tensor nll_loss(const Tensor& predictions, const 
Tensor& targets, const T
                                 tvm::tir::make_const(predictions->dtype, 0));
       },
       name, tag);
+  ICHECK(T->shape.size() != 0);
   if (reduction == "mean") {
     auto W = tvm::te::compute(
         targets->shape,
@@ -690,6 +717,7 @@ inline Tensor nll_loss(const Tensor& predictions, const 
Tensor& targets, const T
     return T;
   }
 }
+
 }  // namespace topi
 }  // namespace tvm
 #endif  // TVM_TOPI_NN_H_
diff --git a/tests/python/topi/python/test_topi_loss.py 
b/tests/python/topi/python/test_topi_loss.py
index 53960139dd..969beb7d28 100644
--- a/tests/python/topi/python/test_topi_loss.py
+++ b/tests/python/topi/python/test_topi_loss.py
@@ -32,12 +32,19 @@ prediction_shape, reduction, ignore_index, dtype = 
tvm.testing.parameters(
     ((10, 5), "none", -100, "float32"),
     ((10, 5), "mean", 3, "float32"),
     ((10, 5), "mean", -100, "float64"),
+    ((5,), "mean", -100, "float32"),
+    ((5,), "mean", 3, "float32"),
+    ((5,), "none", -100, "float32"),
 )
 
 
 def test_nll_loss(target, dev, prediction_shape, reduction, ignore_index, 
dtype):
-    C = prediction_shape[1]
-    target_shape = prediction_shape[:1] + prediction_shape[2:]
+    if len(prediction_shape) == 1:
+        C = prediction_shape[0]
+        target_shape = []
+    else:
+        C = prediction_shape[1]
+        target_shape = prediction_shape[:1] + prediction_shape[2:]
     predictions = te.placeholder(shape=prediction_shape, name="predictions", 
dtype=dtype)
     targets = te.placeholder(shape=target_shape, name="targets", dtype="int32")
     weights = te.placeholder(shape=(C,), name="weights", dtype=dtype)

Reply via email to