This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 676def7c9d [Unity][Frontend][NN] Switch to always using fp32 for 
timestep embedding calculation (#15811)
676def7c9d is described below

commit 676def7c9d47fe54d5416ff1e2a787294a799b01
Author: Josh Fromm <[email protected]>
AuthorDate: Sat Sep 23 00:09:41 2023 -0700

    [Unity][Frontend][NN] Switch to always using fp32 for timestep embedding 
calculation (#15811)
    
    It turns out that using datatypes besides `float32` when calculating 
timestep embeddings leads to substantial errors. The diffusers library just 
hard casts to float32 to avoid this and we should too.
---
 python/tvm/relax/frontend/nn/op.py             | 15 ++++++++++-----
 tests/python/relax/test_frontend_nn_modules.py |  5 +++--
 tests/python/relax/test_frontend_nn_op.py      |  5 +++--
 3 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 8afeb2c118..0f4c5f6483 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1220,19 +1220,21 @@ def get_timestep_embedding(
         [N x dim] Tensor of positional embeddings.
     """
     dtype = get_default_dtype()
-    timesteps = _op.astype(x._expr, dtype)
+
+    # Arithmetic should be done in float for precision.
+    timesteps = _op.astype(x._expr, "float32")
 
     half_dim = embedding_dim // 2
-    exponent = rx.const(-math.log(max_period), dtype) * _op.arange(
-        start=0, end=half_dim, dtype=dtype
+    exponent = rx.const(-math.log(max_period), "float32") * _op.arange(
+        start=0, end=half_dim, dtype="float32"
     )
-    exponent = exponent / (rx.const(half_dim - downscale_freq_shift, dtype))
+    exponent = exponent / (rx.const(half_dim - downscale_freq_shift, 
"float32"))
 
     emb = _op.exp(exponent)
     emb = _op.expand_dims(timesteps, 1) * _op.expand_dims(emb, 0)
     # Scale embeddings
     if scale != 1:
-        emb = rx.const(scale, dtype) * emb
+        emb = rx.const(scale, "float32") * emb
 
     # Concat sine and cosine embeddings.
     if flip_sin_to_cos:
@@ -1243,6 +1245,9 @@ def get_timestep_embedding(
     # Zero pad
     if embedding_dim % 2 == 1:
         emb = _op.nn.pad(emb, (0, 1, 0, 0))
+
+    # Cast to proper output type
+    emb = _op.astype(emb, dtype)
     return _wrap_nested(emb, name)
 
 
diff --git a/tests/python/relax/test_frontend_nn_modules.py 
b/tests/python/relax/test_frontend_nn_modules.py
index f3d248eab4..cb207954e8 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -402,8 +402,9 @@ def test_timesteps():
             lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7)
             lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8)
             lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8)
-            get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = 
R.concat(
-                (lv9, lv10), axis=-1
+            lv11: R.Tensor((3, 10), dtype="float32") = R.concat((lv9, lv10), 
axis=-1)
+            get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = 
R.astype(
+                lv11, dtype="float32"
             )
             gv1: R.Tuple(
                 R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index f6cb29a87b..c7bef23124 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -386,8 +386,9 @@ def test_timestep_embedding():
             lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7)
             lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8)
             lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8)
-            get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = 
R.concat(
-                (lv9, lv10), axis=-1
+            lv11: R.Tensor((3, 10), dtype="float32") = R.concat((lv9, lv10), 
axis=-1)
+            get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = 
R.astype(
+                lv11, dtype="float32"
             )
             gv1: R.Tuple(
                 R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)

Reply via email to