This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d3e3d43b71 [Relax][PyTorch] CrossEntropyLoss (#17863)
d3e3d43b71 is described below

commit d3e3d43b7139415aa92c42f6c88c418dbffa62ca
Author: Hugo Latendresse <[email protected]>
AuthorDate: Thu May 8 12:04:49 2025 -0400

    [Relax][PyTorch] CrossEntropyLoss (#17863)
    
    * tests for add'l modules
    
    * no sort
    
    * cross entropy test passes
    
    * cleanup
    
    * fix expand
    
    * remove new e2e tests
    
    * remove new e2e tests
    
    * convert e2e test to unit test
    
    * unit test
    
    * restore tests
    
    * move
    
    * add new tests
    
    * add new tests from 17862
    
    * whitespace
    
    * print statemetns
    
    * all tests pass
    
    * cleanup - all tests still pass
    
    * cleanup. All nightly tests pass
---
 python/tvm/dlight/gpu/general_reduction.py         |  17 ++
 .../frontend/torch/base_fx_graph_translator.py     |  19 ++
 .../frontend/torch/exported_program_translator.py  |  11 +-
 python/tvm/relax/frontend/torch/fx_translator.py   |  17 +-
 .../test_nnapi}/test_from_exported_to_cuda.py      | 328 +++++++++++++++++++++
 .../relax/test_frontend_from_exported_program.py   |  32 ++
 6 files changed, 413 insertions(+), 11 deletions(-)

diff --git a/python/tvm/dlight/gpu/general_reduction.py 
b/python/tvm/dlight/gpu/general_reduction.py
index d3979ce0e4..b1564bf61f 100644
--- a/python/tvm/dlight/gpu/general_reduction.py
+++ b/python/tvm/dlight/gpu/general_reduction.py
@@ -61,6 +61,23 @@ class GeneralReduction(GPUScheduleRule):
         # Align the number of block iters of the last block.
         num_last_block_iter = len(block_infos[-1].dom_kind())
         if num_last_block_iter < len(dom_kind):
+            # If the last block is a scalar value, there is nothing left to
+            # tile/parallelise, and  `iters` is an empty tuple.
+            # Add a unit thread loop so the final write happens inside a valid
+            # GPU thread environment.
+            if num_last_block_iter == 0:
+                # Put every block (both the running reductions and the final
+                # scalar write) inside a trivial GPU thread. The very first 
block
+                # gets a `blockIdx.x` wrapper so that kernels still have a 
unique
+                # block scope.
+                for i, info in enumerate(block_infos):
+                    loop_rv = sch.add_unit_loop(info.block_rv)
+                    if i == 0:
+                        sch.bind(loop_rv, "blockIdx.x")
+                    else:
+                        sch.bind(loop_rv, "threadIdx.x")
+
+                return sch
 
             def f_layout_mapping(*iters):
                 analyzer = arith.Analyzer()
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index f683e62d24..a76cdc3db4 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -837,6 +837,25 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             groups=groups,
         )
 
