This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dee3c2ab28 [TIR] Handle DeclBuffer in LowerCustomDatatypes (#15041)
dee3c2ab28 is described below
commit dee3c2ab2856c70c41c42c1cd3644fdd2f008d18
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Jun 10 01:04:35 2023 -0400
[TIR] Handle DeclBuffer in LowerCustomDatatypes (#15041)
* [TIR] Handle DeclBuffer in LowerCustomDatatypes
Preserve DeclBuffer node when transforming with `LowerCustomDatatypes`
This is a subset of changes, being split out from
https://github.com/apache/tvm/pull/14778 into independent portions.
* Fix lint error
* Fix parsing error introduced by lint fix
---
src/tir/transforms/lower_custom_datatypes.cc | 5 +
tests/python/unittest/test_custom_datatypes.py | 260 +++++++++++++++++--------
2 files changed, 180 insertions(+), 85 deletions(-)
diff --git a/src/tir/transforms/lower_custom_datatypes.cc
b/src/tir/transforms/lower_custom_datatypes.cc
index b2f95ad2d5..c5bcda2eff 100644
--- a/src/tir/transforms/lower_custom_datatypes.cc
+++ b/src/tir/transforms/lower_custom_datatypes.cc
@@ -103,6 +103,11 @@ class CustomDatatypesLowerer : public StmtExprMutator {
}
}
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+ return VisitBufferAccess(std::move(node));
+ }
+
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto modified = VisitBufferAccess(node);
diff --git a/tests/python/unittest/test_custom_datatypes.py
b/tests/python/unittest/test_custom_datatypes.py
index 5c7a429317..41ccec5ad2 100644
--- a/tests/python/unittest/test_custom_datatypes.py
+++ b/tests/python/unittest/test_custom_datatypes.py
@@ -34,6 +34,7 @@ from tvm.target.datatype import (
register_op,
)
from tvm.tir.op import call_pure_extern
+from tvm.script import tir as T
# note: we can't use relay.testing models because params are randomly
initialized,
@@ -116,88 +117,105 @@ def setup_myfloat():
Own Datatypes framework.
"""
- # To use datatype operations in an external library, you should first load
- # the library containing the datatype implementation:
- # CDLL("libposit.so", RTLD_GLOBAL)
- # In this case, the datatype library we are using is built right into TVM,
- # so we do not need to explicitly load any library.
+ def _setup_myfloat_inner():
+ # To use datatype operations in an external library, you should first
load
+ # the library containing the datatype implementation:
+ # CDLL("libposit.so", RTLD_GLOBAL)
+ # In this case, the datatype library we are using is built right into
TVM,
+ # so we do not need to explicitly load any library.
- # You can pick a code for your datatype arbitrarily, as long as it is
- # greater than 128 and has not already been chosen.
- register("myfloat", 131)
+ # You can pick a code for your datatype arbitrarily, as long as it is
+ # greater than 128 and has not already been chosen.
+ register("myfloat", 131)
- register_op(
- create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm",
"float", "myfloat"
- )
- register_op(
- create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm",
"myfloat", "float"
- )
- register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm",
"myfloat")
- register_op(
- create_lower_func(
- {
- 32: "Custom32Sub",
- }
- ),
- "Sub",
- "llvm",
- "myfloat",
- )
- register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm",
"myfloat")
- register_op(
- create_lower_func(
- {
- 32: "FloatToCustom32",
- }
- ),
- "FloatImm",
- "llvm",
- "myfloat",
- )
- register_op(
- create_lower_func(
- {
- 32: "Custom32Div",
- }
- ),
- "Div",
- "llvm",
- "myfloat",
- )
- register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm",
"myfloat")
- register_op(
- create_lower_func({32: "Custom32Sqrt"}),
- "Call",
- "llvm",
- "myfloat",
- intrinsic_name="tir.sqrt",
- )
- register_op(
- create_lower_func({32: "Custom32Exp"}), "Call", "llvm", "myfloat",
intrinsic_name="tir.exp"
- )
- register_op(
- create_lower_func({32: "Custom32Log"}), "Call", "llvm", "myfloat",
intrinsic_name="tir.log"
- )
- register_op(
- create_lower_func({32: "Custom32Sigmoid"}),
- "Call",
- "llvm",
- "myfloat",
- intrinsic_name="tir.sigmoid",
- )
- register_op(
- create_lower_func({32: "Custom32Tanh"}),
- "Call",
- "llvm",
- "myfloat",
- intrinsic_name="tir.tanh",
- )
- register_op(lower_ite, "Call", "llvm", "myfloat",
intrinsic_name="tir.if_then_else")
- register_op(
- lower_call_pure_extern, "Call", "llvm", "myfloat",
intrinsic_name="tir.call_pure_extern"
- )
+ register_op(
+ create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm",
"float", "myfloat"
+ )
+ register_op(
+ create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm",
"myfloat", "float"
+ )
+ register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm",
"myfloat")
+ register_op(
+ create_lower_func(
+ {
+ 32: "Custom32Sub",
+ }
+ ),
+ "Sub",
+ "llvm",
+ "myfloat",
+ )
+ register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm",
"myfloat")
+ register_op(
+ create_lower_func(
+ {
+ 32: "FloatToCustom32",
+ }
+ ),
+ "FloatImm",
+ "llvm",
+ "myfloat",
+ )
+ register_op(
+ create_lower_func(
+ {
+ 32: "Custom32Div",
+ }
+ ),
+ "Div",
+ "llvm",
+ "myfloat",
+ )
+ register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm",
"myfloat")
+ register_op(
+ create_lower_func({32: "Custom32Sqrt"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.sqrt",
+ )
+ register_op(
+ create_lower_func({32: "Custom32Exp"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.exp",
+ )
+ register_op(
+ create_lower_func({32: "Custom32Log"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.log",
+ )
+ register_op(
+ create_lower_func({32: "Custom32Sigmoid"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.sigmoid",
+ )
+ register_op(
+ create_lower_func({32: "Custom32Tanh"}),
+ "Call",
+ "llvm",
+ "myfloat",
+ intrinsic_name="tir.tanh",
+ )
+ register_op(lower_ite, "Call", "llvm", "myfloat",
intrinsic_name="tir.if_then_else")
+ register_op(
+ lower_call_pure_extern, "Call", "llvm", "myfloat",
intrinsic_name="tir.call_pure_extern"
+ )
+
+ register_min_func(create_min_lower_func({32: "MinCustom32"},
"myfloat"), "myfloat")
- register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"),
"myfloat")
+ try:
+ _setup_myfloat_inner()
+ except tvm._ffi.base.TVMError as e:
+ # Ignore this specific error which can happen if another test
+ # that uses "myfloat" has already run.
+ if "float is already registered" not in str(e):
+ raise e
def setup_posites2():
@@ -513,12 +531,8 @@ def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6,
atol=1e-6):
def test_myfloat():
- try:
- setup_myfloat()
- except tvm._ffi.base.TVMError as e:
- if "float is already registered" not in str(e):
- # Ignore this specific error which can happen if this test runs
twice within the same process
- raise e
+ setup_myfloat()
+
run_ops("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
run_conv2d("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
run_batchnorm("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
@@ -529,6 +543,82 @@ def test_myfloat():
# 'custom[myfloat]32')
+class TestMyfloatLowering(tvm.testing.CompareBeforeAfter):
+ setup_myfloat()
+
+ transform = tvm.tir.transform.LowerCustomDatatypes()
+
+ def before(self):
+ dtype = "custom[myfloat]32"
+
+ @T.prim_func
+ def func(A_data: T.handle(dtype)):
+ T.func_attr({"target": T.target("llvm")})
+ A = T.Buffer(16, dtype=dtype, data=A_data)
+ B_data = T.allocate([16], dtype=dtype)
+ B = T.Buffer(16, dtype=dtype, data=B_data)
+ for i in range(16):
+ B[i] = A[i] + 1.0
+
+ return func
+
+ def expected(self):
+ dtype = "custom[myfloat]32"
+
+ @T.prim_func
+ def func(A_data: T.handle(dtype)):
+ T.func_attr({"target": T.target("llvm")})
+ A_uint32 = T.Buffer(16, "uint32", data=A_data)
+ B_data = T.allocate([16], dtype="uint32")
+ B_uint32 = T.Buffer(16, "uint32", data=B_data)
+ for i in range(16):
+ B_uint32[i] = T.call_pure_extern(
+ "uint32",
+ "FloatToCustom32",
+ T.call_pure_extern("float32", "Custom32ToFloat",
A_uint32[i]) + T.float32(1),
+ )
+
+ return func
+
+
+class TestMyfloatLoweringDeclBuffer(tvm.testing.CompareBeforeAfter):
+ """Like TestMyfloatLoweringDeclBuffer, but using DeclBuffer"""
+
+ setup_myfloat()
+
+ transform = tvm.tir.transform.LowerCustomDatatypes()
+
+ def before(self):
+ dtype = "custom[myfloat]32"
+
+ @T.prim_func
+ def func(A_data: T.handle(dtype)):
+ T.func_attr({"target": T.target("llvm")})
+ A = T.decl_buffer(16, dtype=dtype, data=A_data)
+ B = T.decl_buffer(16, dtype=dtype)
+ for i in range(16):
+ B[i] = A[i] + 1.0
+
+ return func
+
+ def expected(self):
+ dtype = "custom[myfloat]32"
+
+ @T.prim_func
+ def func(A_data: T.handle(dtype)):
+ T.func_attr({"target": T.target("llvm")})
+ A_uint32 = T.decl_buffer(16, "uint32", data=A_data)
+ B_uint32 = T.decl_buffer(16, dtype="uint32")
+ for i in range(16):
+ B_uint32[i] = T.call_pure_extern(
+ "uint32",
+ "FloatToCustom32",
+ T.call_pure_extern("float32", "Custom32ToFloat",
A_uint32[i]) + T.float32(1),
+ )
+
+ return func
+
+
def _has_posit():
return tvm.support.libinfo()["USE_BYODT_POSIT"] == "ON"