This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d3e3d43b71 [Relax][PyTorch] CrossEntropyLoss (#17863)
d3e3d43b71 is described below
commit d3e3d43b7139415aa92c42f6c88c418dbffa62ca
Author: Hugo Latendresse <[email protected]>
AuthorDate: Thu May 8 12:04:49 2025 -0400
[Relax][PyTorch] CrossEntropyLoss (#17863)
* tests for add'l modules
* no sort
* cross entropy test passes
* cleanup
* fix expand
* remove new e2e tests
* remove new e2e tests
* convert e2e test to unit test
* unit test
* restore tests
* move
* add new tests
* add new tests from 17862
* whitespace
* print statemetns
* all tests pass
* cleanup - all tests still pass
* cleanup. All nightly tests pass
---
python/tvm/dlight/gpu/general_reduction.py | 17 ++
.../frontend/torch/base_fx_graph_translator.py | 19 ++
.../frontend/torch/exported_program_translator.py | 11 +-
python/tvm/relax/frontend/torch/fx_translator.py | 17 +-
.../test_nnapi}/test_from_exported_to_cuda.py | 328 +++++++++++++++++++++
.../relax/test_frontend_from_exported_program.py | 32 ++
6 files changed, 413 insertions(+), 11 deletions(-)
diff --git a/python/tvm/dlight/gpu/general_reduction.py
b/python/tvm/dlight/gpu/general_reduction.py
index d3979ce0e4..b1564bf61f 100644
--- a/python/tvm/dlight/gpu/general_reduction.py
+++ b/python/tvm/dlight/gpu/general_reduction.py
@@ -61,6 +61,23 @@ class GeneralReduction(GPUScheduleRule):
# Align the number of block iters of the last block.
num_last_block_iter = len(block_infos[-1].dom_kind())
if num_last_block_iter < len(dom_kind):
+ # If the last block is a scalar value, there is nothing left to
+ # tile/parallelise, and `iters` is an empty tuple.
+ # Add a unit thread loop so the final write happens inside a valid
+ # GPU thread environment.
+ if num_last_block_iter == 0:
+ # Put every block (both the running reductions and the final
+ # scalar write) inside a trivial GPU thread. The very first
block
+ # gets a `blockIdx.x` wrapper so that kernels still have a
unique
+ # block scope.
+ for i, info in enumerate(block_infos):
+ loop_rv = sch.add_unit_loop(info.block_rv)
+ if i == 0:
+ sch.bind(loop_rv, "blockIdx.x")
+ else:
+ sch.bind(loop_rv, "threadIdx.x")
+
+ return sch
def f_layout_mapping(*iters):
analyzer = arith.Analyzer()
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index f683e62d24..a76cdc3db4 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -837,6 +837,25 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
groups=groups,
)
+ def _cross_entropy_loss(
+ self,
+ preds: relax.Expr,
+ targets: relax.Expr,
+ weights: Optional[relax.Expr],
+ reduction: str,
+ ignore_index: int,
+ ) -> relax.Expr:
+ log_probs = relax.op.nn.log_softmax(preds)
+ return self.block_builder.emit(
+ relax.op.nn.nll_loss(
+ log_probs,
+ targets,
+ weights,
+ reduction,
+ ignore_index,
+ )
+ )
+
def _einsum(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0600dfa552..9036b2941d 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -66,7 +66,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
########## Neural Network ##########
- def _batch_norm(self, node: fx.Node, training) -> relax.Var:
+ def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
import numpy as np
x = self.env[node.args[0]]
@@ -113,6 +113,14 @@ class ExportedProgramImporter(BaseFXGraphImporter):
training = False
return self._batch_norm(node, training)
+ def _cross_entropy_default(self, node: fx.Node) -> relax.Expr:
+ preds = self.env[node.args[0]]
+ targets = self.env[node.args[1]]
+ weight = self.env.get(node.args[2], None) if len(node.args) > 2 else
None
+ reduction = node.kwargs.get("reduction", "mean")
+ ignore_index = node.kwargs.get("ignore_index", -100)
+ return self._cross_entropy_loss(preds, targets, weight, reduction,
ignore_index)
+
def _group_norm(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
num_groups = node.args[1]
@@ -401,6 +409,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"conv1d.default": self._conv1d,
"conv2d.default": self._conv2d,
"conv3d.default": self._conv3d,
+ "cross_entropy_loss.default": self._cross_entropy_default,
"einsum.default": self._einsum,
"embedding.default": lambda node: self._embedding_impl(
self.env[node.args[1]], self.env[node.args[0]]
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 0a94c679f5..7e03bd4280 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -308,12 +308,7 @@ class TorchFXImporter(BaseFXGraphImporter):
weights = self.env.get(node.kwargs["weight"], None)
reduction = node.kwargs["reduction"]
ignore_index = node.kwargs["ignore_index"]
-
- return self.block_builder.emit(
- relax.op.nn.nll_loss(
- relax.op.nn.log_softmax(preds), targets, weights, reduction,
ignore_index
- )
- )
+ return self._cross_entropy_loss(preds, targets, weights, reduction,
ignore_index)
def _cross_entropy_module(self, node: fx.Node) -> relax.Expr:
preds = self.env[node.args[0]]
@@ -330,10 +325,12 @@ class TorchFXImporter(BaseFXGraphImporter):
reduction = module.reduction
ignore_index = module.ignore_index
- return self.block_builder.emit(
- relax.op.nn.nll_loss(
- relax.op.nn.log_softmax(preds), targets, weights, reduction,
ignore_index
- )
+ return self._cross_entropy_loss(
+ preds,
+ targets,
+ weights,
+ reduction,
+ ignore_index,
)
def _embedding_module(self, node: fx.Node) -> relax.Var:
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py
similarity index 70%
rename from tests/python/relax/test_from_exported_to_cuda.py
rename to tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py
index 6bb35b50b1..3f0964cfa8 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py
@@ -21,6 +21,7 @@ import tvm.testing
import numpy as np
import torch
from torch import nn
+from torch.nn import functional as F
from torch.export import export
from tvm.relax.frontend.torch import from_exported_program
from torch.nn import Softmax, Upsample
@@ -742,5 +743,332 @@ def test_concat(target, dev):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
[email protected]_targets("cuda")
+def test_leakyrelu_module(target, dev):
+ class LeakyReLUModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.act = nn.LeakyReLU(negative_slope=0.1)
+
+ def forward(self, x):
+ return self.act(x)
+
+ raw_data = np.random.randn(2, 3).astype(np.float32)
+ torch_module = LeakyReLUModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_log_softmax_module(target, dev):
+ class LogSoftmaxModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.logsoftmax = nn.LogSoftmax(dim=1)
+
+ def forward(self, x):
+ return self.logsoftmax(x)
+
+ raw_data = np.random.randn(4, 5).astype(np.float32)
+ torch_module = LogSoftmaxModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_softmax_module(target, dev):
+ class SoftmaxModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.softmax = nn.Softmax(dim=1)
+
+ def forward(self, x):
+ return self.softmax(x)
+
+ raw_data = np.random.randn(4, 5).astype(np.float32)
+ torch_module = SoftmaxModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_adaptive_avg_pool2d_module(target, dev):
+ class AdaptiveAvgPool2dModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = nn.AdaptiveAvgPool2d((1, 1))
+
+ def forward(self, x):
+ return self.pool(x)
+
+ raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32)
+ torch_module = AdaptiveAvgPool2dModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_avg_pool2d_module(target, dev):
+ class AvgPool2dModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = nn.AvgPool2d(kernel_size=2)
+
+ def forward(self, x):
+ return self.pool(x)
+
+ raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32)
+ torch_module = AvgPool2dModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_conv1d_module(target, dev):
+ class Conv1dModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv1d(in_channels=3, out_channels=4, kernel_size=3)
+
+ def forward(self, x):
+ return self.conv(x)
+
+ raw_data = np.random.randn(2, 3, 10).astype(np.float32)
+ torch_module = Conv1dModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_conv2d_module(target, dev):
+ class Conv2dModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3)
+
+ def forward(self, x):
+ return self.conv(x)
+
+ raw_data = np.random.randn(2, 3, 10, 10).astype(np.float32)
+ torch_module = Conv2dModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_conv3d_module(target, dev):
+ class Conv3dModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv3d(in_channels=2, out_channels=3, kernel_size=3)
+
+ def forward(self, x):
+ return self.conv(x)
+
+ raw_data = np.random.randn(1, 2, 8, 8, 8).astype(np.float32)
+ torch_module = Conv3dModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_group_norm_module(target, dev):
+ class GroupNormModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.gn = nn.GroupNorm(num_groups=1, num_channels=4)
+
+ def forward(self, x):
+ return self.gn(x)
+
+ raw_data = np.random.randn(2, 4, 8, 8).astype(np.float32)
+ torch_module = GroupNormModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_layer_norm_module(target, dev):
+ class LayerNormModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.ln = nn.LayerNorm(normalized_shape=8)
+
+ def forward(self, x):
+ return self.ln(x)
+
+ raw_data = np.random.randn(2, 4, 8).astype(np.float32)
+ torch_module = LayerNormModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_linear_module(target, dev):
+ class LinearModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = nn.Linear(10, 5)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ raw_data = np.random.randn(4, 10).astype(np.float32)
+ torch_module = LinearModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_max_pool2d_module(target, dev):
+ class MaxPool2dModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = nn.MaxPool2d(kernel_size=2)
+
+ def forward(self, x):
+ return self.pool(x)
+
+ raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32)
+ torch_module = MaxPool2dModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_embedding_module(target, dev):
+ class EmbeddingModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.embed = nn.Embedding(num_embeddings=10, embedding_dim=3)
+
+ def forward(self, x):
+ return self.embed(x)
+
+ raw_data = np.random.randint(0, 10, (2, 4)).astype(np.int64)
+ torch_module = EmbeddingModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_flatten_module(target, dev):
+ class FlattenModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.flatten = nn.Flatten()
+
+ def forward(self, x):
+ return self.flatten(x)
+
+ raw_data = np.random.randn(2, 3, 4, 5).astype(np.float32)
+ torch_module = FlattenModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_numel(target, dev):
+ class NumelModule(nn.Module):
+ def forward(self, x):
+ return torch.tensor(x.numel())
+
+ raw_data = np.random.randn(2, 3, 4).astype(np.float32)
+ torch_module = NumelModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_size(target, dev):
+ class SizeModule(nn.Module):
+ def forward(self, x):
+ return torch.tensor(x.size(0))
+
+ raw_data = np.random.randn(5, 4).astype(np.float32)
+ torch_module = SizeModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_tensor(target, dev):
+ class TensorModule(nn.Module):
+ def forward(self, x):
+ return torch.tensor([1, 2, 3])
+
+ raw_data = np.zeros((1,)).astype(np.float32)
+ torch_module = TensorModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_type(target, dev):
+ class TypeModule(nn.Module):
+ def forward(self, x):
+ return x.type(torch.float16)
+
+ raw_data = np.random.randn(2, 3).astype(np.float32)
+ torch_module = TypeModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_float(target, dev):
+ class FloatModule(nn.Module):
+ def forward(self, x):
+ return x.float()
+
+ raw_data = np.random.randn(2, 3).astype(np.float32)
+ torch_module = FloatModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_half(target, dev):
+ class HalfModule(nn.Module):
+ def forward(self, x):
+ return x.half()
+
+ raw_data = np.random.randn(2, 3).astype(np.float32)
+ torch_module = HalfModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_getattr(target, dev):
+ class GetAttrModule(nn.Module):
+ def forward(self, x):
+ # Use getattr to call the ndimension method.
+ return torch.tensor(getattr(x, "ndimension")())
+
+ raw_data = np.random.randn(2, 3, 4).astype(np.float32)
+ torch_module = GetAttrModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_sym_size_int(target, dev):
+ class SymSizeIntModule(nn.Module):
+ def forward(self, x):
+ return torch.tensor(x.shape[1])
+
+ raw_data = np.random.randn(2, 3, 4).astype(np.float32)
+ torch_module = SymSizeIntModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_interpolate(target, dev):
+ class InterpolateModule(nn.Module):
+ def forward(self, x):
+ # Upsample to a fixed size.
+ return F.interpolate(x, size=(16, 16), mode="nearest")
+
+ raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32)
+ torch_module = InterpolateModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_cross_entropy_module(target, dev):
+ class CrossEntropyModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.criterion = nn.CrossEntropyLoss()
+ self.target = torch.tensor([0, 1, 2, 1])
+
+ def forward(self, x):
+ return self.criterion(x, self.target)
+
+ raw_data = np.random.randn(4, 3).astype(np.float32)
+ torch_module = CrossEntropyModule().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index bcd96369d4..dd0cc7c9b1 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -17,6 +17,7 @@
import operator
import pytest
import torch
+from torch import nn
from torch.nn import Module
from torch.export import export
@@ -5294,6 +5295,36 @@ def test_eye():
verify_model(Eye2(), example_args2, {}, Expected2)
+def test_cross_entropy():
+ class CrossEntropyModule(Module):
+ def __init__(self):
+ super().__init__()
+ self.criterion = nn.CrossEntropyLoss()
+ self.target = torch.tensor([0, 1, 2, 1])
+
+ def forward(self, x):
+ return self.criterion(x, self.target)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x,
axis=-1)
+ lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss(
+ lv,
+ targets=R.const([0, 1, 2, 1], dtype="int64"),
+ reduction="mean",
+ ignore_index=-100,
+ )
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args1 = (torch.randn(4, 3, dtype=torch.float32),)
+ verify_model(CrossEntropyModule(), example_args1, {}, Expected1)
+
+
def test_linspace():
class Linspace(Module):
def forward(self, input):
@@ -5354,3 +5385,4 @@ def test_dtypes(torch_dtype, relax_dtype):
if __name__ == "__main__":
tvm.testing.main()
+1