This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 711a603  [TensorIR][M1c] Lower and build TensorIR (#8044)
711a603 is described below

commit 711a603db83465aa53fa63c7ee9b564690544d16
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun May 16 04:35:17 2021 +0800

    [TensorIR][M1c] Lower and build TensorIR (#8044)
---
 include/tvm/tir/analysis.h                         |   5 +-
 python/tvm/contrib/nvcc.py                         |   2 +-
 python/tvm/driver/build_module.py                  | 101 +++++++++++++-----
 src/te/schedule/schedule_postproc_to_primfunc.cc   |   4 +-
 src/tir/analysis/buffer_access_lca_detector.cc     |  14 ++-
 src/tir/transforms/flatten_buffer.cc               |   6 +-
 .../plan_update_buffer_allocation_location.cc      |   4 +-
 tests/python/unittest/test_lower_build.py          | 117 +++++++++++++++++++++
 .../test_tir_analysis_detect_buffer_access_lca.py  |  14 +++
 9 files changed, 226 insertions(+), 41 deletions(-)

diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 3f2fdce..262ac68 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -192,9 +192,10 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& 
func,
  *        access(BufferLoad, BufferStore) and low-level access(Load, Store and 
opaque access).
  *        The LCA may be a For loop or a Block.
  * \param func The PrimFunc to be detected.
- * \return The Map from buffer to the LCA of all access to it.
+ * \return The Map from buffer to the LCA of all access to it. The lca is 
function root if the
+ *         return stmt is NullOpt.
  */
-TVM_DLL Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func);
+TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& 
func);
 
 // Pass variants of verification analysis
 // directly throws RuntimeError when verification fails.
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 30b5e6d..612be29 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -194,7 +194,7 @@ def find_libdevice_path(arch):
     selected_ver = 0
     selected_path = None
     cuda_ver = get_cuda_version(cuda_path)
-    if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2):
+    if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2, 11.3):
         path = os.path.join(lib_path, "libdevice.10.bc")
     else:
         for fn in os.listdir(lib_path):
diff --git a/python/tvm/driver/build_module.py 
b/python/tvm/driver/build_module.py
index 4682e34..a3d0bb6 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -18,6 +18,8 @@
 # pylint: disable=invalid-name
 """The build utils in python.
 """
+
+from typing import Union, Optional, List, Mapping
 import warnings
 
 import tvm.tir
@@ -25,11 +27,15 @@ import tvm.tir
 from tvm.runtime import ndarray
 from tvm.ir import container
 from tvm.ir import CallingConv
+from tvm.tir import PrimFunc
+from tvm.ir.module import IRModule
 from tvm.ir.transform import PassContext
 from tvm.target import codegen
 from tvm.te import tensor
 from tvm.te import schedule
 from tvm.target import Target
+from tvm.tir.buffer import Buffer
+from tvm.tir.expr import Var
 
 
 def get_binds(args, compact=False, binds=None):
@@ -119,34 +125,40 @@ def form_irmodule(sch, args, name, binds):
     return tvm.IRModule({name: func})
 
 
-def lower(sch, args, name="main", binds=None, simple_mode=False):
+def lower(
+    inputs: Union[schedule.Schedule, PrimFunc, IRModule],
+    args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
+    name: str = "main",
+    binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
+    simple_mode: bool = False,
+) -> IRModule:
     """Lowering step before build into target.
 
     Parameters
     ----------
-    sch : tvm.te.schedule.Schedule
-        The schedule to be built
+    input : Union[schedule.Schedule, PrimFunc, IRModule]
+        The TE schedule or TensorIR PrimFunc/IRModule to be built
 
-    args : list of Buffer or Tensor or Var
-        The argument lists to the function.
+    args : Optional[List[Union[Buffer, tensor.Tensor, Var]]]
+        The argument lists to the function for TE schedule.
+        It should be None if we want to lower TensorIR.
 
-    name : str, optional
+    name : str
         The name of result function.
 
-    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
+    binds : Optional[Mapping[tensor.Tensor, Buffer]]
         Dictionary that maps the Tensor to Buffer which specified the data 
layout
         requirement of the function. By default, a new compact buffer is 
created
         for each tensor in the argument.
 
-    simple_mode : bool, optional
+    simple_mode : bool
         Whether only output simple and compact statement, this will skip
         LoopPartition, api wrapper generation and Unrolling.
 
     Returns
     -------
-    m : IRModule or Stmt
-       The result IRModule, if simple_mode=False
-       Then the Stmt before make api is returned.
+    m : IRModule
+       The result IRModule
     """
     # config setup
     pass_ctx = PassContext.current()
