This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 29e958d [TIR][TVMScript] specialize (#8354)
29e958d is described below
commit 29e958d179465ad04ea1e33c333e1b1f043f8683
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Jul 2 02:46:49 2021 +0800
[TIR][TVMScript] specialize (#8354)
---
include/tvm/tir/analysis.h | 2 +-
include/tvm/tir/buffer.h | 1 +
include/tvm/tir/function.h | 38 +++
python/tvm/tir/function.py | 55 ++++-
src/tir/ir/specialize.cc | 337 +++++++++++++++++++++++++++
tests/python/unittest/test_tir_specialize.py | 199 ++++++++++++++++
6 files changed, 630 insertions(+), 2 deletions(-)
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 262ac68..63d6fa3 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -96,7 +96,7 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
/*!
- * \brief Whether e expression used any var in variable set..
+ * \brief Whether e expression used any var in variable set.
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index a01d69b..017f4f7 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -183,6 +183,7 @@ class Buffer : public ObjectRef {
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode);
};
/*!
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 97ee7f7..25ed2f9 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -188,6 +188,44 @@ class LinkedParam : public ObjectRef {
};
/*!
+ * \brief Specialize parameters of PrimFunc.
+ * \param func The PrimFunc to be specialized.
+ * \param param_map The mapping from function params to the instance.
+ * \return The new function with parameter specialized.
+ * \note We can define a Meta TIR function with symbolic shape:
+ *
+ * \code
+ * @tvm.script.tir
+ * def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
+ * A = tir.match_buffer(a, (m, n), "float32")
+ * B = tir.match_buffer(b, (m, n), "float32")
+ *
+ * with tir.block([m, n], "") as [vi, vj]:
+ * B[vi, vj] = A[vi, vj]
+ * \endcode
+ *
+ * Then we can make it specialized with given shapes or buffers.
+ *
+ * \code
+ * a, _, m, n = mem_copy.params
+ * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
+ * # or
+ * func = mem_copy.specialize({n: 16, m: 16})
+ * \endcode
+ *
+ * \code {.language-id}
+ * @tvm.script.tir
+ * def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
+ * A = tir.match_buffer(a, (16, 16), "float32")
+ * B = tir.match_buffer(b, (16, 16), "float32")
+ *
+ * with tir.block([16, 16], "") as [vi, vj]:
+ * B[vi, vj] = A[vi, vj]
+ * \endcode
+ */
+PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
+
+/*!
* \brief PrimFunc specific attribute names.
*
* \sa tvm::attr
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 79d18d8..b1081d4 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -16,12 +16,14 @@
# under the License.
"""Function data types."""
+from typing import Mapping, Union
+
import tvm._ffi
import tvm.runtime
from tvm.runtime import Object
from tvm.ir import BaseFunc
from .buffer import Buffer
-from .expr import Var
+from .expr import Var, PrimExpr
from . import _ffi_api
@@ -85,3 +87,54 @@ class PrimFunc(BaseFunc):
The created new function.
"""
return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map,
self.attrs, span)
+
+ def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
+ """Specialize parameters of PrimFunc
+
+ Parameters
+ ----------
+
+ param_map : Mapping[Var, Union[PrimExpr, Buffer]]
+ The mapping from function params to the instance
+
+ Examples
+ --------
+ We can define a Meta TIR function with symbolic shape:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32)
-> None:
+ A = tir.match_buffer(a, (m, n), "float32")
+ B = tir.match_buffer(b, (m, n), "float32")
+
+ with tir.block([m, n], "") as [vi, vj]:
+ B[vi, vj] = A[vi, vj]
+
+ Then we can make it specialized with given shapes or buffers.
+
+ .. code-block:: python
+
+ a, _, m, n = mem_copy.params
+ func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
+ # or
+ func = mem_copy.specialize({n: 16, m: 16})
+
+ The specialized function:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ B = tir.match_buffer(b, (16, 16), "float32")
+
+ with tir.block([16, 16], "") as [vi, vj]:
+ B[vi, vj] = A[vi, vj]
+
+ Returns
+ -------
+ func : PrimFunc
+ The new function with parameter specialized
+ """
+ return _ffi_api.Specialize(self, param_map)
diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
new file mode 100644
index 0000000..aa5f271
--- /dev/null
+++ b/src/tir/ir/specialize.cc
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash,
ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function
parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+ return std::any_of(func->params.begin(), func->params.end(),
+ [&](const Var& var) { return var.same_as(param); });
+}
+
+/**************** Specializer ****************/
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+ explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {}
+
+ static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+ PrimFuncSpecializer specializer(var_map);
+ // Updating Buffer map
+ Map<Var, Buffer> buffer_map;
+ bool buffer_map_updated = false;
+ for (const auto& it : f->buffer_map) {
+ const Var& var = it.first;
+ const Buffer& buffer = it.second;
+ Buffer new_buffer = specializer.MutateBuffer(buffer);
+ buffer_map.Set(var, new_buffer);
+ if (!new_buffer.same_as(buffer)) {
+ buffer_map_updated = true;
+ specializer.buffer_map_[buffer] = new_buffer;
+ }
+ }
+
+ // Updating parmeters
+ Array<Var> params;
+ bool param_updated = false;
+ for (const auto& var : f->params) {
+ // Remove parmeters which has been specialized.
+ if (var_map.find(var) == var_map.end()) {
+ params.push_back(var);
+ } else {
+ param_updated = true;
+ }
+ }
+
+ // Updating function body
+ Stmt body = specializer(f->body);
+
+ if (param_updated || buffer_map_updated || !f->body.same_as(body)) {
+ PrimFuncNode* f_ptr = f.CopyOnWrite();
+ f_ptr->params = std::move(params);
+ f_ptr->buffer_map = std::move(buffer_map);
+ f_ptr->body = std::move(body);
+ }
+ return f;
+ }
+
+ private:
+ Stmt VisitStmt_(const BlockNode* op) final {
+ // Step.0. Define buffer mappings which is allocated inside the block
+ Array<Buffer> alloc_buffers = MutateArray(
+ op->alloc_buffers,
+ std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this,
std::placeholders::_1));
+
+ // Step.1. Recursively visit block body
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<BlockNode>();
+ ICHECK(op != nullptr);
+
+ Array<BufferRegion> reads = MutateArray(
+ op->reads,
+ std::bind(&PrimFuncSpecializer::MutateBufferRegion, this,
std::placeholders::_1));
+ Array<BufferRegion> writes = MutateArray(
+ op->writes,
+ std::bind(&PrimFuncSpecializer::MutateBufferRegion, this,
std::placeholders::_1));
+
+ if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) {
+ return GetRef<Block>(op);
+ } else {
+ ObjectPtr<BlockNode> n = CopyOnWrite(op);
+ n->alloc_buffers = std::move(alloc_buffers);
+ n->reads = std::move(reads);
+ n->writes = std::move(writes);
+ return Stmt(n);
+ }
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ Stmt stmt = StmtExprMutator::VisitStmt_(op);
+ op = stmt.as<BufferStoreNode>();
+ ICHECK(op != nullptr);
+ auto it = buffer_map_.find(op->buffer);
+ if (it == buffer_map_.end()) {
+ return GetRef<BufferStore>(op);
+ } else {
+ auto n = CopyOnWrite(op);
+ n->buffer = it->second;
+ return Stmt(n);
+ }
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+ op = expr.as<BufferLoadNode>();
+ ICHECK(op != nullptr);
+ auto it = buffer_map_.find(op->buffer);
+ if (it == buffer_map_.end()) {
+ return GetRef<BufferLoad>(op);
+ } else {
+ auto n = make_object<BufferLoadNode>(*op);
+ n->buffer = it->second;
+ return PrimExpr(n);
+ }
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ auto it = var_map_.find(GetRef<Var>(op));
+ if (it == var_map_.end()) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return it->second;
+ }
+ }
+
+ private:
+ Buffer MutateBuffer(const Buffer& buffer) const {
+ Array<PrimExpr> shape =
+ MutateArray(buffer->shape, [this](const PrimExpr& e) { return
Substitute(e, var_map_); });
+ Array<PrimExpr> strides =
+ MutateArray(buffer->strides, [this](const PrimExpr& e) { return
Substitute(e, var_map_); });
+
+ PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_);
+
+ if (buffer->elem_offset.same_as(elem_offset) &&
buffer->shape.same_as(shape) &&
+ buffer->strides.same_as(strides)) {
+ return buffer;
+ } else {
+ auto n = make_object<BufferNode>(*buffer.get());
+ n->elem_offset = std::move(elem_offset);
+ n->shape = std::move(shape);
+ n->strides = std::move(strides);
+ return Buffer(n);
+ }
+ }
+
+ Range MutateRange(const Range& range) {
+ PrimExpr min = this->VisitExpr(range->min);
+ PrimExpr extent = this->VisitExpr(range->extent);
+ if (min.same_as(range->min) && extent.same_as(range->extent)) {
+ return range;
+ } else {
+ return Range::FromMinExtent(std::move(min), std::move(extent));
+ }
+ }
+
+ Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
+ Buffer buf = MutateBuffer(alloc_buf);
+ if (buf.same_as(alloc_buf)) {
+ return alloc_buf;
+ } else {
+ ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end());
+ buffer_map_[alloc_buf] = buf;
+ return buf;
+ }
+ }
+
+ BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
+ auto it = buffer_map_.find(buffer_region->buffer);
+ Array<Range> region =
+ MutateArray(buffer_region->region,
+ std::bind(&PrimFuncSpecializer::MutateRange, this,
std::placeholders::_1));
+ if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
+ return buffer_region;
+ } else {
+ return BufferRegion(it->second, std::move(region));
+ }
+ }
+
+ private:
+ /*! \brief The vars to be substitute and their values */
+ const VarMap& var_map_;
+ /*! \brief map from old buffer to mutated buffer */
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>
buffer_map_;
+};
+
+/*!
+ * \brief Update Specialize var map with buffer matching.
+ * \param func The function to be specialized.
+ * \param param The given function parameter
+ * \param specific_buf The matching buffer.
+ * \param var_map The var mapping to be updated.
+ * \note This function will match target buffer's shape, strides and
element_offset
+ * For example, we define a buffer in PrimFunc:
+ * A = tir.match_buffer(a, [m, n])
+ *
+ * Then we match it with a buffer B = tir.decl_buffer((8, 16))
+ *
+ * It means we have two var mappings here: m = 8 and n = 16
+ *
+ * If the buffer signature is not a Var, the mapping will fail.
+ * e.g. A = tir.match_buffer(a, [m * 2, n + 1])
+ */
+void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const
Buffer& specific_buf,
+ VarMap* var_map) {
+ // preliminaries
+ tir::ExprDeepEqual equal;
+
+ auto it = func->buffer_map.find(param);
+ CHECK(it != func->buffer_map.end())
+ << "ValueError: specialize expects param to be in PrimFunc's buffer_map";
+ const Buffer& buf_to_specialize = (*it).second;
+
+ // build var mapping using specific_buf's parameters
+ auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr&
old_expr) {
+ if (!equal(new_expr, old_expr)) {
+ CHECK(old_expr->IsInstance<VarNode>())
+ << "TypeError: The signature of target buffer exprected an
independent Var, but got "
+ << old_expr << ".";
+ const Var& var = Downcast<Var>(old_expr);
+ auto it = var_map->find(var);
+ if (it != var_map->end()) {
+ CHECK(equal(it->second, new_expr))
+ << "ValueError: The assigned value of var " << var << "
mismatched. " << it->second
+ << " vs. " << new_expr << ".";
+ } else {
+ (*var_map)[var] = new_expr;
+ }
+ }
+ };
+
+ // Check buffer dimensions
+ CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size())
+ << "ValueError: The buffer dimensions mismatched" <<
buf_to_specialize->shape.size()
+ << " vs. " << specific_buf->shape.size() << ".";
+
+ CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size())
+ << "ValueError: The buffer strides dimensions mismatched" <<
buf_to_specialize->strides.size()
+ << " vs. " << specific_buf->strides.size() << ".";
+
+ // Updating var mapping using specific_expr
+ for (size_t i = 0; i < specific_buf->shape.size(); ++i) {
+ build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]);
+ }
+ for (size_t i = 0; i < specific_buf->strides.size(); ++i) {
+ build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]);
+ }
+ build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset);
+
+ // Check data_alignment and offset_factor.
+ // These two signatures are int, so we do not need map them.
+ CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment)
+ << "ValueError: The buffer data_alignment mismatched" <<
buf_to_specialize->data_alignment
+ << " vs. " << specific_buf->data_alignment << ".";
+
+ CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor)
+ << "ValueError: The buffer offset_factor mismatched" <<
buf_to_specialize->offset_factor
+ << " vs. " << specific_buf->offset_factor << ".";
+}
+
+/*!
+ * \brief Update Specialize var map with parameter value.
+ * \param func The function to be specialized.
+ * \param param The given function parameter
+ * \param specific_expr The parameter value.
+ * \param var_map The var mapping to be updated.
+ */
+void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const
PrimExpr& specific_expr,
+ VarMap* var_map) {
+ // check param is in PrimFunc's parameters
+ CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be
in PrimFunc's params";
+ // specialize a param not in buffer_map
+ CHECK_EQ(func->buffer_map.count(param), 0)
+ << "ValueError: Specialize expects param to not be in PrimFunc's
buffer_map";
+ // build var mapping using specific_expr
+ (*var_map)[param] = specific_expr;
+}
+
+/**************** Implementation ****************/
+
+PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) {
+ VarMap var_map;
+ for (const auto& kv : param_map) {
+ const Var& param = kv.first;
+ const ObjectRef& instance = kv.second;
+ if (instance->IsInstance<BufferNode>()) {
+ UpdateSpecializeVarMap(func, param, Downcast<Buffer>(instance),
&var_map);
+ } else if (instance->IsInstance<PrimExprNode>()) {
+ UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance),
&var_map);
+ } else {
+ LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or
PrimExpr, but got "
+ << instance->GetTypeKey();
+ }
+ }
+ return PrimFuncSpecializer::Specialize(func, std::move(var_map));
+}
+
+/**************** FFI ****************/
+
+TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize);
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_tir_specialize.py
b/tests/python/unittest/test_tir_specialize.py
new file mode 100644
index 0000000..2e9f111
--- /dev/null
+++ b/tests/python/unittest/test_tir_specialize.py
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring, missing-module-docstring
+
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
[email protected]
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle, n: ty.int32) -> None:
+ m = tir.var("int32")
+ A = tir.match_buffer(a, [m, n])
+ B = tir.match_buffer(b, [m, n])
+ C = tir.match_buffer(c, [m, m])
+
+ with tir.block([m, m, tir.reduce_axis(0, n)], "update") as [vi, vj, vk]:
+ with tir.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def matmul_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+
+ with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj,
vk]:
+ with tir.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ m = tir.var("int32")
+ A = tir.match_buffer(a, [m, 128])
+ B = tir.match_buffer(b, [m, 128])
+ C = tir.match_buffer(c, [m, m])
+
+ with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+ with tir.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ x = tir.var("int32")
+ m = tir.var("int32")
+ A = tir.match_buffer(a, [m, x * 8])
+ B = tir.match_buffer(b, [m, x * 8])
+ C = tir.match_buffer(c, [m, m])
+
+ with tir.block([m, m, tir.reduce_axis(0, x * 8)], "update") as [vi, vj,
vk]:
+ with tir.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def element_wise(a: ty.handle, c: ty.handle) -> None:
+ m = tir.var("int32")
+ n = tir.var("int32")
+ A = tir.match_buffer(a, (m, n), "float32")
+ C = tir.match_buffer(c, (m, n), "float32")
+
+ B = tir.alloc_buffer((m, n), "float32")
+
+ with tir.block([m, n], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+
+ with tir.block([m, n], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def element_wise_128_64(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 64), "float32")
+ C = tir.match_buffer(c, (128, 64), "float32")
+ B = tir.alloc_buffer((128, 64), "float32")
+
+ with tir.block([128, 64], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+
+ with tir.block([128, 64], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def element_wise_128_n(a: ty.handle, c: ty.handle) -> None:
+ n = tir.var("int32")
+ A = tir.match_buffer(a, (128, n), "float32")
+ C = tir.match_buffer(c, (128, n), "float32")
+ B = tir.alloc_buffer((128, n), "float32")
+
+ with tir.block([128, n], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+
+ with tir.block([128, n], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def mem_copy(
+ a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32, q:
ty.int32
+) -> None:
+ A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q)
+ B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q)
+
+ with tir.block([m, n], "") as [vi, vj]:
+ B[vi, vj] = A[vi, vj]
+
+
[email protected]
+def mem_copy_16_16_8_4(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4)
+ B = tir.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4)
+
+ with tir.block([16, 16], "") as [vi, vj]:
+ B[vi, vj] = A[vi, vj]
+
+
[email protected]
+def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p:
ty.int32) -> None:
+ A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n)
+ B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n)
+
+ with tir.block([m, n], "") as [vi, vj]:
+ B[vi, vj] = A[vi, vj]
+
+
+def test_specialize_nothing():
+ func = matmul.specialize({})
+ assert func.same_as(matmul) # Pointer the same
+
+
+def test_specialize_matmul():
+ a, _, _, n = matmul.params
+ # fully specialized
+ func = matmul.specialize({a: tir.decl_buffer((128, 128))})
+ tvm.ir.assert_structural_equal(func, matmul_128)
+ # partially specialized
+ func = matmul.specialize({n: 128})
+ tvm.ir.assert_structural_equal(func, matmul_m_128)
+ # symbolic specialized
+ func = matmul.specialize({n: tir.Var("x", "int32") * 8})
+ tvm.ir.assert_structural_equal(func, matmul_m_8x)
+
+
+def test_specialize_elemwise():
+ a, c = element_wise.params
+ C = element_wise.buffer_map[c]
+ # fully specialized
+ func = element_wise.specialize({a: tir.decl_buffer((128, 64))})
+ tvm.ir.assert_structural_equal(func, element_wise_128_64)
+ # partially specialized
+ func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))})
+ tvm.ir.assert_structural_equal(func, element_wise_128_n)
+
+
+def test_specialize_mem_copy():
+ a, _, m, n, p, q = mem_copy.params
+ # fully specialized
+ func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1],
elem_offset=4)})
+ tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4)
+ func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4})
+ tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4)
+ # partially specialized
+ func = mem_copy.specialize({q: n})
+ tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n)
+
+
+def test_specialize_recursive_load():
+ # TODO(Siyuan): add recursive Load testcase, e.g. A[C[i]]
+ pass
+
+
+if __name__ == "__main__":
+ test_specialize_nothing()
+ test_specialize_matmul()
+ test_specialize_elemwise()
+ test_specialize_mem_copy()
+ test_specialize_recursive_load()