This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ec39199edb [PyTorch] [Relay] Add l1 and mse loss function for pytorch
frontend (#11978)
ec39199edb is described below
commit ec39199edb72dfe93747249d6a060c1832a8e38f
Author: Yuanjing Shi <[email protected]>
AuthorDate: Thu Jun 30 17:07:43 2022 -0700
[PyTorch] [Relay] Add l1 and mse loss function for pytorch frontend (#11978)
* add l1 and mse loss function for pytorch frontend
* fix CI
---
python/tvm/relay/frontend/pytorch.py | 33 +++++++++++++++++++++++-
python/tvm/topi/nn/softmax.py | 4 +--
tests/python/frontend/pytorch/test_forward.py | 36 +++++++++++++++++++++++++++
3 files changed, 70 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 6fe8c89e3c..123b029983 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -932,6 +932,35 @@ class PyTorchOpConverter:
assert weights is None, "weight not supported in cross_entropy_loss"
return _op.nn.cross_entropy_with_logits(_op.nn.log_softmax(input),
target)
+ def l1_loss(self, inputs, input_types):
+ assert len(inputs) == 3
+ [predictions, targets, reduction] = inputs
+ delta = _op.abs(_op.subtract(predictions, targets))
+ if reduction == 0:
+ # reduction = "none"
+ return delta
+ elif reduction == 1:
+ # reduction = "mean"
+ return _op.mean(delta)
+ else:
+ # reduction = "sum"
+ return _op.sum(delta)
+
+ def mse_loss(self, inputs, input_types):
+ assert len(inputs) == 3
+ [predictions, targets, reduction] = inputs
+ delta = _op.subtract(predictions, targets)
+ delta = _op.power(delta, _expr.const(2, input_types[0]))
+ if reduction == 0:
+ # reduction = "none"
+ return delta
+ elif reduction == 1:
+ # reduction = "mean"
+ return _op.mean(delta)
+ else:
+ # reduction = "sum"
+ return _op.sum(delta)
+
def hard_sigmoid(self, inputs, input_types):
def _relu6(x):
return _op.tensor.clip(x, 0.0, 6.0)
@@ -3200,7 +3229,6 @@ class PyTorchOpConverter:
"aten::silu": self.silu,
"aten::glu": self.glu,
"aten::log_sigmoid": self.log_sigmoid,
- "aten::cross_entropy_loss": self.cross_entropy_loss_with_logits,
"aten::adaptive_avg_pool1d": functools.partial(
self.adaptive_avg_pool, _op.nn.adaptive_avg_pool1d
),
@@ -3374,6 +3402,9 @@ class PyTorchOpConverter:
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
"aten::nll_loss_nd": self.nll_loss,
+ "aten::cross_entropy_loss": self.cross_entropy_loss_with_logits,
+ "aten::l1_loss": self.l1_loss,
+ "aten::mse_loss": self.mse_loss,
"aten::flip": self.flip,
"aten::gru": self.gru,
"aten::lstm": self.lstm,
diff --git a/python/tvm/topi/nn/softmax.py b/python/tvm/topi/nn/softmax.py
index 2d6921b26d..83a4995744 100644
--- a/python/tvm/topi/nn/softmax.py
+++ b/python/tvm/topi/nn/softmax.py
@@ -129,12 +129,12 @@ def log_softmax(x, axis=-1):
Parameters
----------
data : tvm.te.Tensor
- 2-D input data
+ N-D input data
Returns
-------
output : tvm.te.Tensor
- 2-D output with same shape
+ N-D output with same shape
"""
shape = x.shape
if axis < 0:
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index d411d9c874..4f42c183b6 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4177,6 +4177,42 @@ def test_cross_entropy_loss():
verify_model(torch.nn.CrossEntropyLoss().eval(), input_data=[predictions,
targets])
+def test_forward_l1_loss():
+ torch.set_grad_enabled(False)
+ N, C = 10, 3
+ predictions = torch.rand((N, C)).float()
+ targets = torch.rand((N, C)).float()
+ verify_model(torch.nn.L1Loss().eval(), input_data=[predictions, targets])
+ verify_model(torch.nn.L1Loss(reduction="sum").eval(),
input_data=[predictions, targets])
+ verify_model(torch.nn.L1Loss(reduction="none").eval(),
input_data=[predictions, targets])
+
+ # multidimension l1 loss
+ d1, d2 = 2, 3
+ predictions = torch.rand((N, C, d1, d2)).float()
+ targets = torch.rand((N, C, d1, d2)).float()
+ verify_model(torch.nn.L1Loss().eval(), input_data=[predictions, targets])
+ verify_model(torch.nn.L1Loss(reduction="sum").eval(),
input_data=[predictions, targets])
+ verify_model(torch.nn.L1Loss(reduction="none").eval(),
input_data=[predictions, targets])
+
+
+def test_forward_mse_loss():
+ torch.set_grad_enabled(False)
+ N, C = 10, 3
+ predictions = torch.rand((N, C)).float()
+ targets = torch.rand((N, C)).float()
+ verify_model(torch.nn.MSELoss().eval(), input_data=[predictions, targets])
+ verify_model(torch.nn.MSELoss(reduction="sum").eval(),
input_data=[predictions, targets])
+ verify_model(torch.nn.MSELoss(reduction="none").eval(),
input_data=[predictions, targets])
+
+ # multidimension mse loss
+ d1, d2 = 2, 3
+ predictions = torch.rand((N, C, d1, d2)).float()
+ targets = torch.rand((N, C, d1, d2)).float()
+ verify_model(torch.nn.MSELoss().eval(), input_data=[predictions, targets])
+ verify_model(torch.nn.MSELoss(reduction="sum").eval(),
input_data=[predictions, targets])
+ verify_model(torch.nn.MSELoss(reduction="none").eval(),
input_data=[predictions, targets])
+
+
@tvm.testing.uses_gpu
def test_forward_flip():
torch.set_grad_enabled(False)