This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 29e958d  [TIR][TVMScript] specialize (#8354)
29e958d is described below

commit 29e958d179465ad04ea1e33c333e1b1f043f8683
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Jul 2 02:46:49 2021 +0800

    [TIR][TVMScript] specialize (#8354)
---
 include/tvm/tir/analysis.h                   |   2 +-
 include/tvm/tir/buffer.h                     |   1 +
 include/tvm/tir/function.h                   |  38 +++
 python/tvm/tir/function.py                   |  55 ++++-
 src/tir/ir/specialize.cc                     | 337 +++++++++++++++++++++++++++
 tests/python/unittest/test_tir_specialize.py | 199 ++++++++++++++++
 6 files changed, 630 insertions(+), 2 deletions(-)

diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 262ac68..63d6fa3 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -96,7 +96,7 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
 
 /*!
- * \brief Whether e expression used any var in variable set..
+ * \brief Whether e expression used any var in variable set.
  * \param expr The expression to be checked.
  * \param vset_contains The check function to see if var is in the vset.
  * \return Whether e uses vset.
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index a01d69b..017f4f7 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -183,6 +183,7 @@ class Buffer : public ObjectRef {
   TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
 
   TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode);
 };
 
 /*!
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 97ee7f7..25ed2f9 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -188,6 +188,44 @@ class LinkedParam : public ObjectRef {
 };
 
 /*!
+ * \brief Specialize parameters of PrimFunc.
+ * \param func The PrimFunc to be specialized.
+ * \param param_map The mapping from function params to the instance.
+ * \return The new function with parameter specialized.
+ * \note We can define a Meta TIR function with symbolic shape:
+ *
+ * \code
+ *  @tvm.script.tir
+ *  def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
+ *      A = tir.match_buffer(a, (m, n), "float32")
+ *      B = tir.match_buffer(b, (m, n), "float32")
+ *
+ *      with tir.block([m, n], "") as [vi, vj]:
+ *          B[vi, vj] = A[vi, vj]
+ * \endcode
+ *
+ * Then we can make it specialized with given shapes or buffers.
+ *
+ * \code
+ *  a, _, m, n = mem_copy.params
+ *  func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
+ *  # or
+ *  func = mem_copy.specialize({n: 16, m: 16})
+ * \endcode
+ *
+ * \code {.language-id}
+ *  @tvm.script.tir
+ *  def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
+ *      A = tir.match_buffer(a, (16, 16), "float32")
+ *      B = tir.match_buffer(b, (16, 16), "float32")
+ *
+ *      with tir.block([16, 16], "") as [vi, vj]:
+ *          B[vi, vj] = A[vi, vj]
+ * \endcode
+ */
+PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
+
+/*!
  * \brief PrimFunc specific attribute names.
  *
  * \sa tvm::attr
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 79d18d8..b1081d4 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -16,12 +16,14 @@
 # under the License.
 """Function data types."""
 
+from typing import Mapping, Union
+
 import tvm._ffi
 import tvm.runtime
 from tvm.runtime import Object
 from tvm.ir import BaseFunc
 from .buffer import Buffer
-from .expr import Var
+from .expr import Var, PrimExpr
 from . import _ffi_api
 
 
@@ -85,3 +87,54 @@ class PrimFunc(BaseFunc):
             The created new function.
         """
         return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, 
self.attrs, span)
+
+    def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
+        """Specialize parameters of PrimFunc
+
+        Parameters
+        ----------
+
+        param_map : Mapping[Var, Union[PrimExpr, Buffer]]
+            The mapping from function params to the instance
+
+        Examples
+        --------
+        We can define a Meta TIR function with symbolic shape:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) 
-> None:
+                A = tir.match_buffer(a, (m, n), "float32")
+                B = tir.match_buffer(b, (m, n), "float32")
+
+                with tir.block([m, n], "") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj]
+
+        Then we can make it specialized with given shapes or buffers.
+
+        .. code-block:: python
+
+            a, _, m, n = mem_copy.params
+            func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
+            # or
+            func = mem_copy.specialize({n: 16, m: 16})
+
+        The specialized function:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (16, 16), "float32")
+                B = tir.match_buffer(b, (16, 16), "float32")
+
+                with tir.block([16, 16], "") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj]
+
+        Returns
+        -------
+        func : PrimFunc
+            The new function with parameter specialized
+        """
+        return _ffi_api.Specialize(self, param_map)
diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
new file mode 100644
index 0000000..aa5f271
--- /dev/null
+++ b/src/tir/ir/specialize.cc
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, 
ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function 
parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/**************** Specializer ****************/
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    bool buffer_map_updated = false;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        buffer_map_updated = true;
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    bool param_updated = false;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      } else {
+        param_updated = true;
+      }
+    }
+
+    // Updating function body
+    Stmt body = specializer(f->body);
+
+    if (param_updated || buffer_map_updated || !f->body.same_as(body)) {
+      PrimFuncNode* f_ptr = f.CopyOnWrite();
+      f_ptr->params = std::move(params);
+      f_ptr->buffer_map = std::move(buffer_map);
+      f_ptr->body = std::move(body);
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    // Step.0. Define buffer mappings which is allocated inside the block
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, 
std::placeholders::_1));
+
+    // Step.1. Recursively visit block body
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BlockNode>();
+    ICHECK(op != nullptr);
+
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, 
std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, 
std::placeholders::_1));
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op != nullptr);
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    } else {
+      auto n = CopyOnWrite(op);
+      n->buffer = it->second;
+      return Stmt(n);
+    }
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op != nullptr);
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    } else {
+      auto n = make_object<BufferLoadNode>(*op);
+      n->buffer = it->second;
+      return PrimExpr(n);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(const Buffer& buffer) const {
+    Array<PrimExpr> shape =
+        MutateArray(buffer->shape, [this](const PrimExpr& e) { return 
Substitute(e, var_map_); });
+    Array<PrimExpr> strides =
+        MutateArray(buffer->strides, [this](const PrimExpr& e) { return 
Substitute(e, var_map_); });
+
+    PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_);
+
+    if (buffer->elem_offset.same_as(elem_offset) && 
buffer->shape.same_as(shape) &&
+        buffer->strides.same_as(strides)) {
+      return buffer;
+    } else {
+      auto n = make_object<BufferNode>(*buffer.get());
+      n->elem_offset = std::move(elem_offset);
+      n->shape = std::move(shape);
+      n->strides = std::move(strides);
+      return Buffer(n);
+    }
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      return Range::FromMinExtent(std::move(min), std::move(extent));
+    }
+  }
+
+  Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
+    Buffer buf = MutateBuffer(alloc_buf);
+    if (buf.same_as(alloc_buf)) {
+      return alloc_buf;
+    } else {
+      ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end());
+      buffer_map_[alloc_buf] = buf;
+      return buf;
+    }
+  }
+
+  BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
+    auto it = buffer_map_.find(buffer_region->buffer);
+    Array<Range> region =
+        MutateArray(buffer_region->region,
+                    std::bind(&PrimFuncSpecializer::MutateRange, this, 
std::placeholders::_1));
+    if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
+      return buffer_region;
+    } else {
+      return BufferRegion(it->second, std::move(region));
+    }
+  }
+
+ private:
+  /*! \brief The vars to be substitute and their values */
+  const VarMap& var_map_;
+  /*! \brief map from old buffer to mutated buffer */
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> 
buffer_map_;
+};
+
+/*!
+ * \brief Update Specialize var map with buffer matching.
+ * \param func The function to be specialized.
+ * \param param The given function parameter
+ * \param specific_buf The matching buffer.
+ * \param var_map The var mapping to be updated.
+ * \note This function will match target buffer's shape, strides and 
element_offset
+ *   For example, we define a buffer in PrimFunc:
+ *   A = tir.match_buffer(a, [m, n])
+ *
+ *   Then we match it with a buffer B =  tir.decl_buffer((8, 16))
+ *
+ *   It means we have two var mappings here: m = 8 and n = 16
+ *
+ *   If the buffer signature is not a Var, the mapping will fail.
+ *   e.g. A = tir.match_buffer(a, [m * 2, n + 1])
+ */
+void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const 
Buffer& specific_buf,
+                            VarMap* var_map) {
+  // preliminaries
+  tir::ExprDeepEqual equal;
+
+  auto it = func->buffer_map.find(param);
+  CHECK(it != func->buffer_map.end())
+      << "ValueError: specialize expects param to be in PrimFunc's buffer_map";
+  const Buffer& buf_to_specialize = (*it).second;
+
+  // build var mapping using specific_buf's parameters
+  auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& 
old_expr) {
+    if (!equal(new_expr, old_expr)) {
+      CHECK(old_expr->IsInstance<VarNode>())
+          << "TypeError: The signature of target buffer exprected an 
independent Var, but got "
+          << old_expr << ".";
+      const Var& var = Downcast<Var>(old_expr);
+      auto it = var_map->find(var);
+      if (it != var_map->end()) {
+        CHECK(equal(it->second, new_expr))
+            << "ValueError: The assigned value of var " << var << " 
mismatched. " << it->second
+            << " vs. " << new_expr << ".";
+      } else {
+        (*var_map)[var] = new_expr;
+      }
+    }
+  };
+
+  // Check buffer dimensions
+  CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size())
+      << "ValueError: The buffer dimensions mismatched" << 
buf_to_specialize->shape.size()
+      << " vs. " << specific_buf->shape.size() << ".";
+
+  CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size())
+      << "ValueError: The buffer strides dimensions mismatched" << 
buf_to_specialize->strides.size()
+      << " vs. " << specific_buf->strides.size() << ".";
+
+  // Updating var mapping using specific_expr
+  for (size_t i = 0; i < specific_buf->shape.size(); ++i) {
+    build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]);
+  }
+  for (size_t i = 0; i < specific_buf->strides.size(); ++i) {
+    build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]);
+  }
+  build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset);
+
+  // Check data_alignment and offset_factor.
+  // These two signatures are int, so we do not need map them.
+  CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment)
+      << "ValueError: The buffer data_alignment mismatched" << 
buf_to_specialize->data_alignment
+      << " vs. " << specific_buf->data_alignment << ".";
+
+  CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor)
+      << "ValueError: The buffer offset_factor mismatched" << 
buf_to_specialize->offset_factor
+      << " vs. " << specific_buf->offset_factor << ".";
+}
+
+/*!
+ * \brief Update Specialize var map with parameter value.
+ * \param func The function to be specialized.
+ * \param param The given function parameter
+ * \param specific_expr The parameter value.
+ * \param var_map The var mapping to be updated.
+ */
+void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const 
PrimExpr& specific_expr,
+                            VarMap* var_map) {
+  // check param is in PrimFunc's parameters
+  CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be 
in PrimFunc's params";
+  // specialize a param not in buffer_map
+  CHECK_EQ(func->buffer_map.count(param), 0)
+      << "ValueError: Specialize expects param to not be in PrimFunc's 
buffer_map";
+  // build var mapping using specific_expr
+  (*var_map)[param] = specific_expr;
+}
+
+/**************** Implementation ****************/
+
+PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) {
+  VarMap var_map;
+  for (const auto& kv : param_map) {
+    const Var& param = kv.first;
+    const ObjectRef& instance = kv.second;
+    if (instance->IsInstance<BufferNode>()) {
+      UpdateSpecializeVarMap(func, param, Downcast<Buffer>(instance), 
&var_map);
+    } else if (instance->IsInstance<PrimExprNode>()) {
+      UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance), 
&var_map);
+    } else {
+      LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or 
PrimExpr, but got "
+                 << instance->GetTypeKey();
+    }
+  }
+  return PrimFuncSpecializer::Specialize(func, std::move(var_map));
+}
+
+/**************** FFI ****************/
+
+TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_specialize.py 
b/tests/python/unittest/test_tir_specialize.py
new file mode 100644
index 0000000..2e9f111
--- /dev/null
+++ b/tests/python/unittest/test_tir_specialize.py
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring, missing-module-docstring
+
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
[email protected]
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle, n: ty.int32) -> None:
+    m = tir.var("int32")
+    A = tir.match_buffer(a, [m, n])
+    B = tir.match_buffer(b, [m, n])
+    C = tir.match_buffer(c, [m, m])
+
+    with tir.block([m, m, tir.reduce_axis(0, n)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def matmul_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, [128, 128])
+    B = tir.match_buffer(b, [128, 128])
+    C = tir.match_buffer(c, [128, 128])
+
+    with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, 
vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    m = tir.var("int32")
+    A = tir.match_buffer(a, [m, 128])
+    B = tir.match_buffer(b, [m, 128])
+    C = tir.match_buffer(c, [m, m])
+
+    with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    x = tir.var("int32")
+    m = tir.var("int32")
+    A = tir.match_buffer(a, [m, x * 8])
+    B = tir.match_buffer(b, [m, x * 8])
+    C = tir.match_buffer(c, [m, m])
+
+    with tir.block([m, m, tir.reduce_axis(0, x * 8)], "update") as [vi, vj, 
vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def element_wise(a: ty.handle, c: ty.handle) -> None:
+    m = tir.var("int32")
+    n = tir.var("int32")
+    A = tir.match_buffer(a, (m, n), "float32")
+    C = tir.match_buffer(c, (m, n), "float32")
+
+    B = tir.alloc_buffer((m, n), "float32")
+
+    with tir.block([m, n], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+    with tir.block([m, n], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def element_wise_128_64(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 64), "float32")
+    C = tir.match_buffer(c, (128, 64), "float32")
+    B = tir.alloc_buffer((128, 64), "float32")
+
+    with tir.block([128, 64], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+    with tir.block([128, 64], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def element_wise_128_n(a: ty.handle, c: ty.handle) -> None:
+    n = tir.var("int32")
+    A = tir.match_buffer(a, (128, n), "float32")
+    C = tir.match_buffer(c, (128, n), "float32")
+    B = tir.alloc_buffer((128, n), "float32")
+
+    with tir.block([128, n], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+    with tir.block([128, n], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]
+def mem_copy(
+    a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32, q: 
ty.int32
+) -> None:
+    A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q)
+    B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q)
+
+    with tir.block([m, n], "") as [vi, vj]:
+        B[vi, vj] = A[vi, vj]
+
+
[email protected]
+def mem_copy_16_16_8_4(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4)
+    B = tir.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4)
+
+    with tir.block([16, 16], "") as [vi, vj]:
+        B[vi, vj] = A[vi, vj]
+
+
[email protected]
+def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: 
ty.int32) -> None:
+    A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n)
+    B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n)
+
+    with tir.block([m, n], "") as [vi, vj]:
+        B[vi, vj] = A[vi, vj]
+
+
+def test_specialize_nothing():
+    func = matmul.specialize({})
+    assert func.same_as(matmul)  # Pointer the same
+
+
+def test_specialize_matmul():
+    a, _, _, n = matmul.params
+    # fully specialized
+    func = matmul.specialize({a: tir.decl_buffer((128, 128))})
+    tvm.ir.assert_structural_equal(func, matmul_128)
+    # partially specialized
+    func = matmul.specialize({n: 128})
+    tvm.ir.assert_structural_equal(func, matmul_m_128)
+    # symbolic specialized
+    func = matmul.specialize({n: tir.Var("x", "int32") * 8})
+    tvm.ir.assert_structural_equal(func, matmul_m_8x)
+
+
+def test_specialize_elemwise():
+    a, c = element_wise.params
+    C = element_wise.buffer_map[c]
+    # fully specialized
+    func = element_wise.specialize({a: tir.decl_buffer((128, 64))})
+    tvm.ir.assert_structural_equal(func, element_wise_128_64)
+    # partially specialized
+    func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))})
+    tvm.ir.assert_structural_equal(func, element_wise_128_n)
+
+
+def test_specialize_mem_copy():
+    a, _, m, n, p, q = mem_copy.params
+    # fully specialized
+    func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1], 
elem_offset=4)})
+    tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4)
+    func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4})
+    tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4)
+    # partially specialized
+    func = mem_copy.specialize({q: n})
+    tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n)
+
+
+def test_specialize_recursive_load():
+    # TODO(Siyuan): add recursive Load testcase, e.g. A[C[i]]
+    pass
+
+
+if __name__ == "__main__":
+    test_specialize_nothing()
+    test_specialize_matmul()
+    test_specialize_elemwise()
+    test_specialize_mem_copy()
+    test_specialize_recursive_load()

Reply via email to