+    def _cross_entropy_loss(
+        self,
+        preds: relax.Expr,
+        targets: relax.Expr,
+        weights: Optional[relax.Expr],
+        reduction: str,
+        ignore_index: int,
+    ) -> relax.Expr:
+        log_probs = relax.op.nn.log_softmax(preds)
+        return self.block_builder.emit(
+            relax.op.nn.nll_loss(
+                log_probs,
+                targets,
+                weights,
+                reduction,
+                ignore_index,
+            )
+        )
+
     def _einsum(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0600dfa552..9036b2941d 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -66,7 +66,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
     ########## Neural Network ##########
 
-    def _batch_norm(self, node: fx.Node, training) -> relax.Var:
+    def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
         import numpy as np
 
         x = self.env[node.args[0]]
@@ -113,6 +113,14 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         training = False
         return self._batch_norm(node, training)
 
+    def _cross_entropy_default(self, node: fx.Node) -> relax.Expr:
+        preds = self.env[node.args[0]]
+        targets = self.env[node.args[1]]
+        weight = self.env.get(node.args[2], None) if len(node.args) > 2 else 
None
+        reduction = node.kwargs.get("reduction", "mean")
+        ignore_index = node.kwargs.get("ignore_index", -100)
+        return self._cross_entropy_loss(preds, targets, weight, reduction, 
ignore_index)
+
     def _group_norm(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         num_groups = node.args[1]
@@ -401,6 +409,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "conv1d.default": self._conv1d,
             "conv2d.default": self._conv2d,
             "conv3d.default": self._conv3d,
+            "cross_entropy_loss.default": self._cross_entropy_default,
             "einsum.default": self._einsum,
             "embedding.default": lambda node: self._embedding_impl(
                 self.env[node.args[1]], self.env[node.args[0]]
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 0a94c679f5..7e03bd4280 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -308,12 +308,7 @@ class TorchFXImporter(BaseFXGraphImporter):
         weights = self.env.get(node.kwargs["weight"], None)
         reduction = node.kwargs["reduction"]
         ignore_index = node.kwargs["ignore_index"]
-
-        return self.block_builder.emit(
-            relax.op.nn.nll_loss(
-                relax.op.nn.log_softmax(preds), targets, weights, reduction, 
ignore_index
-            )
-        )
+        return self._cross_entropy_loss(preds, targets, weights, reduction, 
ignore_index)
 
     def _cross_entropy_module(self, node: fx.Node) -> relax.Expr:
         preds = self.env[node.args[0]]
@@ -330,10 +325,12 @@ class TorchFXImporter(BaseFXGraphImporter):
         reduction = module.reduction
         ignore_index = module.ignore_index
 
-        return self.block_builder.emit(
-            relax.op.nn.nll_loss(
-                relax.op.nn.log_softmax(preds), targets, weights, reduction, 
ignore_index
-            )
+        return self._cross_entropy_loss(
+            preds,
+            targets,
+            weights,
+            reduction,
+            ignore_index,
         )
 
     def _embedding_module(self, node: fx.Node) -> relax.Var:
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py
similarity index 70%
rename from tests/python/relax/test_from_exported_to_cuda.py
rename to tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py
index 6bb35b50b1..3f0964cfa8 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py
@@ -21,6 +21,7 @@ import tvm.testing
 import numpy as np
 import torch
 from torch import nn
+from torch.nn import functional as F
 from torch.export import export
 from tvm.relax.frontend.torch import from_exported_program
 from torch.nn import Softmax, Upsample
@@ -742,5 +743,332 @@ def test_concat(target, dev):
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
[email protected]_targets("cuda")
+def test_leakyrelu_module(target, dev):
+    class LeakyReLUModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.act = nn.LeakyReLU(negative_slope=0.1)
+
+        def forward(self, x):
+            return self.act(x)
+
+    raw_data = np.random.randn(2, 3).astype(np.float32)
+    torch_module = LeakyReLUModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_log_softmax_module(target, dev):
+    class LogSoftmaxModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.logsoftmax = nn.LogSoftmax(dim=1)
+
+        def forward(self, x):
+            return self.logsoftmax(x)
+
+    raw_data = np.random.randn(4, 5).astype(np.float32)
+    torch_module = LogSoftmaxModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_softmax_module(target, dev):
+    class SoftmaxModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.softmax = nn.Softmax(dim=1)
+
+        def forward(self, x):
+            return self.softmax(x)
+
+    raw_data = np.random.randn(4, 5).astype(np.float32)
+    torch_module = SoftmaxModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_adaptive_avg_pool2d_module(target, dev):
+    class AdaptiveAvgPool2dModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = nn.AdaptiveAvgPool2d((1, 1))
+
+        def forward(self, x):
+            return self.pool(x)
+
+    raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32)
+    torch_module = AdaptiveAvgPool2dModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_avg_pool2d_module(target, dev):
+    class AvgPool2dModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = nn.AvgPool2d(kernel_size=2)
+
+        def forward(self, x):
+            return self.pool(x)
+
+    raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32)
+    torch_module = AvgPool2dModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_conv1d_module(target, dev):
+    class Conv1dModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = nn.Conv1d(in_channels=3, out_channels=4, kernel_size=3)
+
+        def forward(self, x):
+            return self.conv(x)
+
+    raw_data = np.random.randn(2, 3, 10).astype(np.float32)
+    torch_module = Conv1dModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_conv2d_module(target, dev):
+    class Conv2dModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3)
+
+        def forward(self, x):
+            return self.conv(x)
+
+    raw_data = np.random.randn(2, 3, 10, 10).astype(np.float32)
+    torch_module = Conv2dModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_conv3d_module(target, dev):
+    class Conv3dModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = nn.Conv3d(in_channels=2, out_channels=3, kernel_size=3)
+
+        def forward(self, x):
+            return self.conv(x)
+
+    raw_data = np.random.randn(1, 2, 8, 8, 8).astype(np.float32)
+    torch_module = Conv3dModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_group_norm_module(target, dev):
+    class GroupNormModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.gn = nn.GroupNorm(num_groups=1, num_channels=4)
+
+        def forward(self, x):
+            return self.gn(x)
+
+    raw_data = np.random.randn(2, 4, 8, 8).astype(np.float32)
+    torch_module = GroupNormModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_layer_norm_module(target, dev):
+    class LayerNormModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.ln = nn.LayerNorm(normalized_shape=8)
+
+        def forward(self, x):
+            return self.ln(x)
+
+    raw_data = np.random.randn(2, 4, 8).astype(np.float32)
+    torch_module = LayerNormModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_linear_module(target, dev):
+    class LinearModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.linear = nn.Linear(10, 5)
+
+        def forward(self, x):
+            return self.linear(x)
+
+    raw_data = np.random.randn(4, 10).astype(np.float32)
+    torch_module = LinearModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_max_pool2d_module(target, dev):
+    class MaxPool2dModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = nn.MaxPool2d(kernel_size=2)
+
+        def forward(self, x):
+            return self.pool(x)
+
+    raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32)
+    torch_module = MaxPool2dModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_embedding_module(target, dev):
+    class EmbeddingModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.embed = nn.Embedding(num_embeddings=10, embedding_dim=3)
+
+        def forward(self, x):
+            return self.embed(x)
+
+    raw_data = np.random.randint(0, 10, (2, 4)).astype(np.int64)
+    torch_module = EmbeddingModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_flatten_module(target, dev):
+    class FlattenModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.flatten = nn.Flatten()
+
+        def forward(self, x):
+            return self.flatten(x)
+
+    raw_data = np.random.randn(2, 3, 4, 5).astype(np.float32)
+    torch_module = FlattenModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_numel(target, dev):
+    class NumelModule(nn.Module):
+        def forward(self, x):
+            return torch.tensor(x.numel())
+
+    raw_data = np.random.randn(2, 3, 4).astype(np.float32)
+    torch_module = NumelModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_size(target, dev):
+    class SizeModule(nn.Module):
+        def forward(self, x):
+            return torch.tensor(x.size(0))
+
+    raw_data = np.random.randn(5, 4).astype(np.float32)
+    torch_module = SizeModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_tensor(target, dev):
+    class TensorModule(nn.Module):
+        def forward(self, x):
+            return torch.tensor([1, 2, 3])
+
+    raw_data = np.zeros((1,)).astype(np.float32)
+    torch_module = TensorModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_type(target, dev):
+    class TypeModule(nn.Module):
+        def forward(self, x):
+            return x.type(torch.float16)
+
+    raw_data = np.random.randn(2, 3).astype(np.float32)
+    torch_module = TypeModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_float(target, dev):
+    class FloatModule(nn.Module):
+        def forward(self, x):
+            return x.float()
+
+    raw_data = np.random.randn(2, 3).astype(np.float32)
+    torch_module = FloatModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_half(target, dev):
+    class HalfModule(nn.Module):
+        def forward(self, x):
+            return x.half()
+
+    raw_data = np.random.randn(2, 3).astype(np.float32)
+    torch_module = HalfModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_getattr(target, dev):
+    class GetAttrModule(nn.Module):
+        def forward(self, x):
+            # Use getattr to call the ndimension method.
+            return torch.tensor(getattr(x, "ndimension")())
+
+    raw_data = np.random.randn(2, 3, 4).astype(np.float32)
+    torch_module = GetAttrModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_sym_size_int(target, dev):
+    class SymSizeIntModule(nn.Module):
+        def forward(self, x):
+            return torch.tensor(x.shape[1])
+
+    raw_data = np.random.randn(2, 3, 4).astype(np.float32)
+    torch_module = SymSizeIntModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_interpolate(target, dev):
+    class InterpolateModule(nn.Module):
+        def forward(self, x):
+            # Upsample to a fixed size.
+            return F.interpolate(x, size=(16, 16), mode="nearest")
+
+    raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32)
+    torch_module = InterpolateModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_cross_entropy_module(target, dev):
+    class CrossEntropyModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.criterion = nn.CrossEntropyLoss()
+            self.target = torch.tensor([0, 1, 2, 1])
+
+        def forward(self, x):
+            return self.criterion(x, self.target)
+
+    raw_data = np.random.randn(4, 3).astype(np.float32)
+    torch_module = CrossEntropyModule().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index bcd96369d4..dd0cc7c9b1 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -17,6 +17,7 @@
 import operator
 import pytest
 import torch
+from torch import nn
 from torch.nn import Module
 from torch.export import export
 
@@ -5294,6 +5295,36 @@ def test_eye():
     verify_model(Eye2(), example_args2, {}, Expected2)
 
 
+def test_cross_entropy():
+    class CrossEntropyModule(Module):
+        def __init__(self):
+            super().__init__()
+            self.criterion = nn.CrossEntropyLoss()
+            self.target = torch.tensor([0, 1, 2, 1])
+
+        def forward(self, x):
+            return self.criterion(x, self.target)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, 
axis=-1)
+                lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss(
+                    lv,
+                    targets=R.const([0, 1, 2, 1], dtype="int64"),
+                    reduction="mean",
+                    ignore_index=-100,
+                )
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args1 = (torch.randn(4, 3, dtype=torch.float32),)
+    verify_model(CrossEntropyModule(), example_args1, {}, Expected1)
+
+
 def test_linspace():
     class Linspace(Module):
         def forward(self, input):
@@ -5354,3 +5385,4 @@ def test_dtypes(torch_dtype, relax_dtype):
 
 if __name__ == "__main__":
     tvm.testing.main()
+1

Reply via email to