This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 711a603 [TensorIR][M1c] Lower and build TensorIR (#8044)
711a603 is described below
commit 711a603db83465aa53fa63c7ee9b564690544d16
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun May 16 04:35:17 2021 +0800
[TensorIR][M1c] Lower and build TensorIR (#8044)
---
include/tvm/tir/analysis.h | 5 +-
python/tvm/contrib/nvcc.py | 2 +-
python/tvm/driver/build_module.py | 101 +++++++++++++-----
src/te/schedule/schedule_postproc_to_primfunc.cc | 4 +-
src/tir/analysis/buffer_access_lca_detector.cc | 14 ++-
src/tir/transforms/flatten_buffer.cc | 6 +-
.../plan_update_buffer_allocation_location.cc | 4 +-
tests/python/unittest/test_lower_build.py | 117 +++++++++++++++++++++
.../test_tir_analysis_detect_buffer_access_lca.py | 14 +++
9 files changed, 226 insertions(+), 41 deletions(-)
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 3f2fdce..262ac68 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -192,9 +192,10 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc&
func,
* access(BufferLoad, BufferStore) and low-level access(Load, Store and
opaque access).
* The LCA may be a For loop or a Block.
* \param func The PrimFunc to be detected.
- * \return The Map from buffer to the LCA of all access to it.
+ * \return The Map from buffer to the LCA of all access to it. The lca is
function root if the
+ * return stmt is NullOpt.
*/
-TVM_DLL Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func);
+TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc&
func);
// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 30b5e6d..612be29 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -194,7 +194,7 @@ def find_libdevice_path(arch):
selected_ver = 0
selected_path = None
cuda_ver = get_cuda_version(cuda_path)
- if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2):
+ if cuda_ver in (9.0, 9.1, 10.0, 10.1, 10.2, 11.0, 11.1, 11.2, 11.3):
path = os.path.join(lib_path, "libdevice.10.bc")
else:
for fn in os.listdir(lib_path):
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index 4682e34..a3d0bb6 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -18,6 +18,8 @@
# pylint: disable=invalid-name
"""The build utils in python.
"""
+
+from typing import Union, Optional, List, Mapping
import warnings
import tvm.tir
@@ -25,11 +27,15 @@ import tvm.tir
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.ir import CallingConv
+from tvm.tir import PrimFunc
+from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.target import codegen
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
+from tvm.tir.buffer import Buffer
+from tvm.tir.expr import Var
def get_binds(args, compact=False, binds=None):
@@ -119,34 +125,40 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func})
-def lower(sch, args, name="main", binds=None, simple_mode=False):
+def lower(
+ inputs: Union[schedule.Schedule, PrimFunc, IRModule],
+ args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
+ name: str = "main",
+ binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
+ simple_mode: bool = False,
+) -> IRModule:
"""Lowering step before build into target.
Parameters
----------
- sch : tvm.te.schedule.Schedule
- The schedule to be built
+ input : Union[schedule.Schedule, PrimFunc, IRModule]
+ The TE schedule or TensorIR PrimFunc/IRModule to be built
- args : list of Buffer or Tensor or Var
- The argument lists to the function.
+ args : Optional[List[Union[Buffer, tensor.Tensor, Var]]]
+ The argument lists to the function for TE schedule.
+ It should be None if we want to lower TensorIR.
- name : str, optional
+ name : str
The name of result function.
- binds : dict of :any:`Tensor` to :any:`Buffer`, optional
+ binds : Optional[Mapping[tensor.Tensor, Buffer]]
Dictionary that maps the Tensor to Buffer which specified the data
layout
requirement of the function. By default, a new compact buffer is
created
for each tensor in the argument.
- simple_mode : bool, optional
+ simple_mode : bool
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
Returns
-------
- m : IRModule or Stmt
- The result IRModule, if simple_mode=False
- Then the Stmt before make api is returned.
+ m : IRModule
+ The result IRModule
"""
# config setup
pass_ctx = PassContext.current()
@@ -160,20 +172,46 @@ def lower(sch, args, name="main", binds=None,
simple_mode=False):
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# Phase 0
- if isinstance(sch, schedule.Schedule):
- mod = form_irmodule(sch, args, name, binds)
+ pass_list = lower_phase0
+ is_legacy_te_schedule: bool = False
+
+ if isinstance(inputs, schedule.Schedule):
+ if args is None:
+ raise ValueError("args must be given for lowering from TE
schedule")
+ mod = form_irmodule(inputs, args, name, binds)
+ is_legacy_te_schedule = True
+ elif isinstance(inputs, PrimFunc):
+ func = inputs.with_attr("global_symbol", name)
+ if pass_ctx.config.get("tir.noalias", True):
+ func = func.with_attr("tir.noalias", True)
+ mod = tvm.IRModule({name: func})
+ elif isinstance(inputs, IRModule):
+ mod = inputs
else:
- mod = sch
+ raise TypeError(
+ f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got
{type(inputs)}"
+ )
- pass_list = lower_phase0
# Phase 1
+ if is_legacy_te_schedule:
+ pass_list += [
+ tvm.tir.transform.InjectPrefetch(),
+ tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
+ ]
+ else:
+ pass_list += [
+ tvm.tir.transform.LowerInitBlock(),
+ tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
+ tvm.tir.transform.ConvertBlocksToOpaque(),
+ tvm.tir.transform.CompactBufferAllocation(),
+ tvm.tir.transform.FlattenBuffer(),
+ ]
pass_list += [
- tvm.tir.transform.InjectPrefetch(),
- tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
tvm.tir.transform.BF16Legalize(),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
]
+
pass_list += lower_phase1
# Phase 2
@@ -297,22 +335,29 @@ def _build_for_device(input_mod, target, target_host):
return mod_host, rt_mod_dev
-def build(inputs, args=None, target=None, target_host=None,
name="default_function", binds=None):
+def build(
+ inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str,
IRModule]],
+ args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
+ target: Optional[Union[str, Target]] = None,
+ target_host: Optional[Union[str, Target]] = None,
+ name: Optional[str] = "default_function",
+ binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
+):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.
Parameters
----------
- inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule
- The schedule to be built
+ inputs : Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str,
IRModule]]
+ The input to be built
- args : list of Buffer or Tensor or Var, optional
+ args : Optional[List[Union[Buffer, tensor.Tensor, Var]]]
The argument lists to the function.
- target : str or :any:`tvm.target.Target`, optional
+ target : Optional[Union[str, Target]]
The target and option of the compilation.
- target_host : str or :any:`tvm.target.Target` optional
+ target_host : Optional[Union[str, Target]]
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
@@ -321,10 +366,10 @@ def build(inputs, args=None, target=None,
target_host=None, name="default_functi
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
- name : str, optional
+ name : Optional[str]
The name of result function.
- binds : dict, optional
+ binds : Optional[Mapping[tensor.Tensor, Buffer]]
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
@@ -375,10 +420,10 @@ def build(inputs, args=None, target=None,
target_host=None, name="default_functi
elif isinstance(inputs, (list, tuple, container.Array)):
merged_mod = tvm.IRModule({})
for x in inputs:
- merged_mod.update(x)
+ merged_mod.update(lower(x))
input_mod = merged_mod
- elif isinstance(inputs, tvm.IRModule):
- input_mod = inputs
+ elif isinstance(inputs, (tvm.IRModule, PrimFunc)):
+ input_mod = lower(inputs)
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError(
f"Inputs must be Schedule, IRModule or dict of target to IRModule,
"
diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc
b/src/te/schedule/schedule_postproc_to_primfunc.cc
index 1710a91..32cc510 100644
--- a/src/te/schedule/schedule_postproc_to_primfunc.cc
+++ b/src/te/schedule/schedule_postproc_to_primfunc.cc
@@ -159,13 +159,13 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef>
arg_list, Stmt body,
ICHECK(!extern_buffer.count(tensor));
tir::Buffer buffer = CreateBufferFor(tensor);
- tir::Var bptr(buffer->name, DataType::Handle());
+ tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
extern_buffer[tensor] = buffer;
} else {
tir::Buffer buffer = Downcast<tir::Buffer>(var);
- tir::Var bptr(buffer->name, DataType::Handle());
+ tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
}
diff --git a/src/tir/analysis/buffer_access_lca_detector.cc
b/src/tir/analysis/buffer_access_lca_detector.cc
index 23e60e1..6f2622f 100644
--- a/src/tir/analysis/buffer_access_lca_detector.cc
+++ b/src/tir/analysis/buffer_access_lca_detector.cc
@@ -36,17 +36,20 @@ namespace tir {
*/
class LCADetector : public StmtExprVisitor {
public:
- static Map<Buffer, Stmt> Detect(const PrimFunc& func) {
+ static Map<Buffer, Optional<Stmt>> Detect(const PrimFunc& func) {
LCADetector detector;
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get());
}
+
detector(func->body);
// Prepare the return
- Map<Buffer, Stmt> buffer_lca;
+ Map<Buffer, Optional<Stmt>> buffer_lca;
for (const auto& kv : detector.buffer_lca_) {
- buffer_lca.Set(GetRef<Buffer>(kv.first), GetRef<Stmt>(kv.second->stmt));
+ const Buffer& buffer = GetRef<Buffer>(kv.first);
+ const Optional<Stmt> stmt = kv.second ?
GetRef<Optional<Stmt>>(kv.second->stmt) : NullOpt;
+ buffer_lca.Set(buffer, stmt);
}
return buffer_lca;
}
@@ -131,7 +134,6 @@ class LCADetector : public StmtExprVisitor {
}
static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const
ScopeInfo* rhs) {
- ICHECK(lhs || rhs);
if (lhs == nullptr) return rhs;
if (rhs == nullptr) return lhs;
while (lhs->parent_scope_info != nullptr && //
@@ -166,7 +168,9 @@ class LCADetector : public StmtExprVisitor {
support::Arena arena_;
};
-Map<Buffer, Stmt> DetectBufferAccessLCA(const PrimFunc& func) { return
LCADetector::Detect(func); }
+Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func) {
+ return LCADetector::Detect(func);
+}
TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA);
} // namespace tir
diff --git a/src/tir/transforms/flatten_buffer.cc
b/src/tir/transforms/flatten_buffer.cc
index 82035cb..07f7b42 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -113,7 +113,11 @@ class BufferFlattener : public StmtExprMutator {
if (it == unit_loop_vars_.end()) {
return std::move(var);
} else {
- return it->second;
+ PrimExpr expr = it->second;
+ if (expr.dtype() != var.dtype()) {
+ expr = Cast(var.dtype(), std::move(expr));
+ }
+ return expr;
}
}
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index ecedaa6..949c955 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -32,7 +32,7 @@ namespace tir {
class BufferAllocationLocator : public StmtExprMutator {
public:
explicit BufferAllocationLocator(const PrimFunc& func) {
- Map<Buffer, Stmt> buffer_lca = DetectBufferAccessLCA(func);
+ Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
std::unordered_set<const BufferNode*> arg_buffers;
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
@@ -42,7 +42,7 @@ class BufferAllocationLocator : public StmtExprMutator {
// create buffers to be allocated at each stmts
for (const auto& kv : buffer_lca) {
const Buffer& buffer = kv.first;
- const StmtNode* stmt = kv.second.get();
+ const StmtNode* stmt = kv.second.defined() ? kv.second.value().get() :
nullptr;
if (arg_buffers.count(buffer.get())) {
continue;
}
diff --git a/tests/python/unittest/test_lower_build.py
b/tests/python/unittest/test_lower_build.py
new file mode 100644
index 0000000..21f4132
--- /dev/null
+++ b/tests/python/unittest/test_lower_build.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+
+import tvm
+from tvm import te, tir
+from tvm.ir.module import IRModule
+from tvm.script import ty
+import tvm.testing
+
+
+def _check_module_with_numpy(mod, shape=(128, 128, 128)):
+ m, n, k = shape
+ a = tvm.nd.array(np.random.rand(m, k).astype("float32"))
+ b = tvm.nd.array(np.random.rand(n, k).astype("float32"))
+ c = tvm.nd.array(np.zeros((m, n), dtype="float32"))
+ c_np = np.dot(a.asnumpy(), b.asnumpy().transpose())
+ mod(a, b, c)
+ tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+
+# pylint: disable=no-self-argument, missing-class-docstring,
missing-function-docstring
[email protected]
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "init") as [vi, vj]:
+ C[vi, vj] = tir.float32(0)
+ for k in range(0, 128):
+ with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as
[vi, vj, vk]:
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+class LoweredModule:
+ def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ # function attr dict
+ tir.func_attr({"global_symbol": "main", "tir.noalias": True})
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+ # body
+ for x, y in tir.grid(128, 128):
+ C.data[x * 128 + y] = 0.0
+ for k in tir.serial(0, 128):
+ C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y)
+ tir.load(
+ "float32", A.data, x * 128 + k
+ ) * tir.load("float32", B.data, y * 128 + k)
+
+
+def test_lower_build_te_schedule():
+ m, n, k = 128, 128, 128
+ axis_k = te.reduce_axis((0, k), "k")
+ A = te.placeholder((m, k), name="A")
+ B = te.placeholder((k, n), name="B")
+ C = te.compute((m, n), lambda x, y: te.sum(A[x, axis_k] * B[y, axis_k],
axis=axis_k), name="C")
+ s = te.create_schedule(C.op)
+ # check lowering
+ ir_mod = tvm.lower(s, [A, B, C])
+ tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+ # check building
+ mod = tvm.build(s, [A, B, C], target="llvm")
+ _check_module_with_numpy(mod)
+
+
+def test_lower_build_tir_func():
+ # check lowering
+ ir_mod = tvm.lower(matmul)
+ tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+ # check building
+ mod = tvm.build(matmul, target="llvm")
+ _check_module_with_numpy(mod)
+
+
+def test_lower_build_tir_module():
+ func = matmul.with_attr("global_symbol", "main")
+ func = func.with_attr("tir.noalias", True)
+ ir_mod = IRModule({"main": func})
+ # check lowering
+ lowered_mod = tvm.lower(ir_mod)
+ tvm.ir.assert_structural_equal(lowered_mod, LoweredModule())
+ # check building
+ mod = tvm.build(ir_mod, target="llvm")
+ _check_module_with_numpy(mod)
+
+
+def test_lower_build_lowered_module():
+ # check lowering
+ ir_mod = tvm.lower(LoweredModule())
+ tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+ # check building
+ mod = tvm.build(ir_mod, target="llvm")
+ _check_module_with_numpy(mod)
+
+
+if __name__ == "__main__":
+ test_lower_build_te_schedule()
+ test_lower_build_tir_func()
+ test_lower_build_tir_module()
+ test_lower_build_lowered_module()
diff --git
a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
index 7ac61a7..36fd80f 100644
--- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
+++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
@@ -64,6 +64,12 @@ def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None:
C[vi, vj] = B[vi, vj]
[email protected]
+def lca_is_func_root(a: ty.handle) -> None:
+ A = tir.match_buffer(a, [0, 0], "float32")
+ A.data[0] = 1.0
+
+
def test_buffer_load_store():
func = buffer_load_store_func
A, B = [func.buffer_map[x] for x in func.params]
@@ -102,6 +108,14 @@ def test_opaque_access():
assert lca[C] == root_block.body[1].body.body.block
+def test_lca_func_root():
+ func = lca_is_func_root
+ (A,) = [func.buffer_map[x] for x in func.params]
+ lca = tir.analysis.detect_buffer_access_lca(func)
+ assert lca[A] is None
+
+
if __name__ == "__main__":
test_buffer_load_store()
test_opaque_access()
+ test_lca_func_root()