This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 676def7c9d [Unity][Frontend][NN] Switch to always using fp32 for
timestep embedding calculation (#15811)
676def7c9d is described below
commit 676def7c9d47fe54d5416ff1e2a787294a799b01
Author: Josh Fromm <[email protected]>
AuthorDate: Sat Sep 23 00:09:41 2023 -0700
[Unity][Frontend][NN] Switch to always using fp32 for timestep embedding
calculation (#15811)
It turns out that using datatypes besides `float32` when calculating
timestep embeddings leads to substantial errors. The diffusers library just
hard casts to float32 to avoid this and we should too.
---
python/tvm/relax/frontend/nn/op.py | 15 ++++++++++-----
tests/python/relax/test_frontend_nn_modules.py | 5 +++--
tests/python/relax/test_frontend_nn_op.py | 5 +++--
3 files changed, 16 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 8afeb2c118..0f4c5f6483 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1220,19 +1220,21 @@ def get_timestep_embedding(
[N x dim] Tensor of positional embeddings.
"""
dtype = get_default_dtype()
- timesteps = _op.astype(x._expr, dtype)
+
+ # Arithmetic should be done in float for precision.
+ timesteps = _op.astype(x._expr, "float32")
half_dim = embedding_dim // 2
- exponent = rx.const(-math.log(max_period), dtype) * _op.arange(
- start=0, end=half_dim, dtype=dtype
+ exponent = rx.const(-math.log(max_period), "float32") * _op.arange(
+ start=0, end=half_dim, dtype="float32"
)
- exponent = exponent / (rx.const(half_dim - downscale_freq_shift, dtype))
+ exponent = exponent / (rx.const(half_dim - downscale_freq_shift,
"float32"))
emb = _op.exp(exponent)
emb = _op.expand_dims(timesteps, 1) * _op.expand_dims(emb, 0)
# Scale embeddings
if scale != 1:
- emb = rx.const(scale, dtype) * emb
+ emb = rx.const(scale, "float32") * emb
# Concat sine and cosine embeddings.
if flip_sin_to_cos:
@@ -1243,6 +1245,9 @@ def get_timestep_embedding(
# Zero pad
if embedding_dim % 2 == 1:
emb = _op.nn.pad(emb, (0, 1, 0, 0))
+
+ # Cast to proper output type
+ emb = _op.astype(emb, dtype)
return _wrap_nested(emb, name)
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index f3d248eab4..cb207954e8 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -402,8 +402,9 @@ def test_timesteps():
lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7)
lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8)
lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8)
- get_timestep_embedding: R.Tensor((3, 10), dtype="float32") =
R.concat(
- (lv9, lv10), axis=-1
+ lv11: R.Tensor((3, 10), dtype="float32") = R.concat((lv9, lv10),
axis=-1)
+ get_timestep_embedding: R.Tensor((3, 10), dtype="float32") =
R.astype(
+ lv11, dtype="float32"
)
gv1: R.Tuple(
R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index f6cb29a87b..c7bef23124 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -386,8 +386,9 @@ def test_timestep_embedding():
lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7)
lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8)
lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8)
- get_timestep_embedding: R.Tensor((3, 10), dtype="float32") =
R.concat(
- (lv9, lv10), axis=-1
+ lv11: R.Tensor((3, 10), dtype="float32") = R.concat((lv9, lv10),
axis=-1)
+ get_timestep_embedding: R.Tensor((3, 10), dtype="float32") =
R.astype(
+ lv11, dtype="float32"
)
gv1: R.Tuple(
R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)