This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e812a219f6 [Unity][FX] Add support for PT2.0 
scaled_dot_product_attention (#14841)
e812a219f6 is described below

commit e812a219f6d0bc89a521c12402384fccce3e8196
Author: masahi <[email protected]>
AuthorDate: Tue May 16 11:07:08 2023 +0900

    [Unity][FX] Add support for PT2.0 scaled_dot_product_attention (#14841)
    
    * add converter for PT 2.0 scaled_dot_product_attention
    
    * remove requires_gpu in FX test
    
    * add test
    
    * more black
    
    * support float mask for attn_mask input
    
    * remove local import
    
    * more clean
---
 python/tvm/relax/frontend/torch/fx_translator.py |  18 +
 tests/python/relax/test_frontend_from_fx.py      | 522 +++++------------------
 2 files changed, 124 insertions(+), 416 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 99bc8f73f3..a29070a325 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1002,6 +1002,23 @@ class TorchFXImporter:
             )
         )
 
+    def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var:
+        assert len(node.args) <= 4, "Dropout, and causal masking are not 
supported."
+        transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 
3])
+        query = transpose_S_H(self.env[node.args[0]])
+        key = transpose_S_H(self.env[node.args[1]])
+        value = transpose_S_H(self.env[node.args[2]])
+
+        if len(node.args) == 4:
+            mask = self.env[node.args[3]]
+            msg = "Only a float mask is supported for the attn_mask input."
+            assert "float" in mask.struct_info.dtype, msg
+            attn = relax.op.nn.attention(query, key, value, bias=mask)
+        else:
+            attn = relax.op.nn.attention(query, key, value)
+
+        return self.block_builder.emit(attn)
+
     ########## Others ##########
 
     def _size(self, node: fx.node.Node) -> relax.Expr:
@@ -1185,6 +1202,7 @@ class TorchFXImporter:
             "neg": self._neg,
             "max": self._max,
             "cross_entropy": self._cross_entropy,
+            "scaled_dot_product_attention": self._scaled_dot_product_attention,
         }
 
     def from_fx(
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 6d20abe16d..40b9519386 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -15,6 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 import pytest
+import torch
+import torch.nn.functional as F
+from torch import fx
+from torch.nn import Module
 
 import tvm
 from tvm import relax
@@ -22,28 +26,20 @@ import tvm.testing
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
+from tvm.relax.frontend import detach_params
+from tvm.relax.frontend.torch import from_fx
 
 
 def verify_model(torch_model, input_info, binding, expected):
-    import torch
-    from torch import fx
-    from tvm.relax.frontend.torch import from_fx
-
     graph_model = fx.symbolic_trace(torch_model)
-    mod = from_fx(graph_model, input_info)
+    with torch.no_grad():
+        mod = from_fx(graph_model, input_info)
     binding = {k: tvm.nd.array(v) for k, v in binding.items()}
     expected = relax.transform.BindParams("main", binding)(expected)
     tvm.ir.assert_structural_equal(mod, expected)
 
 
[email protected]_gpu
 def test_conv1d():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     class Conv1D1(Module):
         def __init__(self):
             super().__init__()
@@ -114,22 +110,15 @@ def test_conv1d():
     input_info = [([1, 3, 10], "float32")]
 
     model = Conv1D1()
-    binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()}
+    binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
     verify_model(model, input_info, binding, expected1)
 
     model = Conv1D2()
-    binding = {"w1": model.conv.weight.numpy()}
+    binding = {"w1": model.conv.weight.detach().numpy()}
     verify_model(model, input_info, binding, expected2)
 
 
[email protected]_gpu
 def test_conv2d():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     class Conv2D1(Module):
         def __init__(self):
             super().__init__()
@@ -200,22 +189,15 @@ def test_conv2d():
     input_info = [([1, 3, 10, 10], "float32")]
 
     model = Conv2D1()
-    binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()}
+    binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
     verify_model(model, input_info, binding, expected1)
 
     model = Conv2D2()
-    binding = {"w1": model.conv.weight.numpy()}
+    binding = {"w1": model.conv.weight.detach().numpy()}
     verify_model(model, input_info, binding, expected2)
 
 
[email protected]_gpu
 def test_linear():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     # nn.Linear
     class Dense1(Module):
         def __init__(self):
@@ -272,11 +254,11 @@ def test_linear():
     input_info = [([1, 3, 10, 10], "float32")]
 
     model = Dense1()
