This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d5dcabf2b2 [Unity][Frontend][NN] Add diffusers style Attention layer
(#15609)
d5dcabf2b2 is described below
commit d5dcabf2b2687ef4a03d0becc7d3e043b93a6f90
Author: Josh Fromm <[email protected]>
AuthorDate: Fri Aug 25 22:00:32 2023 -0700
[Unity][Frontend][NN] Add diffusers style Attention layer (#15609)
This PR adds support for the `Attention` layer in the nn module API. This
layer mimics the behavior of the [Attention layer used in huggingface
Diffusers](https://github.com/huggingface/diffusers/blob/80871ac5971fe7e708befa3b553463c4e61b22ab/src/diffusers/models/attention_processor.py#L36).
Under the hood it uses scaled dot product attention. Notably, there are still
some missing features. For example I didnt add support for attention masks yet.
I also am assuming 3 dimensional inputs [...]
---
python/tvm/relax/frontend/nn/core.py | 3 +
python/tvm/relax/frontend/nn/modules.py | 129 ++++++++++++++++++++++++-
python/tvm/relax/frontend/nn/op.py | 52 +++++++++-
python/tvm/relax/frontend/nn/spec.py | 21 ++--
tests/python/relax/test_frontend_nn_modules.py | 82 ++++++++++++++++
tests/python/relax/test_frontend_nn_op.py | 36 +++++++
tests/python/relax/test_frontend_onnx.py | 13 ++-
7 files changed, 318 insertions(+), 18 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index c282815160..ea3744b6ca 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -486,6 +486,9 @@ class ModuleList(Module):
def __len__(self):
return len(self.modules)
+ def append(self, module):
+ self.modules.append(module)
+
def to(self, dtype: Optional[str] = None) -> None: # pylint:
disable=invalid-name
for module in self.modules:
module.to(dtype=dtype)
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index fde18473ee..6df0957398 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=too-many-arguments,invalid-name,protected-access
+# pylint:
disable=too-many-arguments,invalid-name,protected-access,unused-argument
"""Builtin Modules."""
from typing import List, Optional, Sequence, Union
@@ -24,7 +24,7 @@ from tvm._ffi import register_func
from tvm.runtime import NDArray
from . import op
-from .core import Effect, Module, Parameter, Tensor, get_default_dtype
+from .core import Effect, Module, Parameter, Tensor, get_default_dtype,
ModuleList
class IOEffect(Effect):
@@ -344,7 +344,7 @@ class GroupNorm(Module):
self.weight = None
self.bias = None
- def forward(self, x: Tensor):
+ def forward(self, x: Tensor, channel_axis: int = 1, axes:
Optional[List[int]] = None):
"""
Forward method for group norm layer.
@@ -352,13 +352,20 @@ class GroupNorm(Module):
----------
x : Tensor
The input tensor.
+ channel_axis : int
+ Channel axis of the input data.
+ axes : Optional[List[int]]
+ Optional list of axes to compute norm over, if not specified,
+ assumes that the first two axes should be left alone.
Returns
-------
ret : Tensor
The output tensor for the group norm layer.
"""
- return op.group_norm(x, self.num_groups, self.weight, self.bias,
self.eps)
+ return op.group_norm(
+ x, self.num_groups, self.weight, self.bias, self.eps,
channel_axis, axes
+ )
class KVCache(Effect):
@@ -621,3 +628,117 @@ class Timesteps(Module):
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
+
+
+class Attention(Module):
+ """
+ A cross attention layer.
+
+ Parameters
+ ----------
+ query_dim : int
+ The number of channels in the query.
+ cross_attention_dim : Optional[int]
+ The number of channels in the encoder_hidden_states.
+ If not given, defaults to `query_dim`.
+ heads : int
+ The number of heads to use for multi-head attention.
+ dim_head : int
+ The number of channels in each head.
+ bias : bool
+ Set to `True` for the query, key, and value linear layers to
contain a bias parameter.
+ norm_num_groups : Optional[int]
+ When set, group norm is applied to the input using this number of
groups.
+ out_bias : bool
+ Set to `True` to apply a bias to the output linear layer.
+ scale_qk : bool
+ Whether to apply scaling to query and key tensors.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ bias: bool = False,
+ norm_num_groups: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ ):
+ self.query_dim = query_dim
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim
else query_dim
+ self.heads = heads
+ self.dim_head = dim_head
+ self.bias = bias
+ self.norm_num_groups = norm_num_groups
+ self.out_bias = out_bias
+ self.scale_qk = scale_qk
+
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+ self.inner_dim = dim_head * heads
+
+ self.to_q = Linear(self.query_dim, self.inner_dim, bias=self.bias)
+ self.to_k = Linear(self.cross_attention_dim, self.inner_dim,
bias=self.bias)
+ self.to_v = Linear(self.cross_attention_dim, self.inner_dim,
bias=self.bias)
+
+ if self.norm_num_groups is not None:
+ self.group_norm = GroupNorm(
+ num_channels=self.query_dim, num_groups=self.norm_num_groups,
affine=True
+ )
+ else:
+ self.group_norm = None
+
+ self.to_out = ModuleList([Linear(self.inner_dim, self.query_dim,
bias=self.out_bias)])
+
+ def forward(
+ self,
+ hidden_states: Tensor,
+ encoder_hidden_states: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ **cross_attention_kwargs,
+ ):
+ """
+ Forward method for Attention layer.
+
+ Parameters
+ ----------
+ hidden_states : Tensor
+ The input sample tensor.
+ encoder_hidden_states : Optional[Tensor]
+ Previous hidden step hidden states.
+ attention_mask : Optional[Tensor]
+ Mask tensor for attention, currently not supported.
+
+ Returns
+ -------
+ ret : Tensor
+ The output tensor for the embedding layer.
+ """
+ # This implementation assumes use of torch 2.0 scaled_dot_product
attention.
+ assert attention_mask is None, "Attention mask not yet supported."
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states, channel_axis=2,
axes=[1])
+
+ query = self.to_q(hidden_states)
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+ head_dim = int(self.inner_dim // self.heads)
+
+ query = op.reshape(query, [0, -1, self.heads, head_dim])
+ key = op.reshape(key, [0, -1, self.heads, head_dim])
+ value = op.reshape(value, [0, -1, self.heads, head_dim])
+
+ hidden_states = op.scaled_dot_product_attention(query, key, value,
is_causal=False)
+
+ # Return to proper shape.
+ hidden_states = op.reshape(hidden_states, (0, -1, self.heads *
head_dim))
+
+ # Linear projection
+ hidden_states = self.to_out[0](hidden_states)
+
+ return hidden_states
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index b3959cd95f..4ef02797c2 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -759,6 +759,8 @@ def group_norm(
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float = 1e-5,
+ channel_axis: int = 1,
+ axes: Optional[List[int]] = None,
name: str = "group_norm",
) -> Tensor:
r"""
@@ -785,6 +787,13 @@ def group_norm(
epsilon : float
Small float added to square mean to avoid dividing by zero.
+ channel_axis: int
+ The channel axis of the data.
+
+ axes : Optional[int]
+ Which axes to compute the groupnorm over. If None, assumes first
+ two channels should be ignored.
+
name : str
Name hint.
@@ -798,9 +807,11 @@ def group_norm(
if bias is not None:
bias = bias._expr
dim = len(x._expr.struct_info.shape)
+ if axes is None:
+ axes = list(range(2, dim))
return _wrap_nested(
_op.nn.group_norm(
- x._expr, weight, bias, num_groups, channel_axis=1,
axes=list(range(2, dim)), epsilon=eps
+ x._expr, weight, bias, num_groups, channel_axis=channel_axis,
axes=axes, epsilon=eps
),
name,
)
@@ -955,6 +966,45 @@ def get_timestep_embedding(
return _wrap_nested(emb, name)
+def scaled_dot_product_attention(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ attn_mask: Optional[Tensor] = None,
+ is_causal: Optional[bool] = False,
+ scale: Optional[float] = None,
+ name: str = "scaled_dot_product_attention",
+):
+ """
+ Computes a scaled dot product attention on provided attention
+ query, key, and values. Compliant with the functional torch implementation.
+
+ Parameters
+ ----------
+ query : Tensor
+ Tensor representing current attention lookup.
+ key : Tensor
+ Tensor representing cross attention mapping.
+ value : Tensor
+ Tensor representing embedded attention values.
+ attn_mask : Optional[Tensor]
+ Optional mask for attention, not yet supported.
+ is_causal : Optional[bool]
+ If set, uses a causal attention mask.
+ scale : Optional[float]
+ Optional extra scaling argument applied to attention.
+ name : str
+ Name hint for this function.
+ """
+ assert attn_mask is None, "attn_mask not yet supported."
+ causal_mask = "TopLeft" if is_causal else None
+
+ attn = _op.nn.attention(
+ query._expr, key._expr, value._expr, causal_mask=causal_mask,
scale=scale
+ )
+ return _wrap_nested(attn, name)
+
+
def tensor_expr_op(
tensor_expr_func: Callable,
name_hint: str,
diff --git a/python/tvm/relax/frontend/nn/spec.py
b/python/tvm/relax/frontend/nn/spec.py
index 95772f2f94..983e0adc52 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -106,16 +106,17 @@ class MethodSpec:
arg_names = list(method_signature.parameters.keys())
arg_specs = []
for arg_name in arg_names:
- arg_spec = spec[arg_name]
- if arg_spec is Int or arg_spec is int:
- arg_spec = Int()
- elif isinstance(arg_spec, str) and arg_spec == "int":
- arg_spec = Int()
- elif isinstance(arg_spec, (Int, Tensor)):
- pass
- else:
- raise TypeError(f"Invalid spec for argument {arg_name}:
{arg_spec}")
- arg_specs.append(arg_spec)
+ if arg_name in spec:
+ arg_spec = spec[arg_name]
+ if arg_spec is Int or arg_spec is int:
+ arg_spec = Int()
+ elif isinstance(arg_spec, str) and arg_spec == "int":
+ arg_spec = Int()
+ elif isinstance(arg_spec, (Int, Tensor)):
+ pass
+ else:
+ raise TypeError(f"Invalid spec for argument {arg_name}:
{arg_spec}")
+ arg_specs.append(arg_spec)
return MethodSpec(method, arg_names, arg_specs)
@staticmethod
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 68b03c5a21..dba4178f65 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -335,5 +335,87 @@ def test_kv_cache():
assert_structural_equal(tvm_mod, Module, True)
+def test_attention():
+ @R.function
+ def forward(
+ hidden_states: R.Tensor((2, 4096, 640), dtype="float32"),
+ encoder_hidden_states: R.Tensor((2, 77, 2048), dtype="float32"),
+ to_q_weight: R.Tensor((640, 640), dtype="float32"),
+ to_k_weight: R.Tensor((640, 2048), dtype="float32"),
+ to_v_weight: R.Tensor((640, 2048), dtype="float32"),
+ group_norm_weight: R.Tensor((640,), dtype="float32"),
+ group_norm_bias: R.Tensor((640,), dtype="float32"),
+ to_out_0_weight: R.Tensor((640, 640), dtype="float32"),
+ to_out_0_bias: R.Tensor((640,), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((2, 4096, 640), dtype="float32"), R.Tuple(R.Object)):
+ with R.dataflow():
+ group_norm: R.Tensor((2, 4096, 640), dtype="float32") =
R.nn.group_norm(
+ hidden_states,
+ group_norm_weight,
+ group_norm_bias,
+ num_groups=8,
+ channel_axis=2,
+ axes=[1],
+ epsilon=1.0000000000000001e-05,
+ center=True,
+ scale=True,
+ )
+ permute_dims: R.Tensor((640, 640), dtype="float32") =
R.permute_dims(
+ to_q_weight, axes=None
+ )
+ matmul: R.Tensor((2, 4096, 640), dtype="float32") = R.matmul(
+ group_norm, permute_dims, out_dtype="void"
+ )
+ permute_dims1: R.Tensor((2048, 640), dtype="float32") =
R.permute_dims(
+ to_k_weight, axes=None
+ )
+ matmul1: R.Tensor((2, 77, 640), dtype="float32") = R.matmul(
+ encoder_hidden_states, permute_dims1, out_dtype="void"
+ )
+ permute_dims2: R.Tensor((2048, 640), dtype="float32") =
R.permute_dims(
+ to_v_weight, axes=None
+ )
+ matmul2: R.Tensor((2, 77, 640), dtype="float32") = R.matmul(
+ encoder_hidden_states, permute_dims2, out_dtype="void"
+ )
+ reshape: R.Tensor((2, 4096, 10, 64), dtype="float32") = R.reshape(
+ matmul, R.shape([2, 4096, 10, 64])
+ )
+ reshape1: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape(
+ matmul1, R.shape([2, 77, 10, 64])
+ )
+ reshape2: R.Tensor((2, 77, 10, 64), dtype="float32") = R.reshape(
+ matmul2, R.shape([2, 77, 10, 64])
+ )
+ scaled_dot_product_attention: R.Tensor(
+ (2, 4096, 10, 64), dtype="float32"
+ ) = R.nn.attention(reshape, reshape1, reshape2, scale=None,
causal_mask=None)
+ reshape3: R.Tensor((2, 4096, 640), dtype="float32") = R.reshape(
+ scaled_dot_product_attention, R.shape([2, 4096, 640])
+ )
+ permute_dims3: R.Tensor((640, 640), dtype="float32") =
R.permute_dims(
+ to_out_0_weight, axes=None
+ )
+ matmul3: R.Tensor((2, 4096, 640), dtype="float32") = R.matmul(
+ reshape3, permute_dims3, out_dtype="void"
+ )
+ add: R.Tensor((2, 4096, 640), dtype="float32") = R.add(matmul3,
to_out_0_bias)
+ gv1: R.Tuple(R.Tensor((2, 4096, 640), dtype="float32"),
R.Tuple(R.Object)) = add, (_io,)
+ R.output(gv1)
+ return gv1
+
+ mod = modules.Attention(query_dim=640, cross_attention_dim=2048, heads=10,
norm_num_groups=8)
+ tvm_mod, _ = mod.export_tvm(
+ spec={
+ "forward": {
+ "hidden_states": spec.Tensor((2, 4096, 640), "float32"),
+ "encoder_hidden_states": spec.Tensor((2, 77, 2048), "float32"),
+ }
+ }
+ )
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index c404c18a68..6b59d0419a 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -298,6 +298,42 @@ def test_timestep_embedding():
tvm.ir.assert_structural_equal(irmodule["test"], test)
+def test_scaled_dot_product_attention():
+ class Model(Module):
+ def test(self, query: Tensor, key: Tensor, value: Tensor):
+ scaled_dot_product_attention =
op.scaled_dot_product_attention(query, key, value)
+ return scaled_dot_product_attention
+
+ @R.function
+ def test(
+ query: R.Tensor((1, 32, 32, 32), dtype="float32"),
+ key: R.Tensor((1, 32, 32, 32), dtype="float32"),
+ value: R.Tensor((1, 32, 32, 32), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"),
R.Tuple(R.Object)):
+ with R.dataflow():
+ scaled_dot_product_attention: R.Tensor(
+ (1, 32, 32, 32), dtype="float32"
+ ) = R.nn.attention(query, key, value, scale=None, causal_mask=None)
+ gv1: R.Tuple(
+ R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)
+ ) = scaled_dot_product_attention, (_io,)
+ R.output(gv1)
+ return gv1
+
+ m = Model()
+ irmodule, _ = m.export_tvm(
+ spec={
+ "test": {
+ "query": spec.Tensor([1, 32, 32, 32], "float32"),
+ "key": spec.Tensor([1, 32, 32, 32], "float32"),
+ "value": spec.Tensor([1, 32, 32, 32], "float32"),
+ }
+ }
+ )
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
def test_tensor_expr_op():
class Model(Module):
def test(self, x: Tensor):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 7eb57a02f9..d587d70636 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -72,7 +72,10 @@ def generate_random_inputs(
def check_correctness(
- model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None, opset:
int = None
+ model: ModelProto,
+ inputs: Optional[Dict[str, np.ndarray]] = None,
+ opset: int = None,
+ atol: float = 1e-5,
) -> None:
"""Run an onnx model in both onnxruntime and TVM through our importer
confirm that the results match. Otherwise, an exception will be raised.
@@ -85,6 +88,9 @@ def check_correctness(
An optional dictionary containing values for each input in the onnx
model.
opset: int
The opset version to use for the onnx importer.
+ atol: float
+ Set the tolerance of correctness checking. Some ops may be show more
+ arithmetic variance than others.
"""
if opset is not None:
model.opset_import[0].version = opset
@@ -143,7 +149,7 @@ def check_correctness(
# TODO Allow configurable tolerance.
# Sometimes None is used to indicate an unused output.
if ort_out is not None:
- tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=1e-5)
+ tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=atol)
@pytest.mark.parametrize(
@@ -933,7 +939,8 @@ def test_all_reduce_funcs(func, dynamic):
model = helper.make_model(graph, producer_name="reduce_test")
inputs_dict = {"x": data}
- check_correctness(model, inputs_dict, opset=11)
+ # Reduction ops accumulate arithmetic errors, so we use a higher
tolerance.
+ check_correctness(model, inputs_dict, opset=11, atol=1e-4)
for keepdims in [True, False]:
verify_reduce_func(