This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dee3c2ab28 [TIR] Handle DeclBuffer in LowerCustomDatatypes (#15041)
dee3c2ab28 is described below

commit dee3c2ab2856c70c41c42c1cd3644fdd2f008d18
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Jun 10 01:04:35 2023 -0400

    [TIR] Handle DeclBuffer in LowerCustomDatatypes (#15041)
    
    * [TIR] Handle DeclBuffer in LowerCustomDatatypes
    
    Preserve DeclBuffer node when transforming with `LowerCustomDatatypes`
    This is a subset of changes, being split out from
    https://github.com/apache/tvm/pull/14778 into independent portions.
    
    * Fix lint error
    
    * Fix parsing error introduced by lint fix
---
 src/tir/transforms/lower_custom_datatypes.cc   |   5 +
 tests/python/unittest/test_custom_datatypes.py | 260 +++++++++++++++++--------
 2 files changed, 180 insertions(+), 85 deletions(-)

diff --git a/src/tir/transforms/lower_custom_datatypes.cc 
b/src/tir/transforms/lower_custom_datatypes.cc
index b2f95ad2d5..c5bcda2eff 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -103,6 +103,11 @@ class CustomDatatypesLowerer : public StmtExprMutator {
     }
   }
 
+  Stmt VisitStmt_(const DeclBufferNode* op) final {
+    auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
     auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
     auto modified = VisitBufferAccess(node);
diff --git a/tests/python/unittest/test_custom_datatypes.py 
b/tests/python/unittest/test_custom_datatypes.py
index 5c7a429317..41ccec5ad2 100644
--- a/tests/python/unittest/test_custom_datatypes.py
+++ b/tests/python/unittest/test_custom_datatypes.py
@@ -34,6 +34,7 @@ from tvm.target.datatype import (
     register_op,
 )
 from tvm.tir.op import call_pure_extern
+from tvm.script import tir as T
 
 
 # note: we can't use relay.testing models because params are randomly 
initialized,
@@ -116,88 +117,105 @@ def setup_myfloat():
     Own Datatypes framework.
     """
 
-    # To use datatype operations in an external library, you should first load
-    # the library containing the datatype implementation:
-    # CDLL("libposit.so", RTLD_GLOBAL)
-    # In this case, the datatype library we are using is built right into TVM,
-    # so we do not need to explicitly load any library.
+    def _setup_myfloat_inner():
+        # To use datatype operations in an external library, you should first 
load
+        # the library containing the datatype implementation:
+        # CDLL("libposit.so", RTLD_GLOBAL)
+        # In this case, the datatype library we are using is built right into 
TVM,
+        # so we do not need to explicitly load any library.
 
-    # You can pick a code for your datatype arbitrarily, as long as it is
-    # greater than 128 and has not already been chosen.
-    register("myfloat", 131)
+        # You can pick a code for your datatype arbitrarily, as long as it is
+        # greater than 128 and has not already been chosen.
+        register("myfloat", 131)
 
-    register_op(
-        create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", 
"float", "myfloat"
-    )
-    register_op(
-        create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", 
"myfloat", "float"
-    )
-    register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", 
"myfloat")
-    register_op(
-        create_lower_func(
-            {
-                32: "Custom32Sub",
-            }
-        ),
-        "Sub",
-        "llvm",
-        "myfloat",
-    )
-    register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", 
"myfloat")
-    register_op(
-        create_lower_func(
-            {
-                32: "FloatToCustom32",
-            }
-        ),
-        "FloatImm",
-        "llvm",
-        "myfloat",
-    )
-    register_op(
-        create_lower_func(
-            {
-                32: "Custom32Div",
-            }
-        ),
-        "Div",
-        "llvm",
-        "myfloat",
-    )
-    register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", 
"myfloat")
-    register_op(
-        create_lower_func({32: "Custom32Sqrt"}),
-        "Call",
-        "llvm",
-        "myfloat",
-        intrinsic_name="tir.sqrt",
-    )
-    register_op(
-        create_lower_func({32: "Custom32Exp"}), "Call", "llvm", "myfloat", 
intrinsic_name="tir.exp"
-    )
-    register_op(
-        create_lower_func({32: "Custom32Log"}), "Call", "llvm", "myfloat", 
intrinsic_name="tir.log"
-    )
-    register_op(
-        create_lower_func({32: "Custom32Sigmoid"}),
-        "Call",
-        "llvm",
-        "myfloat",
-        intrinsic_name="tir.sigmoid",
-    )
-    register_op(
-        create_lower_func({32: "Custom32Tanh"}),
-        "Call",
-        "llvm",
-        "myfloat",
-        intrinsic_name="tir.tanh",
-    )
-    register_op(lower_ite, "Call", "llvm", "myfloat", 
intrinsic_name="tir.if_then_else")
-    register_op(
-        lower_call_pure_extern, "Call", "llvm", "myfloat", 
intrinsic_name="tir.call_pure_extern"
-    )
+        register_op(
+            create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", 
"float", "myfloat"
+        )
+        register_op(
+            create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", 
"myfloat", "float"
+        )
+        register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", 
"myfloat")
+        register_op(
+            create_lower_func(
+                {
+                    32: "Custom32Sub",
+                }
+            ),
+            "Sub",
+            "llvm",
+            "myfloat",
+        )
+        register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", 
"myfloat")
+        register_op(
+            create_lower_func(
+                {
+                    32: "FloatToCustom32",
+                }
+            ),
+            "FloatImm",
+            "llvm",
+            "myfloat",
+        )
+        register_op(
+            create_lower_func(
+                {
+                    32: "Custom32Div",
+                }
+            ),
+            "Div",
+            "llvm",
+            "myfloat",
+        )
+        register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", 
"myfloat")
+        register_op(
+            create_lower_func({32: "Custom32Sqrt"}),
+            "Call",
+            "llvm",
+            "myfloat",
+            intrinsic_name="tir.sqrt",
+        )
+        register_op(
+            create_lower_func({32: "Custom32Exp"}),
+            "Call",
+            "llvm",
+            "myfloat",
+            intrinsic_name="tir.exp",
+        )
+        register_op(
+            create_lower_func({32: "Custom32Log"}),
+            "Call",
+            "llvm",
+            "myfloat",
+            intrinsic_name="tir.log",
+        )
+        register_op(
+            create_lower_func({32: "Custom32Sigmoid"}),
+            "Call",
+            "llvm",
+            "myfloat",
+            intrinsic_name="tir.sigmoid",
+        )
+        register_op(
+            create_lower_func({32: "Custom32Tanh"}),
+            "Call",
+            "llvm",
+            "myfloat",
+            intrinsic_name="tir.tanh",
+        )
+        register_op(lower_ite, "Call", "llvm", "myfloat", 
intrinsic_name="tir.if_then_else")
+        register_op(
+            lower_call_pure_extern, "Call", "llvm", "myfloat", 
intrinsic_name="tir.call_pure_extern"
+        )
+
+        register_min_func(create_min_lower_func({32: "MinCustom32"}, 
"myfloat"), "myfloat")
 
-    register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), 
"myfloat")
+    try:
+        _setup_myfloat_inner()
+    except tvm._ffi.base.TVMError as e:
+        # Ignore this specific error which can happen if another test
+        # that uses "myfloat" has already run.
+        if "float is already registered" not in str(e):
+            raise e
 
 
 def setup_posites2():
@@ -513,12 +531,8 @@ def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6, 
atol=1e-6):
 
 
 def test_myfloat():
-    try:
-        setup_myfloat()
-    except tvm._ffi.base.TVMError as e:
-        if "float is already registered" not in str(e):
-            # Ignore this specific error which can happen if this test runs 
twice within the same process
-            raise e
+    setup_myfloat()
+
     run_ops("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
     run_conv2d("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
     run_batchnorm("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
@@ -529,6 +543,82 @@ def test_myfloat():
     #           'custom[myfloat]32')
 
 
+class TestMyfloatLowering(tvm.testing.CompareBeforeAfter):
+    setup_myfloat()
+
+    transform = tvm.tir.transform.LowerCustomDatatypes()
+
+    def before(self):
+        dtype = "custom[myfloat]32"
+
+        @T.prim_func
+        def func(A_data: T.handle(dtype)):
+            T.func_attr({"target": T.target("llvm")})
+            A = T.Buffer(16, dtype=dtype, data=A_data)
+            B_data = T.allocate([16], dtype=dtype)
+            B = T.Buffer(16, dtype=dtype, data=B_data)
+            for i in range(16):
+                B[i] = A[i] + 1.0
+
+        return func
+
+    def expected(self):
+        dtype = "custom[myfloat]32"
+
+        @T.prim_func
+        def func(A_data: T.handle(dtype)):
+            T.func_attr({"target": T.target("llvm")})
+            A_uint32 = T.Buffer(16, "uint32", data=A_data)
+            B_data = T.allocate([16], dtype="uint32")
+            B_uint32 = T.Buffer(16, "uint32", data=B_data)
+            for i in range(16):
+                B_uint32[i] = T.call_pure_extern(
+                    "uint32",
+                    "FloatToCustom32",
+                    T.call_pure_extern("float32", "Custom32ToFloat", 
A_uint32[i]) + T.float32(1),
+                )
+
+        return func
+
+
+class TestMyfloatLoweringDeclBuffer(tvm.testing.CompareBeforeAfter):
+    """Like TestMyfloatLoweringDeclBuffer, but using DeclBuffer"""
+
+    setup_myfloat()
+
+    transform = tvm.tir.transform.LowerCustomDatatypes()
+
+    def before(self):
+        dtype = "custom[myfloat]32"
+
+        @T.prim_func
+        def func(A_data: T.handle(dtype)):
+            T.func_attr({"target": T.target("llvm")})
+            A = T.decl_buffer(16, dtype=dtype, data=A_data)
+            B = T.decl_buffer(16, dtype=dtype)
+            for i in range(16):
+                B[i] = A[i] + 1.0
+
+        return func
+
+    def expected(self):
+        dtype = "custom[myfloat]32"
+
+        @T.prim_func
+        def func(A_data: T.handle(dtype)):
+            T.func_attr({"target": T.target("llvm")})
+            A_uint32 = T.decl_buffer(16, "uint32", data=A_data)
+            B_uint32 = T.decl_buffer(16, dtype="uint32")
+            for i in range(16):
+                B_uint32[i] = T.call_pure_extern(
+                    "uint32",
+                    "FloatToCustom32",
+                    T.call_pure_extern("float32", "Custom32ToFloat", 
A_uint32[i]) + T.float32(1),
+                )
+
+        return func
+
+
 def _has_posit():
     return tvm.support.libinfo()["USE_BYODT_POSIT"] == "ON"
 

Reply via email to