-    binding = {"w1": model.linear.weight.numpy(), "w2": 
model.linear.bias.numpy()}
+    binding = {"w1": model.linear.weight.detach().numpy(), "w2": 
model.linear.bias.detach().numpy()}
     verify_model(model, input_info, binding, expected1)
 
     model = Dense2()
-    binding = {"w1": model.linear.weight.numpy()}
+    binding = {"w1": model.linear.weight.detach().numpy()}
     verify_model(model, input_info, binding, expected2)
 
     # matmul
@@ -311,14 +293,7 @@ def test_linear():
     )
 
 
[email protected]_gpu
 def test_bmm():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     class BMM(Module):
         def __init__(self):
             super().__init__()
@@ -350,14 +325,7 @@ def test_bmm():
     )
 
 
[email protected]_gpu
 def test_baddbmm():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     class BAddBMM1(Module):
         def __init__(self):
             super().__init__()
@@ -419,13 +387,7 @@ def test_baddbmm():
     )
 
 
[email protected]_gpu
 def test_relu():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-
     class ReLU0(Module):
         def __init__(self):
             super().__init__()
@@ -456,13 +418,7 @@ def test_relu():
     verify_model(ReLU1(), input_info, {}, expected)
 
 
[email protected]_gpu
 def test_relu6():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-
     class ReLU6(Module):
         def __init__(self):
             super().__init__()
@@ -486,14 +442,7 @@ def test_relu6():
     verify_model(ReLU6(), input_info, {}, expected)
 
 
[email protected]_gpu
 def test_maxpool2d():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class MaxPool2d(Module):
@@ -588,14 +537,7 @@ def test_maxpool2d():
     verify_model(MaxPool2d3(), input_info, {}, expected3)
 
 
[email protected]_gpu
 def test_avgpool2d():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class AvgPool2d(Module):
@@ -665,14 +607,7 @@ def test_avgpool2d():
     verify_model(AvgPool2d3(), input_info, {}, expected2)
 
 
[email protected]_gpu
 def test_adaptive_avgpool2d():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class AdaptiveAvgPool2d0(Module):
@@ -706,14 +641,7 @@ def test_adaptive_avgpool2d():
     verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_flatten():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Flatten(Module):
@@ -743,14 +671,7 @@ def test_flatten():
     verify_model(torch.nn.Flatten(2, -1), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_batchnorm2d():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class BatchNorm2d(Module):
@@ -795,22 +716,15 @@ def test_batchnorm2d():
 
     model = BatchNorm2d()
     binding = {
-        "w1": model.bn.weight.numpy(),
-        "w2": model.bn.bias.numpy(),
-        "w3": model.bn.running_mean.numpy(),
-        "w4": model.bn.running_var.numpy(),
+        "w1": model.bn.weight.detach().numpy(),
+        "w2": model.bn.bias.detach().numpy(),
+        "w3": model.bn.running_mean.detach().numpy(),
+        "w4": model.bn.running_var.detach().numpy(),
     }
     verify_model(BatchNorm2d(), input_info, binding, expected1)
 
 
[email protected]_gpu
 def test_embedding():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([4], "int64")]
 
     class Embedding(Module):
@@ -836,18 +750,11 @@ def test_embedding():
             return gv
 
     model = Embedding()
-    binding = {"w1": model.embedding.weight.numpy()}
+    binding = {"w1": model.embedding.weight.detach().numpy()}
     verify_model(model, input_info, binding, expected1)
 
 
[email protected]_gpu
 def test_dropout():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Dropout1(Module):
@@ -878,14 +785,7 @@ def test_dropout():
     verify_model(Dropout2(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_layernorm():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class LayerNorm(Module):
@@ -921,20 +821,13 @@ def test_layernorm():
 
     model = LayerNorm()
     binding = {
-        "w1": model.ln.weight.numpy(),
-        "w2": model.ln.bias.numpy(),
+        "w1": model.ln.weight.detach().numpy(),
+        "w2": model.ln.bias.detach().numpy(),
     }
     verify_model(LayerNorm(), input_info, binding, expected1)
 
 
[email protected]_gpu
 def test_functional_layernorm():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class LayerNorm(Module):
@@ -973,20 +866,13 @@ def test_functional_layernorm():
 
     model = LayerNorm((10, 10))
     binding = {
-        "w1": model.weight.numpy(),
-        "w2": model.bias.numpy(),
+        "w1": model.weight.detach().numpy(),
+        "w2": model.bias.detach().numpy(),
     }
     verify_model(model, input_info, binding, expected1)
 
 
[email protected]_gpu
 def test_cross_entropy():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([3, 2], "float32"), ([3], "int32")]
 
     class CrossEntropy1(Module):
@@ -1067,19 +953,12 @@ def test_cross_entropy():
 
     verify_model(CrossEntropy1(), input_info, {}, expected1)
     model = CrossEntropy2()
-    binding = {"w1": model.loss.weight.numpy()}
+    binding = {"w1": model.loss.weight.detach().numpy()}
     verify_model(model, input_info, binding, expected2)
     verify_model(CrossEntropy3(), input_info, {}, expected3)
 
 
[email protected]_gpu
 def test_functional_cross_entropy():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([3, 10], "float32"), ([3], "int32")]
 
     class CrossEntropy(Module):
