This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ca6ec8a5d [Relax][PyTorch] Sort.default (#17852)
2ca6ec8a5d is described below

commit 2ca6ec8a5d1cd22dbf428b3b5cd9f899d058ea3c
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Apr 21 15:17:16 2025 -0400

    [Relax][PyTorch] Sort.default (#17852)
    
    Add support for sort.default in exported program translator.
    There was an existing _sort() function in base_fx_graph_translator.py,
    but it would return values only. Pytorch returns a tuple of values and
    indices, so that was corrected
---
 .../frontend/torch/base_fx_graph_translator.py     |  7 +++++-
 .../frontend/torch/exported_program_translator.py  |  1 +
 tests/python/relax/test_from_exported_to_cuda.py   | 25 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 17 +++++++++++----
 4 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 13d13ff24c..20556167c1 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1278,10 +1278,15 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         return self.block_builder.emit(relax.op.scatter_elements(x, index, 
src, axis=dim))
 
     def _sort(self, node: fx.Node) -> relax.Var:
+        # torch.sort() returns a tuple of values and indices
+        # we use argsort to get indices and gather_elements to get values
         x = self.env[node.args[0]]
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
-1)
         descending = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("descending", False)
-        return self.block_builder.emit(relax.op.sort(x, dim, descending))
+
+        indices = self.block_builder.emit(relax.op.argsort(x, dim, descending))
+        values = self.block_builder.emit(relax.op.gather_elements(x, indices, 
axis=dim))
+        return self.block_builder.emit(relax.Tuple([values, indices]))
 
     def _split(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index ed6740a25e..f38f353a9e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -431,6 +431,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "roll.default": self._roll,
             "select.int": self._select,
             "slice.Tensor": self._slice,
+            "sort.default": self._sort,
             "split.Tensor": self._split,
             "split_with_sizes.default": self._split,
             "squeeze.default": self._squeeze,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index 76a4bb2039..6bb35b50b1 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -208,6 +208,31 @@ def test_ones(target, dev):
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
[email protected]_targets("cuda")
+def test_sort(target, dev):
+    raw_data = np.array([[4, 1, 13], [-30, 1, 3], [4, 0, 
10]]).astype("float32")
+
+    # Test values
+    class SortModelValues(nn.Module):
+        def forward(self, x):
+            A, _ = torch.sort(x, dim=0, descending=True)
+            B, _ = torch.sort(x, dim=1, descending=False)
+            return A + B
+
+    torch_module = SortModelValues().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    # Test indices
+    class SortModelIndices(nn.Module):
+        def forward(self, x):
+            _, A = torch.sort(x, dim=0, descending=True)
+            _, B = torch.sort(x, dim=1, descending=False)
+            return A + B
+
+    torch_module = SortModelIndices().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_tensor_clamp(target, dev):
     class ClampBothTensor(torch.nn.Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index e8db6af347..2d27fa1f59 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4749,11 +4749,20 @@ def test_sort():
     class Expected:
         @R.function
         def main(
-            inp_0: R.Tensor((5, 3), dtype="float32"),
-        ) -> R.Tensor((5, 3), dtype="float32"):
+            inp_0: R.Tensor((5, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), 
dtype="int32")):
             with R.dataflow():
-                lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1, 
descending=True)
-                gv: R.Tensor((5, 3), dtype="float32") = lv
+                lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
+                    inp_0, axis=1, descending=True, dtype="int32"
+                )
+                lv1: R.Tensor((5, 3), dtype="float32") = 
R.gather_elements(inp_0, lv, axis=1)
+                lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 
3), dtype="int32")) = (
+                    lv1,
+                    lv,
+                )
+                gv: R.Tuple(
+                    R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), 
dtype="int32")
+                ) = lv2
                 R.output(gv)
             return gv
 

Reply via email to