This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d5dcabf2b2 [Unity][Frontend][NN] Add diffusers style Attention layer 
(#15609)
d5dcabf2b2 is described below

commit d5dcabf2b2687ef4a03d0becc7d3e043b93a6f90
Author: Josh Fromm <[email protected]>
AuthorDate: Fri Aug 25 22:00:32 2023 -0700

    [Unity][Frontend][NN] Add diffusers style Attention layer (#15609)
    
    This PR adds support for the `Attention` layer in the nn module API. This 
layer mimics the behavior of the [Attention layer used in huggingface 
Diffusers](https://github.com/huggingface/diffusers/blob/80871ac5971fe7e708befa3b553463c4e61b22ab/src/diffusers/models/attention_processor.py#L36).
 Under the hood it uses scaled dot product attention. Notably, there are still 
some missing features. For example I didnt add support for attention masks yet. 
I also am assuming 3 dimensional inputs [...]
---
 python/tvm/relax/frontend/nn/core.py           |   3 +
 python/tvm/relax/frontend/nn/modules.py        | 129 ++++++++++++++++++++++++-
 python/tvm/relax/frontend/nn/op.py             |  52 +++++++++-
 python/tvm/relax/frontend/nn/spec.py           |  21 ++--
 tests/python/relax/test_frontend_nn_modules.py |  82 ++++++++++++++++
 tests/python/relax/test_frontend_nn_op.py      |  36 +++++++
 tests/python/relax/test_frontend_onnx.py       |  13 ++-
 7 files changed, 318 insertions(+), 18 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index c282815160..ea3744b6ca 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -486,6 +486,9 @@ class ModuleList(Module):
     def __len__(self):
         return len(self.modules)
 
+    def append(self, module):
+        self.modules.append(module)
+
     def to(self, dtype: Optional[str] = None) -> None:  # pylint: 
disable=invalid-name
         for module in self.modules:
             module.to(dtype=dtype)
diff --git a/python/tvm/relax/frontend/nn/modules.py 
b/python/tvm/relax/frontend/nn/modules.py
index fde18473ee..6df0957398 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=too-many-arguments,invalid-name,protected-access
+# pylint: 
disable=too-many-arguments,invalid-name,protected-access,unused-argument
 """Builtin Modules."""
 from typing import List, Optional, Sequence, Union
 
@@ -24,7 +24,7 @@ from tvm._ffi import register_func
 from tvm.runtime import NDArray
 
 from . import op
-from .core import Effect, Module, Parameter, Tensor, get_default_dtype
+from .core import Effect, Module, Parameter, Tensor, get_default_dtype, 
ModuleList
 
 
 class IOEffect(Effect):
@@ -344,7 +344,7 @@ class GroupNorm(Module):
             self.weight = None
             self.bias = None
 
-    def forward(self, x: Tensor):
+    def forward(self, x: Tensor, channel_axis: int = 1, axes: 
Optional[List[int]] = None):
         """
         Forward method for group norm layer.
 
@@ -352,13 +352,20 @@ class GroupNorm(Module):
         ----------
         x : Tensor
             The input tensor.
+        channel_axis : int
+            Channel axis of the input data.
+        axes : Optional[List[int]]
+            Optional list of axes to compute norm over, if not specified,
+            assumes that the first two axes should be left alone.
 
         Returns
         -------
         ret : Tensor
             The output tensor for the group norm layer.
         """
-        return op.group_norm(x, self.num_groups, self.weight, self.bias, 
self.eps)
+        return op.group_norm(
+            x, self.num_groups, self.weight, self.bias, self.eps, 
channel_axis, axes
+        )
 
 
 class KVCache(Effect):
@@ -621,3 +628,117 @@ class Timesteps(Module):
             flip_sin_to_cos=self.flip_sin_to_cos,
             downscale_freq_shift=self.downscale_freq_shift,
         )
+
+
+class Attention(Module):
+    """
+    A cross attention layer.
+
+    Parameters
+    ----------
+        query_dim : int
+            The number of channels in the query.
+        cross_attention_dim : Optional[int]
+            The number of channels in the encoder_hidden_states.
+            If not given, defaults to `query_dim`.
+        heads : int
+            The number of heads to use for multi-head attention.
+        dim_head : int
+            The number of channels in each head.
+        bias : bool
+            Set to `True` for the query, key, and value linear layers to 
contain a bias parameter.
+        norm_num_groups : Optional[int]
+            When set, group norm is applied to the input using this number of 
groups.
+        out_bias : bool
+            Set to `True` to apply a bias to the output linear layer.
+        scale_qk : bool
+            Whether to apply scaling to query and key tensors.
+    """
+
+    def __init__(
+        self,
+        query_dim: int,
+        cross_attention_dim: Optional[int] = None,
+        heads: int = 8,
+        dim_head: int = 64,
+        bias: bool = False,
+        norm_num_groups: Optional[int] = None,
+        out_bias: bool = True,
+        scale_qk: bool = True,
+    ):
+        self.query_dim = query_dim
+        self.cross_attention_dim = cross_attention_dim if cross_attention_dim 
else query_dim
+        self.heads = heads
+        self.dim_head = dim_head
+        self.bias = bias
+        self.norm_num_groups = norm_num_groups
+        self.out_bias = out_bias
+        self.scale_qk = scale_qk
+
+        self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+        self.inner_dim = dim_head * heads
+
+        self.to_q = Linear(self.query_dim, self.inner_dim, bias=self.bias)
+        self.to_k = Linear(self.cross_attention_dim, self.inner_dim, 
bias=self.bias)
+        self.to_v = Linear(self.cross_attention_dim, self.inner_dim, 
bias=self.bias)
+
+        if self.norm_num_groups is not None:
+            self.group_norm = GroupNorm(
+                num_channels=self.query_dim, num_groups=self.norm_num_groups, 
affine=True
+            )
+        else:
+            self.group_norm = None
+
+        self.to_out = ModuleList([Linear(self.inner_dim, self.query_dim, 
bias=self.out_bias)])
+
+    def forward(
+        self,
+        hidden_states: Tensor,
+        encoder_hidden_states: Optional[Tensor] = None,
+        attention_mask: Optional[Tensor] = None,
+        **cross_attention_kwargs,
+    ):
+        """
+        Forward method for Attention layer.
+
+        Parameters
+        ----------
+        hidden_states : Tensor
+            The input sample tensor.
+        encoder_hidden_states : Optional[Tensor]
+            Previous hidden step hidden states.
+        attention_mask : Optional[Tensor]
+            Mask tensor for attention, currently not supported.
+
+        Returns
+        -------
+        ret : Tensor
+            The output tensor for the embedding layer.
+        """
+        # This implementation assumes use of torch 2.0 scaled_dot_product 
attention.
+        assert attention_mask is None, "Attention mask not yet supported."
+
+        if self.group_norm is not None:
+            hidden_states = self.group_norm(hidden_states, channel_axis=2, 
axes=[1])
+
+        query = self.to_q(hidden_states)
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+
+        key = self.to_k(encoder_hidden_states)
+        value = self.to_v(encoder_hidden_states)
+        head_dim = int(self.inner_dim // self.heads)
+
+        query = op.reshape(query, [0, -1, self.heads, head_dim])
+        key = op.reshape(key, [0, -1, self.heads, head_dim])
+        value = op.reshape(value, [0, -1, self.heads, head_dim])
+
+        hidden_states = op.scaled_dot_product_attention(query, key, value, 
is_causal=False)
+
+        # Return to proper shape.
+        hidden_states = op.reshape(hidden_states, (0, -1, self.heads * 
head_dim))
+
+        # Linear projection
+        hidden_states = self.to_out[0](hidden_states)
+
+        return hidden_states
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index b3959cd95f..4ef02797c2 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -759,6 +759,8 @@ def group_norm(
     weight: Optional[Tensor],
     bias: Optional[Tensor],
     eps: float = 1e-5,
+    channel_axis: int = 1,
+    axes: Optional[List[int]] = None,
     name: str = "group_norm",
 ) -> Tensor:
     r"""
@@ -785,6 +787,13 @@ def group_norm(
     epsilon : float
         Small float added to square mean to avoid dividing by zero.
 
+    channel_axis: int
+        The channel axis of the data.
+
+    axes : Optional[int]
+        Which axes to compute the groupnorm over. If None, assumes first
+        two channels should be ignored.
+
     name : str
         Name hint.
 
@@ -798,9 +807,11 @@ def group_norm(
     if bias is not None:
         bias = bias._expr
     dim = len(x._expr.struct_info.shape)
+    if axes is None:
+        axes = list(range(2, dim))
     return _wrap_nested(
         _op.nn.group_norm(
-            x._expr, weight, bias, num_groups, channel_axis=1, 
axes=list(range(2, dim)), epsilon=eps
+            x._expr, weight, bias, num_groups, channel_axis=channel_axis, 
axes=axes, epsilon=eps
         ),
         name,
     )
@@ -955,6 +966,45 @@ def get_timestep_embedding(
     return _wrap_nested(emb, name)
 
 
+def scaled_dot_product_attention(
+    query: Tensor,
+    key: Tensor,
+    value: Tensor,
+    attn_mask: Optional[Tensor] = None,
+    is_causal: Optional[bool] = False,
+    scale: Optional[float] = None,
+    name: str = "scaled_dot_product_attention",
+):
+    """
+    Computes a scaled dot product attention on provided attention
+    query, key, and values. Compliant with the functional torch implementation.
+
+    Parameters
+    ----------
+    query : Tensor
+        Tensor representing current attention lookup.
+    key : Tensor
+        Tensor representing cross attention mapping.
+    value : Tensor
+        Tensor representing embedded attention values.
+    attn_mask : Optional[Tensor]
+        Optional mask for attention, not yet supported.
+    is_causal : Optional[bool]
+        If set, uses a causal attention mask.
+    scale : Optional[float]
+        Optional extra scaling argument applied to attention.
+    name : str
+        Name hint for this function.
+    """
+    assert attn_mask is None, "attn_mask not yet supported."
+    causal_mask = "TopLeft" if is_causal else None
+
+    attn = _op.nn.attention(
+        query._expr, key._expr, value._expr, causal_mask=causal_mask, 
scale=scale
+    )
+    return _wrap_nested(attn, name)
+
+
 def tensor_expr_op(
     tensor_expr_func: Callable,
     name_hint: str,
diff --git a/python/tvm/relax/frontend/nn/spec.py 
b/python/tvm/relax/frontend/nn/spec.py
index 95772f2f94..983e0adc52 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -106,16 +106,17 @@ class MethodSpec:
         arg_names = list(method_signature.parameters.keys())
         arg_specs = []
         for arg_name in arg_names:
-            arg_spec = spec[arg_name]
-            if arg_spec is Int or arg_spec is int:
-                arg_spec = Int()
-            elif isinstance(arg_spec, str) and arg_spec == "int":
-                arg_spec = Int()
-            elif isinstance(arg_spec, (Int, Tensor)):
-                pass
-            else:
-                raise TypeError(f"Invalid spec for argument {arg_name}: 
{arg_spec}")
-            arg_specs.append(arg_spec)
+            if arg_name in spec:
+                arg_spec = spec[arg_name]
+                if arg_spec is Int or arg_spec is int:
+                    arg_spec = Int()
+                elif isinstance(arg_spec, str) and arg_spec == "int":
+                    arg_spec = Int()
+                elif isinstance(arg_spec, (Int, Tensor)):
+                    pass
+                else:
+                    raise TypeError(f"Invalid spec for argument {arg_name}: 
{arg_spec}")
+                arg_specs.append(arg_spec)
         return MethodSpec(method, arg_names, arg_specs)
 
     @staticmethod
diff --git a/tests/python/relax/test_frontend_nn_modules.py 
b/tests/python/relax/test_frontend_nn_modules.py
index 68b03c5a21..dba4178f65 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -335,5 +335,87 @@ def test_kv_cache():
     assert_structural_equal(tvm_mod, Module, True)
 
 
+def test_attention():
+    @R.function
+    def forward(
+        hidden_states: R.Tensor((2, 4096, 640), dtype="float32"),
+        encoder_hidden_states: R.Tensor((2, 77, 2048), dtype="float32"),
+        to_q_weight: R.Tensor((640, 640), dtype="float32"),
+        to_k_weight: R.Tensor((640, 2048), dtype="float32"),
+        to_v_weight: R.Tensor((640, 2048), dtype="float32"),
+        group_norm_weight: R.Tensor((640,), dtype="float32"),
+        group_norm_bias: R.Tensor((640,), dtype="float32"),
+        to_out_0_weight: R.Tensor((640, 640), dtype="float32"),
+        to_out_0_bias: R.Tensor((640,), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(R.Tensor((2, 4096, 640), dtype="float32"), R.Tuple(R.Object)):
+        with R.dataflow():
+            group_norm: R.Tensor((2, 4096, 640), dtype="float32") = 
R.nn.group_norm(
+                hidden_states,
+                group_norm_weight,
+                group_norm_bias,
+                num_groups=8,
+                channel_axis=2,
+                axes=[1],
+                epsilon=1.0000000000000001e-05,
+                center=True,
+                scale=True,
+            )
+            permute_dims: R.Tensor((640, 640), dtype="float32") = 
R.permute_dims(
+                to_q_weight, axes=None
+            )
+            matmul: R.Tensor((2, 4096, 640), dtype="float32") = R.matmul(
+                group_norm, permute_dims, out_dtype="void"
+            )
+            permute_dims1: R.Tensor((2048, 640), dtype="float32") = 
R.permute_dims(
+                to_k_weight, axes=None
+            )
+            matmul1: R.Tensor((2, 77, 640), dtype="float32") = R.matmul(
+                encoder_hidden_states, permute_dims1, out_dtype="void"
+            )
+            permute_dims2: R.Tensor((2048, 640), dtype="float32") = 
R.permute_dims(
+                to_v_weight, axes=None
+            )
+            matmul2: R.Tensor((2, 77, 640), dtype="float32") = R.matmul(
+                encoder_hidden_states, permute_dims2, out_dtype="void"
+            )
+            reshape: R.Tensor((2, 4096, 10, 64), dtype="float32") = R.reshape(
+                matmul, R.shape([2, 4096, 10, 64])
+            )
+            reshape1: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape(
+                matmul1, R.shape([2, 77, 10, 64])
+            )
+            reshape2: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape(
+                matmul2, R.shape([2, 77, 10, 64])
+            )
+            scaled_dot_product_attention: R.Tensor(
+                (2, 4096, 10, 64), dtype="float32"
+            ) = R.nn.attention(reshape, reshape1, reshape2, scale=None, 
causal_mask=None)
+            reshape3: R.Tensor((2, 4096, 640), dtype="float32") = R.reshape(
+                scaled_dot_product_attention, R.shape([2, 4096, 640])
+            )
+            permute_dims3: R.Tensor((640, 640), dtype="float32") = 
R.permute_dims(
+                to_out_0_weight, axes=None
+            )
+            matmul3: R.Tensor((2, 4096, 640), dtype="float32") = R.matmul(
+                reshape3, permute_dims3, out_dtype="void"
+            )
+            add: R.Tensor((2, 4096, 640), dtype="float32") = R.add(matmul3, 
to_out_0_bias)
+            gv1: R.Tuple(R.Tensor((2, 4096, 640), dtype="float32"), 
R.Tuple(R.Object)) = add, (_io,)
+            R.output(gv1)
+        return gv1
+
+    mod = modules.Attention(query_dim=640, cross_attention_dim=2048, heads=10, 
norm_num_groups=8)
+    tvm_mod, _ = mod.export_tvm(
+        spec={
+            "forward": {
+                "hidden_states": spec.Tensor((2, 4096, 640), "float32"),
+                "encoder_hidden_states": spec.Tensor((2, 77, 2048), "float32"),
+            }
+        }
+    )
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index c404c18a68..6b59d0419a 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -298,6 +298,42 @@ def test_timestep_embedding():
     tvm.ir.assert_structural_equal(irmodule["test"], test)
 
 
+def test_scaled_dot_product_attention():
+    class Model(Module):
+        def test(self, query: Tensor, key: Tensor, value: Tensor):
+            scaled_dot_product_attention = 
op.scaled_dot_product_attention(query, key, value)
+            return scaled_dot_product_attention
+
+    @R.function
+    def test(
+        query: R.Tensor((1, 32, 32, 32), dtype="float32"),
+        key: R.Tensor((1, 32, 32, 32), dtype="float32"),
+        value: R.Tensor((1, 32, 32, 32), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), 
R.Tuple(R.Object)):
+        with R.dataflow():
+            scaled_dot_product_attention: R.Tensor(
+                (1, 32, 32, 32), dtype="float32"
+            ) = R.nn.attention(query, key, value, scale=None, causal_mask=None)
+            gv1: R.Tuple(
+                R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)
+            ) = scaled_dot_product_attention, (_io,)
+            R.output(gv1)
+        return gv1
+
+    m = Model()
+    irmodule, _ = m.export_tvm(
+        spec={
+            "test": {
+                "query": spec.Tensor([1, 32, 32, 32], "float32"),
+                "key": spec.Tensor([1, 32, 32, 32], "float32"),
+                "value": spec.Tensor([1, 32, 32, 32], "float32"),
+            }
+        }
+    )
+    tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
 def test_tensor_expr_op():
     class Model(Module):
         def test(self, x: Tensor):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 7eb57a02f9..d587d70636 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -72,7 +72,10 @@ def generate_random_inputs(
 
 
 def check_correctness(
-    model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None, opset: 
int = None
+    model: ModelProto,
+    inputs: Optional[Dict[str, np.ndarray]] = None,
+    opset: int = None,
+    atol: float = 1e-5,
 ) -> None:
     """Run an onnx model in both onnxruntime and TVM through our importer
        confirm that the results match. Otherwise, an exception will be raised.
@@ -85,6 +88,9 @@ def check_correctness(
         An optional dictionary containing values for each input in the onnx 
model.
     opset: int
         The opset version to use for the onnx importer.
+    atol: float
+        Set the tolerance of correctness checking. Some ops may be show more
+        arithmetic variance than others.
     """
     if opset is not None:
         model.opset_import[0].version = opset
@@ -143,7 +149,7 @@ def check_correctness(
         # TODO Allow configurable tolerance.
         # Sometimes None is used to indicate an unused output.
         if ort_out is not None:
-            tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=1e-5)
+            tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=atol)
 
 
 @pytest.mark.parametrize(
@@ -933,7 +939,8 @@ def test_all_reduce_funcs(func, dynamic):
         model = helper.make_model(graph, producer_name="reduce_test")
 
         inputs_dict = {"x": data}
-        check_correctness(model, inputs_dict, opset=11)
+        # Reduction ops accumulate arithmetic errors, so we use a higher 
tolerance.
+        check_correctness(model, inputs_dict, opset=11, atol=1e-4)
 
     for keepdims in [True, False]:
         verify_reduce_func(

Reply via email to