@@ -1105,14 +984,7 @@ def test_functional_cross_entropy():
     verify_model(model, input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_silu():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class SiLU(Module):
@@ -1144,7 +1016,6 @@ def test_silu():
     verify_model(SiLU2(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_groupnorm():
     import torch
     from torch.nn import Module
@@ -1188,20 +1059,13 @@ def test_groupnorm():
 
     model = GroupNorm()
     binding = {
-        "w1": model.gn.weight.numpy(),
-        "w2": model.gn.bias.numpy(),
+        "w1": model.gn.weight.detach().numpy(),
+        "w2": model.gn.bias.detach().numpy(),
     }
     verify_model(model, input_info, binding, expected1)
 
 
[email protected]_gpu
 def test_softmax():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Softmax(Module):
@@ -1228,14 +1092,7 @@ def test_softmax():
     verify_model(Softmax(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_binary():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
     input_info2 = [([1, 3, 10, 10], "float32")]
 
@@ -1513,14 +1370,7 @@ def test_binary():
     verify_model(LT2(), input_info2, {}, expected14)
 
 
[email protected]_gpu
 def test_size():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Size(Module):
@@ -1540,14 +1390,7 @@ def test_size():
     verify_model(Size(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_squeeze():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([3, 1, 4, 1], "float32")]
 
     class Squeeze1(Module):
@@ -1586,14 +1429,7 @@ def test_squeeze():
     verify_model(Squeeze2(), input_info, {}, Expected2)
 
 
[email protected]_gpu
 def test_unsqueeze():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Unsqueeze1(Module):
@@ -1634,14 +1470,7 @@ def test_unsqueeze():
     verify_model(Unsqueeze2(), input_info, {}, expected2)
 
 
[email protected]_gpu
 def test_getattr():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class GetAttr1(Module):
@@ -1661,14 +1490,7 @@ def test_getattr():
     verify_model(GetAttr1(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_getitem():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     class Slice1(Module):
         def forward(self, x):
             return x[0, 1::2, :, :3]
@@ -1718,14 +1540,7 @@ def test_getitem():
     verify_model(Slice2(), [([8, 16], "float32")], {}, expected2)
 
 
[email protected]_gpu
 def test_unary():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     # sin
@@ -1849,14 +1664,7 @@ def test_unary():
     verify_model(Round(), input_info, {}, expected5)
 
 
[email protected]_gpu
 def test_gelu():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Gelu(Module):
@@ -1879,14 +1687,7 @@ def test_gelu():
     verify_model(Gelu(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_tanh():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Tanh(Module):
@@ -1909,15 +1710,7 @@ def test_tanh():
     verify_model(Tanh(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_clamp():
-    import torch
-    from torch import fx
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Clamp(Module):
@@ -1964,14 +1757,7 @@ def test_clamp():
         from_fx(gm, input_info)
 
 
[email protected]_gpu
 def test_interpolate():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Interpolate(Module):
@@ -2006,14 +1792,7 @@ def test_interpolate():
     verify_model(Interpolate(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_addmm():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [
         ([10, 10], "float32"),
         ([10, 10], "float32"),
@@ -2043,14 +1822,7 @@ def test_addmm():
     verify_model(Addmm(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_split():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Split(Module):
@@ -2085,14 +1857,7 @@ def test_split():
     verify_model(Split(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_cumsum():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 2, 3, 4], "float32")]
 
     class Cumsum(Module):
@@ -2115,14 +1880,7 @@ def test_cumsum():
     verify_model(Cumsum(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_chunk():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 3, 10, 10], "float32")]
 
     class Chunk(Module):
@@ -2157,14 +1915,7 @@ def test_chunk():
     verify_model(Chunk(), input_info, {}, Expected)
 
 
[email protected]_gpu
 def test_inplace_fill():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     class InplaceFill(Module):
         def forward(self, input):
             input.fill_(1.5)
@@ -2185,13 +1936,8 @@ def test_inplace_fill():
     verify_model(InplaceFill(), [([10, 10], "float32")], {}, Expected)
 
 
[email protected]_gpu
 def test_arange():
     import numpy as np
-    import torch
-    from torch import fx
-    from torch.nn import Module
-    from tvm.relax.frontend.torch import from_fx
 
     torch.set_grad_enabled(False)
     torch.random.manual_seed(0)
@@ -2206,20 +1952,12 @@ def test_arange():
     assert len(mod["main"].body.blocks[0].bindings) == 1
     assert isinstance(mod["main"].body.blocks[0].bindings[0].value, 
relax.Constant)
     tvm.testing.assert_allclose(
-        mod["main"].body.blocks[0].bindings[0].value.data.numpy(), 
np.arange(0, 20, dtype="int32")
+        mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
+        np.arange(0, 20, dtype="int32"),
     )
 
 
[email protected]_gpu
 def test_empty():
-    import torch
-    from torch import fx
-    from torch.nn import Module
-    from tvm.relax.frontend.torch import from_fx
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     class Empty(Module):
         def forward(self, input):
             return torch.empty((10, 10), dtype=torch.float32)
@@ -2233,16 +1971,7 @@ def test_empty():
     assert mod["main"].body.blocks[0].bindings[0].value.data.dtype == "float32"
 
 
[email protected]_gpu
 def test_tensor():
-    import torch
-    from torch import fx
-    from torch.nn import Module
-    from tvm.relax.frontend.torch import from_fx
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     class Empty1(Module):
         def forward(self, input):
             return torch.tensor(3, dtype=torch.float32)
@@ -2268,14 +1997,7 @@ def test_tensor():
     assert mod2["main"].body.blocks[0].bindings[0].value.data.dtype == "int64"
 
 
[email protected]_gpu
 def test_tril():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([10, 10], "float32")]
 
     class Tril(Module):
@@ -2304,14 +2026,7 @@ def test_tril():
     verify_model(InplaceTril(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_triu():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([10, 10], "float32")]
 
     class Triu(Module):
@@ -2340,14 +2055,7 @@ def test_triu():
     verify_model(InplaceTriu(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_new_ones():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 2, 3], "float32")]
 
     class NewOnes(Module):
@@ -2370,14 +2078,7 @@ def test_new_ones():
     verify_model(NewOnes(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_expand():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 2, 3, 4], "float32")]
 
     class Expand(Module):
@@ -2400,14 +2101,7 @@ def test_expand():
     verify_model(Expand(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_reduce():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 2, 3, 4], "float32")]
 
     # sum
@@ -2431,14 +2125,7 @@ def test_reduce():
     verify_model(Sum(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_datatype():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 2, 3, 4], "float32")]
 
     # float
@@ -2501,14 +2188,7 @@ def test_datatype():
     verify_model(AsType(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_permute():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 2, 3, 4], "float32")]
 
     class Permute(Module):
@@ -2531,14 +2211,7 @@ def test_permute():
     verify_model(Permute(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_reshape():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 2, 3, 4], "float32")]
 
     class Reshape(Module):
@@ -2559,14 +2232,7 @@ def test_reshape():
     verify_model(Reshape(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_transpose():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 2, 3, 4], "float32")]
 
     class Transpose(Module):
@@ -2589,14 +2255,7 @@ def test_transpose():
     verify_model(Transpose(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_view():
-    import torch
-    from torch.nn import Module
-
-    torch.set_grad_enabled(False)
-    torch.random.manual_seed(0)
-
     input_info = [([1, 2, 3, 4], "float32")]
 
     class View(Module):
@@ -2617,14 +2276,7 @@ def test_view():
     verify_model(View(), input_info, {}, expected1)
 
 
[email protected]_gpu
 def test_keep_params():
-    import torch
-    from torch import fx
-    from torch.nn import Module
-    from tvm.relax.frontend import detach_params
-    from tvm.relax.frontend.torch import from_fx
-
     class Conv2D1(Module):
         def __init__(self):
             super().__init__()
@@ -2674,16 +2326,11 @@ def test_keep_params():
         assert tuple(x.value for x in param_var.struct_info.shape.values) == 
param_ndarray.shape
         assert param_var.struct_info.dtype == param_ndarray.dtype
 
-    tvm.testing.assert_allclose(params[0].numpy(), 
model.conv.bias.detach().numpy())
-    tvm.testing.assert_allclose(params[1].numpy(), 
model.conv.weight.detach().numpy())
+    tvm.testing.assert_allclose(params[0].numpy(), 
model.conv.bias.detach().detach().numpy())
+    tvm.testing.assert_allclose(params[1].numpy(), 
model.conv.weight.detach().detach().numpy())
 
 
[email protected]_gpu
 def test_unwrap_unit_return_tuple():
-    import torch.fx as fx
-    from torch.nn import Module
-    from tvm.relax.frontend.torch import from_fx
-
     class Identity(Module):
         def __init__(self):
             super().__init__()
@@ -2707,12 +2354,7 @@ def test_unwrap_unit_return_tuple():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
[email protected]_gpu
 def test_no_bind_return_tuple():
-    import torch.fx as fx
-    from torch.nn import Module
-    from tvm.relax.frontend.torch import from_fx
-
     class Identity(Module):
         def __init__(self):
             super().__init__()
@@ -2740,11 +2382,7 @@ def test_no_bind_return_tuple():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
[email protected]_gpu
 def test_argmax():
-    import torch
-    from torch.nn import Module
-
     class Argmax1(Module):
         def __init__(self) -> None:
             super().__init__()
@@ -2783,11 +2421,7 @@ def test_argmax():
     verify_model(Argmax2(), [([256, 256], "float32")], {}, Expected2)
 
 
[email protected]_gpu
 def test_argmin():
-    import torch
-    from torch.nn import Module
-
     class Argmin1(Module):
         def __init__(self) -> None:
             super().__init__()
@@ -2826,11 +2460,7 @@ def test_argmin():
     verify_model(Argmin2(), [([256, 256], "float32")], {}, Expected2)
 
 
[email protected]_gpu
 def test_to():
-    import torch
-    from torch.nn import Module
-
     class To1(Module):
         def forward(self, input):
             return input.to(torch.float16)
@@ -2866,11 +2496,7 @@ def test_to():
     verify_model(To2(), [([256, 256], "float32")], {}, Expected2)
 
 
[email protected]_gpu
 def test_mean():
-    import torch
-    from torch.nn import Module
-
     class Mean(Module):
         def forward(self, input):
             return input.mean(-1)
@@ -2905,11 +2531,7 @@ def test_mean():
     verify_model(MeanKeepDim(), [([256, 256], "float32")], {}, Expected2)
 
 
[email protected]_gpu
 def test_rsqrt():
-    import torch
-    from torch.nn import Module
-
     class Rsqrt(Module):
         def forward(self, input):
             return torch.rsqrt(input)
@@ -2929,11 +2551,7 @@ def test_rsqrt():
     verify_model(Rsqrt(), [([256, 256], "float32")], {}, Expected1)
 
 
[email protected]_gpu
 def test_neg():
-    import torch
-    from torch.nn import Module
-
     class Neg(Module):
         def forward(self, input):
             return -input
@@ -2953,11 +2571,7 @@ def test_neg():
     verify_model(Neg(), [([256, 256], "float32")], {}, Expected1)
 
 
[email protected]_gpu
 def test_max():
-    import torch
-    from torch.nn import Module
-
     class Max(Module):
         def forward(self, x, y):
             return torch.max(x, y)
@@ -2978,5 +2592,81 @@ def test_max():
     verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], 
{}, Expected1)
 
 
+def test_attention():
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+        ) -> R.Tensor((32, 128, 8, 64), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_0, axes=[0, 2, 1, 3]
+                )
+                lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_1, axes=[0, 2, 1, 3]
+                )
+                lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_2, axes=[0, 2, 1, 3]
+                )
+                lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.nn.attention(
+                    lv, lv1, lv2, scale=None
+                )
+                gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"),
+        ) -> R.Tensor((32, 128, 8, 64), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_0, axes=[0, 2, 1, 3]
+                )
+                lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_1, axes=[0, 2, 1, 3]
+                )
+                lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_2, axes=[0, 2, 1, 3]
+                )
+                lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.nn.attention(
+                    lv, lv1, lv2, inp_3, scale=None
+                )
+                gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    verify_model(
+        lambda q, k, v: F.scaled_dot_product_attention(q, k, v),
+        [
+            ([32, 8, 128, 64], "float32"),
+            ([32, 8, 128, 64], "float32"),
+            ([32, 8, 128, 64], "float32"),
+        ],
+        {},
+        Expected1,
+    )
+
+    verify_model(
+        lambda q, k, v, mask: F.scaled_dot_product_attention(q, k, v, mask),
+        [
+            ([32, 8, 128, 64], "float32"),
+            ([32, 8, 128, 64], "float32"),
+            ([32, 8, 128, 64], "float32"),
+            ([32, 8, 128, 128], "float32"),
+        ],
+        {},
+        Expected2,
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to