@@ -160,20 +172,46 @@ def lower(sch, args, name="main", binds=None, 
simple_mode=False):
     lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
 
     # Phase 0
-    if isinstance(sch, schedule.Schedule):
-        mod = form_irmodule(sch, args, name, binds)
+    pass_list = lower_phase0
+    is_legacy_te_schedule: bool = False
+
+    if isinstance(inputs, schedule.Schedule):
+        if args is None:
+            raise ValueError("args must be given for lowering from TE 
schedule")
+        mod = form_irmodule(inputs, args, name, binds)
+        is_legacy_te_schedule = True
+    elif isinstance(inputs, PrimFunc):
+        func = inputs.with_attr("global_symbol", name)
+        if pass_ctx.config.get("tir.noalias", True):
+            func = func.with_attr("tir.noalias", True)
+        mod = tvm.IRModule({name: func})
+    elif isinstance(inputs, IRModule):
+        mod = inputs
     else:
-        mod = sch
+        raise TypeError(
+            f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got 
{type(inputs)}"
+        )
 
-    pass_list = lower_phase0
     # Phase 1
+    if is_legacy_te_schedule:
+        pass_list += [
+            tvm.tir.transform.InjectPrefetch(),
+            tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
+        ]
+    else:
+        pass_list += [
+            tvm.tir.transform.LowerInitBlock(),
+            tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
+            tvm.tir.transform.ConvertBlocksToOpaque(),
+            tvm.tir.transform.CompactBufferAllocation(),
+            tvm.tir.transform.FlattenBuffer(),
+        ]
     pass_list += [
-        tvm.tir.transform.InjectPrefetch(),
-        tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
         tvm.tir.transform.BF16Legalize(),
         tvm.tir.transform.NarrowDataType(32),
         tvm.tir.transform.Simplify(),
     ]
+
     pass_list += lower_phase1
 
     # Phase 2
@@ -297,22 +335,29 @@ def _build_for_device(input_mod, target, target_host):
     return mod_host, rt_mod_dev
 
 
-def build(inputs, args=None, target=None, target_host=None, 
name="default_function", binds=None):
+def build(
+    inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, 
IRModule]],
+    args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
+    target: Optional[Union[str, Target]] = None,
+    target_host: Optional[Union[str, Target]] = None,
+    name: Optional[str] = "default_function",
+    binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
+):
     """Build a function with arguments as signature. Code will be generated
     for devices coupled with target information.
 
     Parameters
     ----------
-    inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule
-        The schedule to be built
+    inputs : Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, 
IRModule]]
+        The input to be built
 
-    args : list of Buffer or Tensor or Var, optional
+    args : Optional[List[Union[Buffer, tensor.Tensor, Var]]]
         The argument lists to the function.
 
-    target : str or :any:`tvm.target.Target`, optional
+    target : Optional[Union[str, Target]]
         The target and option of the compilation.
 
-    target_host : str or :any:`tvm.target.Target` optional
+    target_host : Optional[Union[str, Target]]
         Host compilation target, if target is device.
         When TVM compiles device specific program such as CUDA,
         we also need host(CPU) side code to interact with the driver
@@ -321,10 +366,10 @@ def build(inputs, args=None, target=None, 
target_host=None, name="default_functi
         By default, llvm is used if it is enabled,
         otherwise a stackvm intepreter is used.
 
-    name : str, optional
+    name : Optional[str]
         The name of result function.
 
-    binds : dict, optional
+    binds : Optional[Mapping[tensor.Tensor, Buffer]]
         Dictionary that maps the binding of symbolic buffer to Tensor.
         By default, a new buffer is created for each tensor in the argument.
 
@@ -375,10 +420,10 @@ def build(inputs, args=None, target=None, 
target_host=None, name="default_functi
     elif isinstance(inputs, (list, tuple, container.Array)):
         merged_mod = tvm.IRModule({})
         for x in inputs:
-            merged_mod.update(x)
+            merged_mod.update(lower(x))
         input_mod = merged_mod
-    elif isinstance(inputs, tvm.IRModule):
-        input_mod = inputs
+    elif isinstance(inputs, (tvm.IRModule, PrimFunc)):
+        input_mod = lower(inputs)
     elif not isinstance(inputs, (dict, container.Map)):
         raise ValueError(
             f"Inputs must be Schedule, IRModule or dict of target to IRModule, 
"
diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc 
b/src/te/schedule/schedule_postproc_to_primfunc.cc
index 1710a91..32cc510 100644
--- a/src/te/schedule/schedule_postproc_to_primfunc.cc
+++ b/src/te/schedule/schedule_postproc_to_primfunc.cc
@@ -159,13 +159,13 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> 
arg_list, Stmt body,
       ICHECK(!extern_buffer.count(tensor));
 
       tir::Buffer buffer = CreateBufferFor(tensor);
-      tir::Var bptr(buffer->name, DataType::Handle());
+      tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
       params.push_back(bptr);
       buffer_map.Set(bptr, buffer);
       extern_buffer[tensor] = buffer;
     } else {
       tir::Buffer buffer = Downcast<tir::Buffer>(var);
-      tir::Var bptr(buffer->name, DataType::Handle());
+      tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
       params.push_back(bptr);
       buffer_map.Set(bptr, buffer);
     }
diff --git a/src/tir/analysis/buffer_access_lca_detector.cc 
b/src/tir/analysis/buffer_access_lca_detector.cc
index 23e60e1..6f2622f 100644
--- a/src/tir/analysis/buffer_access_lca_detector.cc
+++ b/src/tir/analysis/buffer_access_lca_detector.cc
@@ -36,17 +36,20 @@ namespace tir {
  */
 class LCADetector : public StmtExprVisitor {
  public:
-  static Map<Buffer, Stmt> Detect(const PrimFunc& func) {
+  static Map<Buffer, Optional<Stmt>> Detect(const PrimFunc& func) {
     LCADetector detector;
     for (const auto& kv : func->buffer_map) {
       const Buffer& buffer = kv.second;
       detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get());
     }
+
     detector(func->body);
     // Prepare the return
-    Map<Buffer, Stmt> buffer_lca;
+    Map<Buffer, Optional<Stmt>> buffer_lca;
     for (const auto& kv : detector.buffer_lca_) {
-      buffer_lca.Set(GetRef<Buffer>(kv.first), GetRef<Stmt>(kv.second->stmt));
+      const Buffer& buffer = GetRef<Buffer>(kv.first);
+      const Optional<Stmt> stmt = kv.second ? 
GetRef<Optional<Stmt>>(kv.second->stmt) : NullOpt;
+      buffer_lca.Set(buffer, stmt);
     }
     return buffer_lca;
   }
@@ -131,7 +134,6 @@ class LCADetector : public StmtExprVisitor {
   }
 
   static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const 
ScopeInfo* rhs) {
-    ICHECK(lhs || rhs);
     if (lhs == nullptr) return rhs;
     if (rhs == nullptr) return lhs;
     while (lhs->parent_scope_info != nullptr &&  //
@@ -166,7 +168,9 @@ class LCADetector : public StmtExprVisitor {
   support::Arena arena_;
 };
 
-Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func) { return 
LCADetector::Detect(func); }
+Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func) {
+  return LCADetector::Detect(func);
+}
 
 
TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA);
 }  // namespace tir
diff --git a/src/tir/transforms/flatten_buffer.cc 
b/src/tir/transforms/flatten_buffer.cc
index 82035cb..07f7b42 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -113,7 +113,11 @@ class BufferFlattener : public StmtExprMutator {
     if (it == unit_loop_vars_.end()) {
       return std::move(var);
     } else {
-      return it->second;
+      PrimExpr expr = it->second;
+      if (expr.dtype() != var.dtype()) {
+        expr = Cast(var.dtype(), std::move(expr));
+      }
+      return expr;
     }
   }
 
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc 
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index ecedaa6..949c955 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -32,7 +32,7 @@ namespace tir {
 class BufferAllocationLocator : public StmtExprMutator {
  public:
   explicit BufferAllocationLocator(const PrimFunc& func) {
-    Map<Buffer, Stmt> buffer_lca = DetectBufferAccessLCA(func);
+    Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
     std::unordered_set<const BufferNode*> arg_buffers;
     for (const auto& kv : func->buffer_map) {
       const Buffer& buffer = kv.second;
@@ -42,7 +42,7 @@ class BufferAllocationLocator : public StmtExprMutator {
     // create buffers to be allocated at each stmts
     for (const auto& kv : buffer_lca) {
       const Buffer& buffer = kv.first;
-      const StmtNode* stmt = kv.second.get();
+      const StmtNode* stmt = kv.second.defined() ? kv.second.value().get() : 
nullptr;
       if (arg_buffers.count(buffer.get())) {
         continue;
       }
diff --git a/tests/python/unittest/test_lower_build.py 
b/tests/python/unittest/test_lower_build.py
new file mode 100644
index 0000000..21f4132
--- /dev/null
+++ b/tests/python/unittest/test_lower_build.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+
+import tvm
+from tvm import te, tir
+from tvm.ir.module import IRModule
+from tvm.script import ty
+import tvm.testing
+
+
+def _check_module_with_numpy(mod, shape=(128, 128, 128)):
+    m, n, k = shape
+    a = tvm.nd.array(np.random.rand(m, k).astype("float32"))
+    b = tvm.nd.array(np.random.rand(n, k).astype("float32"))
+    c = tvm.nd.array(np.zeros((m, n), dtype="float32"))
+    c_np = np.dot(a.asnumpy(), b.asnumpy().transpose())
+    mod(a, b, c)
+    tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+
+# pylint: disable=no-self-argument, missing-class-docstring, 
missing-function-docstring
[email protected]
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, [128, 128])
+    B = tir.match_buffer(b, [128, 128])
+    C = tir.match_buffer(c, [128, 128])
+    for i, j in tir.grid(128, 128):
+        with tir.block([128, 128], "init") as [vi, vj]:
+            C[vi, vj] = tir.float32(0)
+        for k in range(0, 128):
+            with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as 
[vi, vj, vk]:
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+class LoweredModule:
+    def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+        # function attr dict
+        tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+        A = tir.match_buffer(a, [128, 128])
+        B = tir.match_buffer(b, [128, 128])
+        C = tir.match_buffer(c, [128, 128])
+        # body
+        for x, y in tir.grid(128, 128):
+            C.data[x * 128 + y] = 0.0
+            for k in tir.serial(0, 128):
+                C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) 
+ tir.load(
+                    "float32", A.data, x * 128 + k
+                ) * tir.load("float32", B.data, y * 128 + k)
+
+
+def test_lower_build_te_schedule():
+    m, n, k = 128, 128, 128
+    axis_k = te.reduce_axis((0, k), "k")
+    A = te.placeholder((m, k), name="A")
+    B = te.placeholder((k, n), name="B")
+    C = te.compute((m, n), lambda x, y: te.sum(A[x, axis_k] * B[y, axis_k], 
axis=axis_k), name="C")
+    s = te.create_schedule(C.op)
+    # check lowering
+    ir_mod = tvm.lower(s, [A, B, C])
+    tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+    # check building
+    mod = tvm.build(s, [A, B, C], target="llvm")
+    _check_module_with_numpy(mod)
+
+
+def test_lower_build_tir_func():
+    # check lowering
+    ir_mod = tvm.lower(matmul)
+    tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+    # check building
+    mod = tvm.build(matmul, target="llvm")
+    _check_module_with_numpy(mod)
+
+
+def test_lower_build_tir_module():
+    func = matmul.with_attr("global_symbol", "main")
+    func = func.with_attr("tir.noalias", True)
+    ir_mod = IRModule({"main": func})
+    # check lowering
+    lowered_mod = tvm.lower(ir_mod)
+    tvm.ir.assert_structural_equal(lowered_mod, LoweredModule())
+    # check building
+    mod = tvm.build(ir_mod, target="llvm")
+    _check_module_with_numpy(mod)
+
+
+def test_lower_build_lowered_module():
+    # check lowering
+    ir_mod = tvm.lower(LoweredModule())
+    tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+    # check building
+    mod = tvm.build(ir_mod, target="llvm")
+    _check_module_with_numpy(mod)
+
+
+if __name__ == "__main__":
+    test_lower_build_te_schedule()
+    test_lower_build_tir_func()
+    test_lower_build_tir_module()
+    test_lower_build_lowered_module()
diff --git 
a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py 
b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
index 7ac61a7..36fd80f 100644
--- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
+++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
@@ -64,6 +64,12 @@ def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None:
             C[vi, vj] = B[vi, vj]
 
 
[email protected]
+def lca_is_func_root(a: ty.handle) -> None:
+    A = tir.match_buffer(a, [0, 0], "float32")
+    A.data[0] = 1.0
+
+
 def test_buffer_load_store():
     func = buffer_load_store_func
     A, B = [func.buffer_map[x] for x in func.params]
@@ -102,6 +108,14 @@ def test_opaque_access():
     assert lca[C] == root_block.body[1].body.body.block
 
 
+def test_lca_func_root():
+    func = lca_is_func_root
+    (A,) = [func.buffer_map[x] for x in func.params]
+    lca = tir.analysis.detect_buffer_access_lca(func)
+    assert lca[A] is None
+
+
 if __name__ == "__main__":
     test_buffer_load_store()
     test_opaque_access()
+    test_lca_func_root()

Reply via email to