This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5ec33bb558 [TOPI] Support non-batch cases for topi.nll_loss (#14060)
5ec33bb558 is described below
commit 5ec33bb5582de6574e9eb426c84950ec80b0f0c4
Author: Yixin Dong <[email protected]>
AuthorDate: Tue Feb 21 17:35:35 2023 +0800
[TOPI] Support non-batch cases for topi.nll_loss (#14060)
This PR supports the cases when input does not contain batches for
`topi.nll_loss`.
When there is no batches, the shape of the prediction parameter is `(C,)`,
the shape of the target parameter is `()`, the shape of the target parameter is
`(C,)`, and the shape of the output is always `()` no matter which reduction
method it uses.
---
include/tvm/topi/nn.h | 28 ++++++++++++++++++++++++++++
tests/python/topi/python/test_topi_loss.py | 11 +++++++++--
2 files changed, 37 insertions(+), 2 deletions(-)
diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h
index 90c1c09a07..27c1043dde 100644
--- a/include/tvm/topi/nn.h
+++ b/include/tvm/topi/nn.h
@@ -660,6 +660,32 @@ inline tvm::te::Tensor batch_to_space_nd(const
tvm::te::Tensor& data,
inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const
Tensor& weights,
std::string reduction = "mean", int ignore_index = -100,
const std::string name = "nll_loss", const std::string
tag = kBroadcast) {
+ if (predictions.ndim() == 1) {
+ // corner case: no batch in shape
+ // prediction->shape = (C,), targets->shape = (), weights->shape = (C,)
+ auto T = tvm::te::compute(
+ {},
+ [&](const tvm::Array<tvm::tir::Var>& target_indices) {
+ auto c = targets();
+ return tvm::tir::Select(c != ignore_index, -predictions(c) *
weights(c),
+ tvm::tir::make_const(predictions->dtype, 0));
+ },
+ name, tag);
+ if (reduction == "mean") {
+ auto W = tvm::te::compute(
+ {},
+ [&](const tvm::Array<tvm::tir::Var>& target_indices) {
+ auto c = targets();
+ return tvm::tir::Select(c != ignore_index, weights(c),
+ tvm::tir::make_const(predictions->dtype,
0));
+ },
+ name, tag);
+ return topi::divide(T, W);
+ } else {
+ return T;
+ }
+ }
+
auto T = tvm::te::compute(
targets->shape,
[&](const tvm::Array<tvm::tir::Var>& target_indices) {
@@ -674,6 +700,7 @@ inline Tensor nll_loss(const Tensor& predictions, const
Tensor& targets, const T
tvm::tir::make_const(predictions->dtype, 0));
},
name, tag);
+ ICHECK(T->shape.size() != 0);
if (reduction == "mean") {
auto W = tvm::te::compute(
targets->shape,
@@ -690,6 +717,7 @@ inline Tensor nll_loss(const Tensor& predictions, const
Tensor& targets, const T
return T;
}
}
+
} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_NN_H_
diff --git a/tests/python/topi/python/test_topi_loss.py
b/tests/python/topi/python/test_topi_loss.py
index 53960139dd..969beb7d28 100644
--- a/tests/python/topi/python/test_topi_loss.py
+++ b/tests/python/topi/python/test_topi_loss.py
@@ -32,12 +32,19 @@ prediction_shape, reduction, ignore_index, dtype =
tvm.testing.parameters(
((10, 5), "none", -100, "float32"),
((10, 5), "mean", 3, "float32"),
((10, 5), "mean", -100, "float64"),
+ ((5,), "mean", -100, "float32"),
+ ((5,), "mean", 3, "float32"),
+ ((5,), "none", -100, "float32"),
)
def test_nll_loss(target, dev, prediction_shape, reduction, ignore_index,
dtype):
- C = prediction_shape[1]
- target_shape = prediction_shape[:1] + prediction_shape[2:]
+ if len(prediction_shape) == 1:
+ C = prediction_shape[0]
+ target_shape = []
+ else:
+ C = prediction_shape[1]
+ target_shape = prediction_shape[:1] + prediction_shape[2:]
predictions = te.placeholder(shape=prediction_shape, name="predictions",
dtype=dtype)
targets = te.placeholder(shape=target_shape, name="targets", dtype="int32")
weights = te.placeholder(shape=(C,), name="weights", dtype=dtype)