This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e812a219f6 [Unity][FX] Add support for PT2.0
scaled_dot_product_attention (#14841)
e812a219f6 is described below
commit e812a219f6d0bc89a521c12402384fccce3e8196
Author: masahi <[email protected]>
AuthorDate: Tue May 16 11:07:08 2023 +0900
[Unity][FX] Add support for PT2.0 scaled_dot_product_attention (#14841)
* add converter for PT 2.0 scaled_dot_product_attention
* remove requires_gpu in FX test
* add test
* more black
* support float mask for attn_mask input
* remove local import
* more clean
---
python/tvm/relax/frontend/torch/fx_translator.py | 18 +
tests/python/relax/test_frontend_from_fx.py | 522 +++++------------------
2 files changed, 124 insertions(+), 416 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 99bc8f73f3..a29070a325 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1002,6 +1002,23 @@ class TorchFXImporter:
)
)
+ def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var:
+ assert len(node.args) <= 4, "Dropout, and causal masking are not
supported."
+ transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1,
3])
+ query = transpose_S_H(self.env[node.args[0]])
+ key = transpose_S_H(self.env[node.args[1]])
+ value = transpose_S_H(self.env[node.args[2]])
+
+ if len(node.args) == 4:
+ mask = self.env[node.args[3]]
+ msg = "Only a float mask is supported for the attn_mask input."
+ assert "float" in mask.struct_info.dtype, msg
+ attn = relax.op.nn.attention(query, key, value, bias=mask)
+ else:
+ attn = relax.op.nn.attention(query, key, value)
+
+ return self.block_builder.emit(attn)
+
########## Others ##########
def _size(self, node: fx.node.Node) -> relax.Expr:
@@ -1185,6 +1202,7 @@ class TorchFXImporter:
"neg": self._neg,
"max": self._max,
"cross_entropy": self._cross_entropy,
+ "scaled_dot_product_attention": self._scaled_dot_product_attention,
}
def from_fx(
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 6d20abe16d..40b9519386 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -15,6 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import pytest
+import torch
+import torch.nn.functional as F
+from torch import fx
+from torch.nn import Module
import tvm
from tvm import relax
@@ -22,28 +26,20 @@ import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
+from tvm.relax.frontend import detach_params
+from tvm.relax.frontend.torch import from_fx
def verify_model(torch_model, input_info, binding, expected):
- import torch
- from torch import fx
- from tvm.relax.frontend.torch import from_fx
-
graph_model = fx.symbolic_trace(torch_model)
- mod = from_fx(graph_model, input_info)
+ with torch.no_grad():
+ mod = from_fx(graph_model, input_info)
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("main", binding)(expected)
tvm.ir.assert_structural_equal(mod, expected)
[email protected]_gpu
def test_conv1d():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
class Conv1D1(Module):
def __init__(self):
super().__init__()
@@ -114,22 +110,15 @@ def test_conv1d():
input_info = [([1, 3, 10], "float32")]
model = Conv1D1()
- binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()}
+ binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
verify_model(model, input_info, binding, expected1)
model = Conv1D2()
- binding = {"w1": model.conv.weight.numpy()}
+ binding = {"w1": model.conv.weight.detach().numpy()}
verify_model(model, input_info, binding, expected2)
[email protected]_gpu
def test_conv2d():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
class Conv2D1(Module):
def __init__(self):
super().__init__()
@@ -200,22 +189,15 @@ def test_conv2d():
input_info = [([1, 3, 10, 10], "float32")]
model = Conv2D1()
- binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()}
+ binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
verify_model(model, input_info, binding, expected1)
model = Conv2D2()
- binding = {"w1": model.conv.weight.numpy()}
+ binding = {"w1": model.conv.weight.detach().numpy()}
verify_model(model, input_info, binding, expected2)
[email protected]_gpu
def test_linear():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
# nn.Linear
class Dense1(Module):
def __init__(self):
@@ -272,11 +254,11 @@ def test_linear():
input_info = [([1, 3, 10, 10], "float32")]
model = Dense1()
- binding = {"w1": model.linear.weight.numpy(), "w2":
model.linear.bias.numpy()}
+ binding = {"w1": model.linear.weight.detach().numpy(), "w2":
model.linear.bias.detach().numpy()}
verify_model(model, input_info, binding, expected1)
model = Dense2()
- binding = {"w1": model.linear.weight.numpy()}
+ binding = {"w1": model.linear.weight.detach().numpy()}
verify_model(model, input_info, binding, expected2)
# matmul
@@ -311,14 +293,7 @@ def test_linear():
)
[email protected]_gpu
def test_bmm():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
class BMM(Module):
def __init__(self):
super().__init__()
@@ -350,14 +325,7 @@ def test_bmm():
)
[email protected]_gpu
def test_baddbmm():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
class BAddBMM1(Module):
def __init__(self):
super().__init__()
@@ -419,13 +387,7 @@ def test_baddbmm():
)
[email protected]_gpu
def test_relu():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
-
class ReLU0(Module):
def __init__(self):
super().__init__()
@@ -456,13 +418,7 @@ def test_relu():
verify_model(ReLU1(), input_info, {}, expected)
[email protected]_gpu
def test_relu6():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
-
class ReLU6(Module):
def __init__(self):
super().__init__()
@@ -486,14 +442,7 @@ def test_relu6():
verify_model(ReLU6(), input_info, {}, expected)
[email protected]_gpu
def test_maxpool2d():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class MaxPool2d(Module):
@@ -588,14 +537,7 @@ def test_maxpool2d():
verify_model(MaxPool2d3(), input_info, {}, expected3)
[email protected]_gpu
def test_avgpool2d():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class AvgPool2d(Module):
@@ -665,14 +607,7 @@ def test_avgpool2d():
verify_model(AvgPool2d3(), input_info, {}, expected2)
[email protected]_gpu
def test_adaptive_avgpool2d():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class AdaptiveAvgPool2d0(Module):
@@ -706,14 +641,7 @@ def test_adaptive_avgpool2d():
verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1)
[email protected]_gpu
def test_flatten():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Flatten(Module):
@@ -743,14 +671,7 @@ def test_flatten():
verify_model(torch.nn.Flatten(2, -1), input_info, {}, expected1)
[email protected]_gpu
def test_batchnorm2d():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class BatchNorm2d(Module):
@@ -795,22 +716,15 @@ def test_batchnorm2d():
model = BatchNorm2d()
binding = {
- "w1": model.bn.weight.numpy(),
- "w2": model.bn.bias.numpy(),
- "w3": model.bn.running_mean.numpy(),
- "w4": model.bn.running_var.numpy(),
+ "w1": model.bn.weight.detach().numpy(),
+ "w2": model.bn.bias.detach().numpy(),
+ "w3": model.bn.running_mean.detach().numpy(),
+ "w4": model.bn.running_var.detach().numpy(),
}
verify_model(BatchNorm2d(), input_info, binding, expected1)
[email protected]_gpu
def test_embedding():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([4], "int64")]
class Embedding(Module):
@@ -836,18 +750,11 @@ def test_embedding():
return gv
model = Embedding()
- binding = {"w1": model.embedding.weight.numpy()}
+ binding = {"w1": model.embedding.weight.detach().numpy()}
verify_model(model, input_info, binding, expected1)
[email protected]_gpu
def test_dropout():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Dropout1(Module):
@@ -878,14 +785,7 @@ def test_dropout():
verify_model(Dropout2(), input_info, {}, expected1)
[email protected]_gpu
def test_layernorm():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class LayerNorm(Module):
@@ -921,20 +821,13 @@ def test_layernorm():
model = LayerNorm()
binding = {
- "w1": model.ln.weight.numpy(),
- "w2": model.ln.bias.numpy(),
+ "w1": model.ln.weight.detach().numpy(),
+ "w2": model.ln.bias.detach().numpy(),
}
verify_model(LayerNorm(), input_info, binding, expected1)
[email protected]_gpu
def test_functional_layernorm():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class LayerNorm(Module):
@@ -973,20 +866,13 @@ def test_functional_layernorm():
model = LayerNorm((10, 10))
binding = {
- "w1": model.weight.numpy(),
- "w2": model.bias.numpy(),
+ "w1": model.weight.detach().numpy(),
+ "w2": model.bias.detach().numpy(),
}
verify_model(model, input_info, binding, expected1)
[email protected]_gpu
def test_cross_entropy():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([3, 2], "float32"), ([3], "int32")]
class CrossEntropy1(Module):
@@ -1067,19 +953,12 @@ def test_cross_entropy():
verify_model(CrossEntropy1(), input_info, {}, expected1)
model = CrossEntropy2()
- binding = {"w1": model.loss.weight.numpy()}
+ binding = {"w1": model.loss.weight.detach().numpy()}
verify_model(model, input_info, binding, expected2)
verify_model(CrossEntropy3(), input_info, {}, expected3)
[email protected]_gpu
def test_functional_cross_entropy():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([3, 10], "float32"), ([3], "int32")]
class CrossEntropy(Module):
@@ -1105,14 +984,7 @@ def test_functional_cross_entropy():
verify_model(model, input_info, {}, expected1)
[email protected]_gpu
def test_silu():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class SiLU(Module):
@@ -1144,7 +1016,6 @@ def test_silu():
verify_model(SiLU2(), input_info, {}, expected1)
[email protected]_gpu
def test_groupnorm():
import torch
from torch.nn import Module
@@ -1188,20 +1059,13 @@ def test_groupnorm():
model = GroupNorm()
binding = {
- "w1": model.gn.weight.numpy(),
- "w2": model.gn.bias.numpy(),
+ "w1": model.gn.weight.detach().numpy(),
+ "w2": model.gn.bias.detach().numpy(),
}
verify_model(model, input_info, binding, expected1)
[email protected]_gpu
def test_softmax():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Softmax(Module):
@@ -1228,14 +1092,7 @@ def test_softmax():
verify_model(Softmax(), input_info, {}, expected1)
[email protected]_gpu
def test_binary():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
input_info2 = [([1, 3, 10, 10], "float32")]
@@ -1513,14 +1370,7 @@ def test_binary():
verify_model(LT2(), input_info2, {}, expected14)
[email protected]_gpu
def test_size():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Size(Module):
@@ -1540,14 +1390,7 @@ def test_size():
verify_model(Size(), input_info, {}, expected1)
[email protected]_gpu
def test_squeeze():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([3, 1, 4, 1], "float32")]
class Squeeze1(Module):
@@ -1586,14 +1429,7 @@ def test_squeeze():
verify_model(Squeeze2(), input_info, {}, Expected2)
[email protected]_gpu
def test_unsqueeze():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Unsqueeze1(Module):
@@ -1634,14 +1470,7 @@ def test_unsqueeze():
verify_model(Unsqueeze2(), input_info, {}, expected2)
[email protected]_gpu
def test_getattr():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class GetAttr1(Module):
@@ -1661,14 +1490,7 @@ def test_getattr():
verify_model(GetAttr1(), input_info, {}, expected1)
[email protected]_gpu
def test_getitem():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
class Slice1(Module):
def forward(self, x):
return x[0, 1::2, :, :3]
@@ -1718,14 +1540,7 @@ def test_getitem():
verify_model(Slice2(), [([8, 16], "float32")], {}, expected2)
[email protected]_gpu
def test_unary():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
# sin
@@ -1849,14 +1664,7 @@ def test_unary():
verify_model(Round(), input_info, {}, expected5)
[email protected]_gpu
def test_gelu():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Gelu(Module):
@@ -1879,14 +1687,7 @@ def test_gelu():
verify_model(Gelu(), input_info, {}, expected1)
[email protected]_gpu
def test_tanh():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Tanh(Module):
@@ -1909,15 +1710,7 @@ def test_tanh():
verify_model(Tanh(), input_info, {}, expected1)
[email protected]_gpu
def test_clamp():
- import torch
- from torch import fx
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Clamp(Module):
@@ -1964,14 +1757,7 @@ def test_clamp():
from_fx(gm, input_info)
[email protected]_gpu
def test_interpolate():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Interpolate(Module):
@@ -2006,14 +1792,7 @@ def test_interpolate():
verify_model(Interpolate(), input_info, {}, expected1)
[email protected]_gpu
def test_addmm():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [
([10, 10], "float32"),
([10, 10], "float32"),
@@ -2043,14 +1822,7 @@ def test_addmm():
verify_model(Addmm(), input_info, {}, expected1)
[email protected]_gpu
def test_split():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Split(Module):
@@ -2085,14 +1857,7 @@ def test_split():
verify_model(Split(), input_info, {}, expected1)
[email protected]_gpu
def test_cumsum():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 2, 3, 4], "float32")]
class Cumsum(Module):
@@ -2115,14 +1880,7 @@ def test_cumsum():
verify_model(Cumsum(), input_info, {}, expected1)
[email protected]_gpu
def test_chunk():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 3, 10, 10], "float32")]
class Chunk(Module):
@@ -2157,14 +1915,7 @@ def test_chunk():
verify_model(Chunk(), input_info, {}, Expected)
[email protected]_gpu
def test_inplace_fill():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
class InplaceFill(Module):
def forward(self, input):
input.fill_(1.5)
@@ -2185,13 +1936,8 @@ def test_inplace_fill():
verify_model(InplaceFill(), [([10, 10], "float32")], {}, Expected)
[email protected]_gpu
def test_arange():
import numpy as np
- import torch
- from torch import fx
- from torch.nn import Module
- from tvm.relax.frontend.torch import from_fx
torch.set_grad_enabled(False)
torch.random.manual_seed(0)
@@ -2206,20 +1952,12 @@ def test_arange():
assert len(mod["main"].body.blocks[0].bindings) == 1
assert isinstance(mod["main"].body.blocks[0].bindings[0].value,
relax.Constant)
tvm.testing.assert_allclose(
- mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
np.arange(0, 20, dtype="int32")
+ mod["main"].body.blocks[0].bindings[0].value.data.numpy(),
+ np.arange(0, 20, dtype="int32"),
)
[email protected]_gpu
def test_empty():
- import torch
- from torch import fx
- from torch.nn import Module
- from tvm.relax.frontend.torch import from_fx
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
class Empty(Module):
def forward(self, input):
return torch.empty((10, 10), dtype=torch.float32)
@@ -2233,16 +1971,7 @@ def test_empty():
assert mod["main"].body.blocks[0].bindings[0].value.data.dtype == "float32"
[email protected]_gpu
def test_tensor():
- import torch
- from torch import fx
- from torch.nn import Module
- from tvm.relax.frontend.torch import from_fx
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
class Empty1(Module):
def forward(self, input):
return torch.tensor(3, dtype=torch.float32)
@@ -2268,14 +1997,7 @@ def test_tensor():
assert mod2["main"].body.blocks[0].bindings[0].value.data.dtype == "int64"
[email protected]_gpu
def test_tril():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([10, 10], "float32")]
class Tril(Module):
@@ -2304,14 +2026,7 @@ def test_tril():
verify_model(InplaceTril(), input_info, {}, expected1)
[email protected]_gpu
def test_triu():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([10, 10], "float32")]
class Triu(Module):
@@ -2340,14 +2055,7 @@ def test_triu():
verify_model(InplaceTriu(), input_info, {}, expected1)
[email protected]_gpu
def test_new_ones():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 2, 3], "float32")]
class NewOnes(Module):
@@ -2370,14 +2078,7 @@ def test_new_ones():
verify_model(NewOnes(), input_info, {}, expected1)
[email protected]_gpu
def test_expand():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 2, 3, 4], "float32")]
class Expand(Module):
@@ -2400,14 +2101,7 @@ def test_expand():
verify_model(Expand(), input_info, {}, expected1)
[email protected]_gpu
def test_reduce():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 2, 3, 4], "float32")]
# sum
@@ -2431,14 +2125,7 @@ def test_reduce():
verify_model(Sum(), input_info, {}, expected1)
[email protected]_gpu
def test_datatype():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 2, 3, 4], "float32")]
# float
@@ -2501,14 +2188,7 @@ def test_datatype():
verify_model(AsType(), input_info, {}, expected1)
[email protected]_gpu
def test_permute():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 2, 3, 4], "float32")]
class Permute(Module):
@@ -2531,14 +2211,7 @@ def test_permute():
verify_model(Permute(), input_info, {}, expected1)
[email protected]_gpu
def test_reshape():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 2, 3, 4], "float32")]
class Reshape(Module):
@@ -2559,14 +2232,7 @@ def test_reshape():
verify_model(Reshape(), input_info, {}, expected1)
[email protected]_gpu
def test_transpose():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 2, 3, 4], "float32")]
class Transpose(Module):
@@ -2589,14 +2255,7 @@ def test_transpose():
verify_model(Transpose(), input_info, {}, expected1)
[email protected]_gpu
def test_view():
- import torch
- from torch.nn import Module
-
- torch.set_grad_enabled(False)
- torch.random.manual_seed(0)
-
input_info = [([1, 2, 3, 4], "float32")]
class View(Module):
@@ -2617,14 +2276,7 @@ def test_view():
verify_model(View(), input_info, {}, expected1)
[email protected]_gpu
def test_keep_params():
- import torch
- from torch import fx
- from torch.nn import Module
- from tvm.relax.frontend import detach_params
- from tvm.relax.frontend.torch import from_fx
-
class Conv2D1(Module):
def __init__(self):
super().__init__()
@@ -2674,16 +2326,11 @@ def test_keep_params():
assert tuple(x.value for x in param_var.struct_info.shape.values) ==
param_ndarray.shape
assert param_var.struct_info.dtype == param_ndarray.dtype
- tvm.testing.assert_allclose(params[0].numpy(),
model.conv.bias.detach().numpy())
- tvm.testing.assert_allclose(params[1].numpy(),
model.conv.weight.detach().numpy())
+ tvm.testing.assert_allclose(params[0].numpy(),
model.conv.bias.detach().detach().numpy())
+ tvm.testing.assert_allclose(params[1].numpy(),
model.conv.weight.detach().detach().numpy())
[email protected]_gpu
def test_unwrap_unit_return_tuple():
- import torch.fx as fx
- from torch.nn import Module
- from tvm.relax.frontend.torch import from_fx
-
class Identity(Module):
def __init__(self):
super().__init__()
@@ -2707,12 +2354,7 @@ def test_unwrap_unit_return_tuple():
tvm.ir.assert_structural_equal(mod, Expected)
[email protected]_gpu
def test_no_bind_return_tuple():
- import torch.fx as fx
- from torch.nn import Module
- from tvm.relax.frontend.torch import from_fx
-
class Identity(Module):
def __init__(self):
super().__init__()
@@ -2740,11 +2382,7 @@ def test_no_bind_return_tuple():
tvm.ir.assert_structural_equal(mod, Expected)
[email protected]_gpu
def test_argmax():
- import torch
- from torch.nn import Module
-
class Argmax1(Module):
def __init__(self) -> None:
super().__init__()
@@ -2783,11 +2421,7 @@ def test_argmax():
verify_model(Argmax2(), [([256, 256], "float32")], {}, Expected2)
[email protected]_gpu
def test_argmin():
- import torch
- from torch.nn import Module
-
class Argmin1(Module):
def __init__(self) -> None:
super().__init__()
@@ -2826,11 +2460,7 @@ def test_argmin():
verify_model(Argmin2(), [([256, 256], "float32")], {}, Expected2)
[email protected]_gpu
def test_to():
- import torch
- from torch.nn import Module
-
class To1(Module):
def forward(self, input):
return input.to(torch.float16)
@@ -2866,11 +2496,7 @@ def test_to():
verify_model(To2(), [([256, 256], "float32")], {}, Expected2)
[email protected]_gpu
def test_mean():
- import torch
- from torch.nn import Module
-
class Mean(Module):
def forward(self, input):
return input.mean(-1)
@@ -2905,11 +2531,7 @@ def test_mean():
verify_model(MeanKeepDim(), [([256, 256], "float32")], {}, Expected2)
[email protected]_gpu
def test_rsqrt():
- import torch
- from torch.nn import Module
-
class Rsqrt(Module):
def forward(self, input):
return torch.rsqrt(input)
@@ -2929,11 +2551,7 @@ def test_rsqrt():
verify_model(Rsqrt(), [([256, 256], "float32")], {}, Expected1)
[email protected]_gpu
def test_neg():
- import torch
- from torch.nn import Module
-
class Neg(Module):
def forward(self, input):
return -input
@@ -2953,11 +2571,7 @@ def test_neg():
verify_model(Neg(), [([256, 256], "float32")], {}, Expected1)
[email protected]_gpu
def test_max():
- import torch
- from torch.nn import Module
-
class Max(Module):
def forward(self, x, y):
return torch.max(x, y)
@@ -2978,5 +2592,81 @@ def test_max():
verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")],
{}, Expected1)
+def test_attention():
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ ) -> R.Tensor((32, 128, 8, 64), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_0, axes=[0, 2, 1, 3]
+ )
+ lv1: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_1, axes=[0, 2, 1, 3]
+ )
+ lv2: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_2, axes=[0, 2, 1, 3]
+ )
+ lv3: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.nn.attention(
+ lv, lv1, lv2, scale=None
+ )
+ gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"),
+ ) -> R.Tensor((32, 128, 8, 64), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_0, axes=[0, 2, 1, 3]
+ )
+ lv1: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_1, axes=[0, 2, 1, 3]
+ )
+ lv2: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_2, axes=[0, 2, 1, 3]
+ )
+ lv3: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.nn.attention(
+ lv, lv1, lv2, inp_3, scale=None
+ )
+ gv: R.Tensor((32, 128, 8, 64), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ verify_model(
+ lambda q, k, v: F.scaled_dot_product_attention(q, k, v),
+ [
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 64], "float32"),
+ ],
+ {},
+ Expected1,
+ )
+
+ verify_model(
+ lambda q, k, v, mask: F.scaled_dot_product_attention(q, k, v, mask),
+ [
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 64], "float32"),
+ ([32, 8, 128, 128], "float32"),
+ ],
+ {},
+ Expected2,
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()