This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ab8a106 [MetaSchedule] Add Per-Store-Feature (#9860)
ab8a106 is described below
commit ab8a106f6a62c582c8910fffa8b06245c16c9b70
Author: Junru Shao <[email protected]>
AuthorDate: Fri Jan 7 23:09:46 2022 -0800
[MetaSchedule] Add Per-Store-Feature (#9860)
* [MetaSchedule] Add Per-Store-Feature
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
* fix lint
* fix lint
* Update per_store_feature.py
* address comments
* fix lint
Co-authored-by: Xiyou Zhou <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
---
include/tvm/tir/stmt.h | 6 +-
python/tvm/meta_schedule/__init__.py | 2 +
.../meta_schedule/feature_extractor/__init__.py | 1 +
.../feature_extractor/per_store_feature.py | 60 +
.../feature_extractor/per_store_feature.cc | 1337 +++++++++++++++++
src/meta_schedule/utils.h | 4 +
src/tir/schedule/primitive/sampling.cc | 8 +-
src/tir/schedule/utils.h | 64 +-
src/tir/transforms/unify_thread_binding.cc | 13 +-
...schedule_feature_extractor_per_store_feature.py | 1555 ++++++++++++++++++++
10 files changed, 3034 insertions(+), 16 deletions(-)
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 0664967..d3fbf0f 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1251,7 +1251,7 @@ constexpr const char* extern_scope = "extern_scope";
* This can hint some code generator to create a new function for compute.
*/
constexpr const char* compute_scope = "compute_scope";
-/*! \brief Mark storage alignement requirement of buffers */
+/*! \brief Mark storage alignment requirement of buffers */
constexpr const char* storage_alignment = "storage_alignment";
/*! \brief Mark storage scope of realization */
constexpr const char* realize_scope = "realize_scope";
@@ -1263,6 +1263,10 @@ constexpr const char* device_type = "device_type";
constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */
constexpr const char* reduce_scope = "reduce_scope";
+/*! \brief Pragma: auto-unroll, max_step */
+constexpr const char* pragma_auto_unroll_max_step =
"pragma_auto_unroll_max_step";
+/*! \brief Pragma: unroll explicit */
+constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
/*! \brief Mark region is guarded by the pragma extension */
constexpr const char* pragma_scope_prefix = "pragma_";
/*! \brief Import C source or file into the final code gen module */
diff --git a/python/tvm/meta_schedule/__init__.py
b/python/tvm/meta_schedule/__init__.py
index 8b6672c..e41e5b3 100644
--- a/python/tvm/meta_schedule/__init__.py
+++ b/python/tvm/meta_schedule/__init__.py
@@ -23,4 +23,6 @@ from . import space_generator
from . import search_strategy
from . import schedule_rule
from . import integration
+from . import feature_extractor
from .tune_context import TuneContext
+from .search_strategy import MeasureCandidate
diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py
b/python/tvm/meta_schedule/feature_extractor/__init__.py
index f29c44b..83ac742 100644
--- a/python/tvm/meta_schedule/feature_extractor/__init__.py
+++ b/python/tvm/meta_schedule/feature_extractor/__init__.py
@@ -20,4 +20,5 @@ Meta Schedule feature extractors that extracts features from
measure candidates for use in cost model.
"""
from .feature_extractor import FeatureExtractor, PyFeatureExtractor
+from .per_store_feature import PerStoreFeature
from .random_feature_extractor import RandomFeatureExtractor
diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
new file mode 100644
index 0000000..306934d
--- /dev/null
+++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""We extract one feature vector per BufferStoreNode statement in a TIR Stmt,
+so we call this feature as "per-store" feature.
+"""
+from tvm._ffi import register_object
+
+from .. import _ffi_api
+from .feature_extractor import FeatureExtractor
+
+
+@register_object("meta_schedule.PerStoreFeature")
+class PerStoreFeature(FeatureExtractor):
+ """PerStoreFeature extracts one feature vector per BufferStoreNode
+
+ Parameters
+ ----------
+ buffers_per_store : int
+ The number of buffers in each BufferStore; Pad or truncate if
necessary.
+ arith_intensity_curve_num_samples : int
+ The number of samples used in the arithmetic intensity curve.
+ cache_line_bytes : int
+ The number of bytes in a cache line.
+ """
+
+ buffers_per_store: int
+ """The number of buffers in each BufferStore; Pad or truncate if
necessary."""
+ arith_intensity_curve_num_samples: int # pylint: disable=invalid-name
+ """The number of samples used in the arithmetic intensity curve."""
+ cache_line_bytes: int
+ """The number of bytes in a cache line."""
+ feature_vector_length: int
+ """Length of the feature vector."""
+
+ def __init__(
+ self,
+ buffers_per_store: int = 5,
+ arith_intensity_curve_num_samples: int = 10,
+ cache_line_bytes: int = 64,
+ ):
+ self.__init_handle_by_constructor__(
+ _ffi_api.FeatureExtractorPerStoreFeature, # type: ignore #
pylint: disable=no-member
+ buffers_per_store,
+ arith_intensity_curve_num_samples,
+ cache_line_bytes,
+ )
diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc
b/src/meta_schedule/feature_extractor/per_store_feature.cc
new file mode 100644
index 0000000..42cd5d5
--- /dev/null
+++ b/src/meta_schedule/feature_extractor/per_store_feature.cc
@@ -0,0 +1,1337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/transform.h>
+
+#include <cmath>
+#include <memory>
+#include <numeric>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+using support::NDIntSet;
+
+/*! \brief Type for multi-dimensional index */
+using MultiIndex = std::vector<PrimExpr>;
+/*! \brief Vector of int64_t */
+using IntVec = std::vector<int64_t>;
+/*! \brief Vector of for loops */
+using ForVec = std::vector<const ForNode*>;
+
+/*!
+ * \brief An unordered_map for (for, buffer) => V
+ * \tparam V The value type
+ */
+template <class V>
+using ForBufferMap = std::unordered_map<const ForNode*,
std::unordered_map<const BufferNode*, V>>;
+
+/*! \brief Given x, compute log2(|x| + 1) */
+inline double slog(double x) { return x >= 0 ? std::log2(x + 1) : std::log2(-x
+ 1); }
+
+namespace utils {
+
+/*!
+ * \brief Get the shape of the buffer
+ * \param buffer The buffer
+ * \param analyzer The analyzer
+ * \return The shape of the buffer
+ */
+std::vector<int64_t> GetBufferShape(const Buffer& buffer, arith::Analyzer*
analyzer) {
+ int ndim = buffer->shape.size();
+ std::vector<int64_t> result;
+ result.reserve(ndim);
+ for (const PrimExpr& i : buffer->shape) {
+ if (const IntImmNode* int_imm = i.as<IntImmNode>()) {
+ result.push_back(int_imm->value);
+ continue;
+ }
+ arith::ConstIntBound bound = analyzer->const_int_bound(i);
+ if (0 <= bound->max_value && bound->max_value <
arith::ConstIntBound::kPosInf) {
+ result.push_back(bound->max_value);
+ } else {
+ result.push_back(1);
+ }
+ }
+ return result;
+}
+
+/*!
+ * \brief Given a loop, return its `pragma_auto_unroll_max_step` annotation if
it exists
+ * \param loop The loop to be checked
+ * \return The value of `pragma_auto_unroll_max_step` if it exists, or -1 if
it does not exist
+ */
+int64_t GetPragmaAutoUnroll(const ForNode* loop) {
+ if (Optional<IntImm> auto_unroll = GetAnn<IntImm>(loop,
tir::attr::pragma_auto_unroll_max_step)) {
+ return auto_unroll.value()->value;
+ }
+ return -1;
+}
+
+/*!
+ * \brief Given a list of loops, return the extent of the first loop if the
list is not empty,
+ * and the first loop has constant extent. Otherwise returns the default value
given
+ * \param loops The list of loops to be checked
+ * \param default_value The default value to be returned if the list is empty
or the first loop
+ * does not have constant extent
+ * \return The extent of the first loop if the list is not empty, or the first
loop has constant
+ * extent. Otherwise returns the default value
+ */
+int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) {
+ if (!loops.empty()) {
+ if (const int64_t* extent = GetLoopIntExtent(loops[0])) {
+ return *extent;
+ }
+ }
+ return default_value;
+}
+
+/*!
+ * \brief Relax each of the multi-indexing pattern according to the domains
bound in the analyzer,
+ * and then union them into a single region
+ * \param multi_index_pattern A list of multi-index pattern to be relaxed
+ * \param numel The size of the single region after union
+ * \param analyzer The analyzer that contains the domain information
+ * \return The relaxed and unioned region
+ */
+IntVec RelaxAndUnion(const std::vector<MultiIndex>& multi_indices, int64_t*
numel,
+ arith::Analyzer* analyzer) {
+ *numel = 1;
+ if (multi_indices.empty()) {
+ return {};
+ }
+ int n_indices = multi_indices.size();
+ int ndim = multi_indices[0].size();
+ IntVec access_shape(ndim, 0);
+ for (int i = 0; i < ndim; ++i) {
+ int64_t minimum = arith::ConstIntBound::kPosInf;
+ int64_t maximum = arith::ConstIntBound::kNegInf;
+ for (int j = 0; j < n_indices; ++j) {
+ arith::ConstIntBound bound =
analyzer->const_int_bound(multi_indices[j][i]);
+ minimum = std::min(minimum, bound->min_value);
+ maximum = std::max(maximum, bound->max_value);
+ }
+ *numel *= maximum - minimum + 1;
+ access_shape[i] = maximum - minimum + 1;
+ }
+ return access_shape;
+}
+
+/*!
+ * \brief Given a list of multi-index pattern, return the minimal stride of a
variable on it
+ * \param multi_indices The list of multi-index pattern
+ * \param buffer_stride The stride of the buffer
+ * \param var The variable to be checked
+ * \return The minimal stride of the variable on the multi-index pattern
+ */
+int64_t GetVarStride(const std::vector<MultiIndex>& multi_indices, const
IntVec& buffer_stride,
+ const Var& var) {
+ class CoefficientExtractor : private ExprVisitor {
+ public:
+ static int64_t Extract(const PrimExpr& expr, const Var& var) {
+ CoefficientExtractor extractor(var);
+ extractor.VisitExpr(expr);
+ return (extractor.visited_var && !extractor.visited_mul &&
!extractor.visited_add)
+ ? 1
+ : (extractor.visited_var ? extractor.stride : 0);
+ }
+
+ private:
+ explicit CoefficientExtractor(const Var& var)
+ : var(var), stride(0), visited_var(false), visited_add(false),
visited_mul(false) {}
+
+ void VisitExpr_(const MulNode* node) override {
+ ExprVisitor::VisitExpr_(node);
+ if (visited_var && !visited_add) {
+ if (const auto* a = node->a.as<IntImmNode>()) {
+ visited_mul = true;
+ stride = a->value;
+ } else if (const auto* b = node->b.as<IntImmNode>()) {
+ visited_mul = true;
+ stride = b->value;
+ }
+ }
+ }
+
+ void VisitExpr_(const AddNode* node) override {
+ ExprVisitor::VisitExpr_(node);
+ if (visited_var && !visited_mul) {
+ visited_add = true;
+ stride = 1;
+ }
+ }
+
+ void VisitExpr_(const VarNode* node) override {
+ if (node == var.get()) {
+ visited_var = true;
+ stride = 2;
+ }
+ }
+
+ const Var& var;
+ int64_t stride;
+ bool visited_var;
+ bool visited_add;
+ bool visited_mul;
+ };
+
+ constexpr int64_t kNotFound = std::numeric_limits<int64_t>::max();
+ int ndim = buffer_stride.size();
+ // Calculate the min stride possible
+ int64_t result = kNotFound;
+ for (const MultiIndex& multi_index : multi_indices) {
+ ICHECK_EQ(multi_index.size(), buffer_stride.size());
+ // Find the rightest dimension that contains the given variable
+ for (int i = ndim - 1; i >= 0; --i) {
+ int64_t coef = CoefficientExtractor::Extract(multi_index[i], var);
+ if (coef != 0) {
+ result = std::min(result, std::abs(coef) * buffer_stride[i]);
+ break;
+ }
+ }
+ }
+ return (result == kNotFound) ? 0 : result;
+}
+
+/*!
+ * \brief Converts a 2-dimensional STL vector to a TVM NDArray
+ * \param src The source 2-dimensional STL vector
+ * \return The converted TVM NDArray
+ */
+runtime::NDArray AsNDArray(const std::vector<std::vector<double>>& src) {
+ ICHECK(!src.empty());
+ int n = src.size();
+ int m = src[0].size();
+ runtime::NDArray tgt = runtime::NDArray::Empty(
+ /*shape=*/{n, m},
+ /*dtype=*/DLDataType{kDLFloat, 64, 1},
+ /*ctx=*/DLDevice{kDLCPU, 0});
+ double* data = static_cast<double*>(tgt->data);
+ for (const std::vector<double>& row : src) {
+ for (double v : row) {
+ *data++ = v;
+ }
+ }
+ return tgt;
+}
+
+} // namespace utils
+
+namespace transform {
+
+/*!
+ * \brief Create a pass that simplifies the IR for feature extraction
+ * \return The pass created
+ */
+Pass SimplifyForFeatureExtraction() {
+ class Simplifier : private StmtExprMutator {
+ public:
+ static Stmt Run(Stmt stmt) { return Simplifier()(std::move(stmt)); }
+
+ private:
+ PrimExpr VisitExpr_(const SelectNode* node) final { return
make_const(node->dtype, 1.0); }
+
+ PrimExpr VisitExpr_(const VarNode* var) final {
+ if (unit_vars_.count(GetRef<Var>(var))) {
+ return make_const(var->dtype, 0.0);
+ }
+ return GetRef<Var>(var);
+ }
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ if (is_zero(loop->min) && is_one(loop->extent) && loop->kind ==
ForKind::kSerial &&
+ loop->annotations.empty()) {
+ unit_vars_.insert(loop->loop_var);
+ return VisitStmt(loop->body);
+ } else {
+ return StmtExprMutator::VisitStmt_(loop);
+ }
+ }
+
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> unit_vars_;
+ };
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ PrimFuncNode* n = f.CopyOnWrite();
+ n->body = Simplifier::Run(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.SimplifyForFeatureExtraction",
{});
+}
+
+/*!
+ * \brief Create a list of passes that preprocesses the IR for feature
extraction
+ * \return The list of passes created
+ */
+Sequential PassListForPerStoreFeature() {
+ return Sequential({
+ tir::transform::SimplifyForFeatureExtraction(),
+ tir::transform::LowerCrossThreadReduction(),
+ tir::transform::LowerInitBlock(),
+ tir::transform::PlanAndUpdateBufferAllocationLocation(),
+ tir::transform::ConvertBlocksToOpaque(),
+ tir::transform::UnifyThreadBinding(),
+ tir::transform::CompactBufferAllocation(),
+ tir::transform::LowerMatchBuffer(),
+ tir::transform::Simplify(),
+ });
+}
+
+} // namespace transform
+
+/*! \brief A data structure managing loop nests */
+struct LoopNest {
+ int64_t prod = 1; // The product of the extents of all the loops
+ ForVec loops; // All the loops
+ IntVec auto_unroll; // The loops with auto unroll pragma
+ ForVec parallel; // The loops whose ForKind are kParallel
+ ForVec vectorize; // The loops whose ForKind are kVectorized
+ ForVec unroll; // The loops whose ForKind are kUnrolled
+ ForVec blockIdx_x; // The loops whose ForKind are kThreadBinding to
blockIdx.x
+ ForVec blockIdx_y; // The loops whose ForKind are kThreadBinding to
blockIdx.y
+ ForVec blockIdx_z; // The loops whose ForKind are kThreadBinding to
blockIdx.z
+ ForVec threadIdx_x; // The loops whose ForKind are kThreadBinding to
threadIdx.x
+ ForVec threadIdx_y; // The loops whose ForKind are kThreadBinding to
threadIdx.y
+ ForVec threadIdx_z; // The loops whose ForKind are kThreadBinding to
threadIdx.z
+ ForVec vthread; // The loops whose ForKind are kThreadBinding to
vthread.*
+
+ /*!
+ * \brief Push a new loop into the loop nest
+ * \param loop The loop to be pushed
+ * \param auto_unroll_attr The auto unroll attribute of the loop
+ * \return A list of for loops that the loop is bound to
+ */
+ ForVec* Push(const ForNode* loop, int64_t* auto_unroll_attr) {
+ if (const int64_t* extent = GetLoopIntExtent(loop)) {
+ this->prod *= *extent;
+ }
+ this->loops.push_back(loop);
+ if ((*auto_unroll_attr = utils::GetPragmaAutoUnroll(loop)) > 0) {
+ this->auto_unroll.push_back(*auto_unroll_attr);
+ }
+ ForVec* ref_loops = nullptr;
+ if (loop->kind == ForKind::kParallel) {
+ ref_loops = ∥
+ } else if (loop->kind == ForKind::kVectorized) {
+ ref_loops = &vectorize;
+ } else if (loop->kind == ForKind::kUnrolled) {
+ ref_loops = &unroll;
+ } else if (loop->kind == ForKind::kThreadBinding) {
+ std::string thread_tag = loop->thread_binding.value()->thread_tag;
+ if (thread_tag == "blockIdx.x") {
+ ref_loops = &blockIdx_x;
+ } else if (thread_tag == "blockIdx.y") {
+ ref_loops = &blockIdx_y;
+ } else if (thread_tag == "blockIdx.z") {
+ ref_loops = &blockIdx_z;
+ } else if (thread_tag == "threadIdx.x") {
+ ref_loops = &threadIdx_x;
+ } else if (thread_tag == "threadIdx.y") {
+ ref_loops = &threadIdx_y;
+ } else if (thread_tag == "threadIdx.z") {
+ ref_loops = &threadIdx_z;
+ } else if (support::StartsWith(thread_tag, "vthread")) {
+ ref_loops = &vthread;
+ } else {
+ LOG(FATAL) << "ValueError: Unable to recognize thread tag: " <<
thread_tag;
+ }
+ }
+ if (ref_loops != nullptr) {
+ ref_loops->push_back(loop);
+ }
+ return ref_loops;
+ }
+
+ /*!
+ * \brief Pop the last loop from the loop nest
+ * \param loop The loop to be popped
+ * \param ref_loops The list of for loops that the loop is bound to
+ * \param auto_unroll_attr The auto unroll attribute of the loop
+ */
+ void Pop(const ForNode* loop, ForVec* ref_loops, int auto_unroll_attr) {
+ if (ref_loops) {
+ ref_loops->pop_back();
+ }
+ if (auto_unroll_attr > 0) {
+ this->auto_unroll.pop_back();
+ }
+ if (const int64_t* extent = GetLoopIntExtent(loop)) {
+ this->prod /= *extent;
+ }
+ this->loops.pop_back();
+ }
+};
+
+/****** Group 1: Computation related features ******/
+
+namespace group1 {
+
+/*! \brief Group 1 features */
+struct Feature {
+ /*! \brief Arithmetic features */
+ struct ArithOps {
+ // Float-point arithmetic features
+ int64_t float_mad = 0; // The number of float MAD (Multiply–add)
ops
+ int64_t float_add_sub = 0; // The number of float add and sub ops
+ int64_t float_mul = 0; // The number of float multiply ops
+ int64_t float_div_mod = 0; // The number of float div and mod ops
+ int64_t float_cmp = 0; // The number of float comparison ops
+ int64_t float_math_func = 0; // The number of float math func calls
+ int64_t float_other_func = 0; // The number of other float func calls
+ // Integer arithmetic features
+ int64_t int_mad = 0; // The number of integer MAD (Multiply–add)
ops
+ int64_t int_add_sub = 0; // The number of integer add and sub ops
+ int64_t int_mul = 0; // The number of integer multiply ops
+ int64_t int_div_mod = 0; // The number of integer div and mod ops
+ int64_t int_cmp = 0; // The number of integer comparison ops
+ int64_t int_math_func = 0; // The number of integer math func calls
+ int64_t int_other_func = 0; // The number of other integer func calls
+ // Other arithmetic features
+ int64_t bool_op = 0; // The number of bool ops
+ int64_t select_op = 0; // The number of select ops
+
+ static constexpr int64_t kCount = 16;
+
+ ArithOps() = default;
+ ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent);
+
+ void Export(std::vector<double>* v) const {
+ double vs[] = {
+ slog(float_mad), slog(float_add_sub), slog(float_mul),
slog(float_div_mod),
+ slog(float_cmp), slog(float_math_func), slog(float_other_func), //
+ slog(int_mad), slog(int_add_sub), slog(int_mul),
slog(int_div_mod),
+ slog(int_cmp), slog(int_math_func), slog(int_other_func), //
+ slog(bool_op), slog(select_op),
+ };
+ v->insert(v->end(), std::begin(vs), std::end(vs));
+ }
+ };
+
+ /*! \brief Loop binding features */
+ struct ForKindFeature {
+ enum class Pos : int {
+ kPosNone = 0, // Does not have this kind of annotation
+ kPosInnerSpatial = 1, // The annotated iterator is the innermost
spatial iterator
+ kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial
iterator
+ kPosOuterSpatial = 3, // The annotated iterator is the outermost
spatial iterator
+ kPosInnerReduce = 4, // The annotated iterator is the innermost
reduce iterator
+ kPosMiddleReduce = 5, // The annotated iterator is a middle reduce
iterator
+ kPosOuterReduce = 6, // The annotated iterator is the outermost
reduce iterator
+ kPosMixed = 7, // The annotated iterator is a mixed space and
reduce iterator
+ };
+ int64_t num = 0; // The number of iterators with the annotation
+ int64_t prod = 0; // The product of the lengths of iterators with
the annotation
+ int64_t len = 0; // The length of the innermost iterator with
the annotation
+ Pos pos = Pos::kPosMixed; // The position of the iterators with the
annotation
+
+ static constexpr int64_t kCount = 11;
+
+ explicit ForKindFeature(const ForVec& loops);
+
+ void Export(std::vector<double>* v) const {
+ double vs[] = {
+ slog(num),
+ slog(prod),
+ slog(len),
+ static_cast<double>(static_cast<int>(pos) == 0),
+ static_cast<double>(static_cast<int>(pos) == 1),
+ static_cast<double>(static_cast<int>(pos) == 2),
+ static_cast<double>(static_cast<int>(pos) == 3),
+ static_cast<double>(static_cast<int>(pos) == 4),
+ static_cast<double>(static_cast<int>(pos) == 5),
+ static_cast<double>(static_cast<int>(pos) == 6),
+ static_cast<double>(static_cast<int>(pos) == 7),
+ };
+ v->insert(v->end(), std::begin(vs), std::end(vs));
+ }
+ };
+
+ ArithOps arith_ops; // Arithmetic features
+ ForKindFeature vectorize; // Loop binding features: kVectorize
+ ForKindFeature unroll; // Loop binding features: kUnroll
+ ForKindFeature parallel; // Loop binding features: kParallel
+ bool is_gpu = false; // If the program is running on GPU
+ int64_t blockIdx_x_len = 1; // The length of blockIdx.x
+ int64_t blockIdx_y_len = 1; // The length of blockIdx.y
+ int64_t blockIdx_z_len = 1; // The length of blockIdx.z
+ int64_t threadIdx_x_len = 1; // The length of threadIdx.x
+ int64_t threadIdx_y_len = 1; // The length of threadIdx.y
+ int64_t threadIdx_z_len = 1; // The length of threadIdx.z
+ int64_t vthread_len = 1; // The length of virtual thread
+
+ static constexpr int64_t kCount = ArithOps::kCount + ForKindFeature::kCount
* 3 + 8;
+
+ explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest,
bool is_gpu)
+ : arith_ops(store, loop_nest.prod),
+ vectorize(loop_nest.vectorize),
+ unroll(loop_nest.unroll),
+ parallel(loop_nest.parallel) {
+ if (is_gpu) {
+ this->is_gpu = true;
+ this->blockIdx_x_len = utils::FirstLoopExtent(loop_nest.blockIdx_x, 1);
+ this->blockIdx_y_len = utils::FirstLoopExtent(loop_nest.blockIdx_y, 1);
+ this->blockIdx_z_len = utils::FirstLoopExtent(loop_nest.blockIdx_z, 1);
+ this->threadIdx_x_len = utils::FirstLoopExtent(loop_nest.threadIdx_x, 1);
+ this->threadIdx_y_len = utils::FirstLoopExtent(loop_nest.threadIdx_y, 1);
+ this->threadIdx_z_len = utils::FirstLoopExtent(loop_nest.threadIdx_z, 1);
+ this->vthread_len = utils::FirstLoopExtent(loop_nest.vthread, 1);
+ }
+ }
+
+ void Export(std::vector<double>* v) const {
+ this->arith_ops.Export(v);
+ this->vectorize.Export(v);
+ this->unroll.Export(v);
+ this->parallel.Export(v);
+ double vs[] = {
+ static_cast<double>(is_gpu), //
+ slog(blockIdx_x_len), slog(blockIdx_y_len),
slog(blockIdx_z_len),
+ slog(threadIdx_x_len), slog(threadIdx_y_len),
slog(threadIdx_z_len),
+ slog(vthread_len),
+ };
+ v->insert(v->end(), std::begin(vs), std::end(vs));
+ }
+};
+
+Feature::ArithOps::ArithOps(const BufferStoreNode* store, int64_t
prod_loop_extent) {
+ class ArithOpCounter : public ExprVisitor {
+ public:
+#define TVM_FEATURE_SIMPLE(Type, Counter) \
+ void VisitExpr_(const Type* op) final { \
+ result_.Counter += this->prod_loop_extent_; \
+ ExprVisitor::VisitExpr_(op); \
+ }
+#define TVM_FEATURE_BINARY(Type, FloatCounter, IntCounter) \
+ void VisitExpr_(const Type* op) final { \
+ if (op->dtype.is_float()) { \
+ result_.FloatCounter += this->prod_loop_extent_; \
+ } else { \
+ result_.IntCounter += this->prod_loop_extent_; \
+ } \
+ ExprVisitor::VisitExpr_(op); \
+ }
+ TVM_FEATURE_SIMPLE(AndNode, bool_op);
+ TVM_FEATURE_SIMPLE(OrNode, bool_op);
+ TVM_FEATURE_SIMPLE(NotNode, bool_op);
+ TVM_FEATURE_SIMPLE(SelectNode, select_op);
+ TVM_FEATURE_BINARY(AddNode, float_add_sub, int_add_sub);
+ TVM_FEATURE_BINARY(SubNode, float_add_sub, int_add_sub);
+ TVM_FEATURE_BINARY(MulNode, float_mul, int_mul);
+ TVM_FEATURE_BINARY(DivNode, float_div_mod, int_div_mod);
+ TVM_FEATURE_BINARY(ModNode, float_div_mod, int_div_mod);
+ TVM_FEATURE_BINARY(FloorDivNode, float_div_mod, int_div_mod);
+ TVM_FEATURE_BINARY(FloorModNode, float_div_mod, int_div_mod);
+ TVM_FEATURE_BINARY(MaxNode, float_cmp, int_cmp);
+ TVM_FEATURE_BINARY(MinNode, float_cmp, int_cmp);
+ TVM_FEATURE_BINARY(EQNode, float_cmp, int_cmp);
+ TVM_FEATURE_BINARY(NENode, float_cmp, int_cmp);
+ TVM_FEATURE_BINARY(LTNode, float_cmp, int_cmp);
+ TVM_FEATURE_BINARY(LENode, float_cmp, int_cmp);
+ TVM_FEATURE_BINARY(GTNode, float_cmp, int_cmp);
+ TVM_FEATURE_BINARY(GENode, float_cmp, int_cmp);
+#undef TVM_FEATURE_BINARY
+#undef TVM_FEATURE_SIMPLE
+
+ void VisitExpr_(const CallNode* op) final {
+ static auto op_call_effect_ =
Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
+ TCallEffectKind effect_kind = op_call_effect_[Downcast<Op>(op->op)];
+ bool is_pure =
+ effect_kind == CallEffectKind::kPure || effect_kind ==
CallEffectKind::kExprAnnotation;
+ if (is_pure) {
+ if (op->dtype.is_float()) {
+ result_.float_math_func += prod_loop_extent_;
+ } else {
+ result_.int_math_func += prod_loop_extent_;
+ }
+ } else {
+ if (op->dtype.is_float()) {
+ result_.float_other_func += prod_loop_extent_;
+ } else {
+ result_.int_other_func += prod_loop_extent_;
+ }
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ int64_t prod_loop_extent_;
+ ArithOps result_;
+ };
+ ArithOpCounter counter;
+ counter.prod_loop_extent_ = prod_loop_extent;
+ counter(store->value);
+ *this = counter.result_;
+}
+
+Feature::ForKindFeature::ForKindFeature(const ForVec& loops) {
+ if (loops.empty()) {
+ this->num = 0;
+ this->prod = 0;
+ this->len = 0;
+ this->pos = ForKindFeature::Pos::kPosNone;
+ } else {
+ const int64_t* last_loop_extent = GetLoopIntExtent(loops.back());
+ this->num = loops.size();
+ this->len = last_loop_extent ? *last_loop_extent : 1;
+ this->pos = ForKindFeature::Pos::kPosMixed;
+ int64_t& prod = this->prod = 1;
+ for (const ForNode* loop : loops) {
+ if (const int64_t* extent = GetLoopIntExtent(loop)) {
+ prod *= *extent;
+ }
+ }
+ }
+}
+
+} // namespace group1
+
+namespace group2 {
+
+/*! \brief Group 2 features */
+struct Feature {
+ enum class AccessType : int {
+ /*! The buffer is read but not written */
+ kRead = 0,
+ /*! The buffer is written but not read */
+ kWrite = 1,
+ /*! The buffer is both read and written */
+ kReadWrite = 2,
+ /*! Unknown type */
+ kUnknownRW = 3,
+ };
+ enum class ReuseType : int {
+ /*! Buffer reuse because accessed on each iteration of a loop */
+ kLoopMultipleRead = 0,
+ /*! Buffer reuse because it is serially accessed */
+ kSerialMultipleReadWrite = 1,
+ /*! No buffer reuse */
+ kNoReuse = 2,
+ };
+
+ struct SubFeature {
+ /*! \brief The buffer this feature is for */
+ const BufferNode* buffer = nullptr;
+ /*! \brief The access type of the buffer */
+ AccessType access_type = AccessType::kUnknownRW;
+ /*! \brief A list of multi-dimensonal indices used to access the buffer */
+ std::vector<MultiIndex> multi_indices = {};
+ // Access information
+ /*! \brief loop_accessed_numel[i][...] means the number of elements
accessed by loops[i] */
+ std::vector<std::unordered_map<const BufferNode*, int64_t>>
loop_accessed_numel = {};
+ /*! \brief The shape of the data access */
+ IntVec access_shape;
+ /*! \brief The bytes that are continuously accessed */
+ int64_t num_continuous_bytes = 1;
+ // Stride information
+ /*! \brief The min stride of the access */
+ int64_t min_stride = 0;
+ /*! \brief The innermost stride */
+ int64_t innermost_stride = 0;
+ /*! \brief The product of the non-strided loops */
+ int64_t prod_non_strided_loop_extent = 0;
+ // Reuse information
+ /*! The type of data reuse */
+ ReuseType reuse_type = ReuseType::kNoReuse;
+ /*! The reuse distance in terms of number of iterations */
+ double reuse_dis_iter = 0.0;
+ /*! The reuse distance in terms of bytes */
+ double reuse_dis_bytes = 0.0;
+ /*! The reuse count */
+ int64_t reuse_ct = 0;
+ // Features
+ /*! The touched memory in bytes */
+ double bytes;
+ /*! The touched unique memory in bytes */
+ double unique_bytes;
+ /*! The number of touched cache lines */
+ double lines;
+ /*! The number touched unique cache lines */
+ double unique_lines;
+ /*! bytes / reuse_ct */
+ double bytes_d_reuse_ct;
+ /*! unique_bytes / reuse_ct */
+ double unique_bytes_d_reuse_ct;
+ /*! lines / reuse_ct */
+ double lines_d_reuse_ct;
+ /*! unique_lines / reuse_ct */
+ double unique_lines_d_reuse_ct;
+ /*! The stride in access */
+ double stride;
+
+ static constexpr int64_t kCount = 18;
+
+ void Export(std::vector<double>* v) const {
+ double vs[] = {
+ static_cast<double>(static_cast<int>(access_type) == 0),
+ static_cast<double>(static_cast<int>(access_type) == 1),
+ static_cast<double>(static_cast<int>(access_type) == 2),
+ // FeatureSet::BufferAccess::AccessType::kUnknownRW is ignored
+ slog(bytes),
+ slog(unique_bytes),
+ slog(lines),
+ slog(unique_lines),
+ static_cast<double>(static_cast<int>(reuse_type) == 0),
+ static_cast<double>(static_cast<int>(reuse_type) == 1),
+ static_cast<double>(static_cast<int>(reuse_type) == 2),
+ slog(reuse_dis_iter),
+ slog(reuse_dis_bytes),
+ slog(reuse_ct),
+ slog(bytes_d_reuse_ct),
+ slog(unique_bytes_d_reuse_ct),
+ slog(lines_d_reuse_ct),
+ slog(unique_lines_d_reuse_ct),
+ slog(stride),
+ };
+ v->insert(v->end(), std::begin(vs), std::end(vs));
+ }
+
+ static void Pad(std::vector<double>* v) { v->insert(v->end(), 18, 0.0); }
+
+ void SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer);
+
+ void SetReuse(const LoopNest& loop_nest, //
+ int64_t top_loop_touch_bytes, //
+ const ForBufferMap<IntVec>& buffer_touched_under_loop);
+
+ void SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes);
+
+ explicit SubFeature(const BufferNode* buffer, AccessType access_type,
+ std::vector<MultiIndex> multi_indices, int n_loops)
+ : buffer(buffer),
+ access_type(access_type),
+ multi_indices(multi_indices),
+ loop_accessed_numel(n_loops) {}
+ };
+
+ void Export(std::vector<double>* v, int buffers_per_store) const {
+ int n = sub_features.size();
+ for (int i = 0; i < buffers_per_store; ++i) {
+ if (i < n) {
+ sub_features[i].Export(v);
+ } else {
+ SubFeature::Pad(v);
+ }
+ }
+ }
+
+ explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest,
+ int64_t cache_line_bytes, IntVec* for_touched_bytes,
+ ForBufferMap<IntVec>* buffer_touched_under_loop,
arith::Analyzer* analyzer);
+
+ void Init(const BufferStoreNode* store, int n_loops);
+
+ void SetRegion(const LoopNest& loop_nest, //
+ IntVec* for_touched_bytes, //
+ ForBufferMap<IntVec>* buffer_touched_under_loop, //
+ arith::Analyzer* analyzer);
+
+ std::vector<SubFeature> sub_features;
+};
+
+void Feature::Init(const BufferStoreNode* store, int n_loops) {
+ struct Info {
+ AccessType access_type = AccessType::kUnknownRW;
+ std::vector<MultiIndex> multi_indices;
+ };
+ std::unordered_map<const BufferNode*, Info> buffer_info;
+ {
+ Info& info = buffer_info[store->buffer.get()];
+ info.access_type = AccessType::kWrite;
+ info.multi_indices.push_back({store->indices.begin(),
store->indices.end()});
+ }
+ PostOrderVisit(store->value, [&buffer_info](const ObjectRef& obj) -> void {
+ if (const BufferLoadNode* load = obj.as<BufferLoadNode>()) {
+ const BufferNode* buffer = load->buffer.get();
+ Info& info = buffer_info[buffer];
+ switch (info.access_type) {
+ case AccessType::kRead:
+ break;
+ case AccessType::kWrite:
+ info.access_type = AccessType::kReadWrite;
+ break;
+ case AccessType::kReadWrite:
+ break;
+ case AccessType::kUnknownRW:
+ default:
+ info.access_type = AccessType::kRead;
+ break;
+ }
+ if (info.access_type != AccessType::kReadWrite) {
+ info.multi_indices.push_back({load->indices.begin(),
load->indices.end()});
+ }
+ }
+ });
+ this->sub_features.reserve(buffer_info.size());
+ for (const auto& kv : buffer_info) {
+ this->sub_features.emplace_back(kv.first, kv.second.access_type,
+ std::move(kv.second.multi_indices),
n_loops);
+ }
+}
+
+void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes,
+ ForBufferMap<IntVec>* buffer_touched_under_loop,
+ arith::Analyzer* analyzer) {
+ int n_loops = loop_nest.loops.size();
+ const std::vector<const ForNode*>& loops = loop_nest.loops;
+ // Step 1. Initialize and bind all the loop variables to a constant
+ *for_touched_bytes = IntVec(n_loops, 0);
+ for (int i = 0; i < n_loops; ++i) {
+ const ForNode* loop = loops[i];
+ analyzer->Bind(loop->loop_var, loop->min, /*allow_override=*/true);
+ }
+ // Step 2. Corner case: no loops
+ if (n_loops == 0) {
+ // In this case, the `access_shape` is not calculated
+ for (SubFeature& feature : sub_features) {
+ feature.access_shape = IntVec(feature.buffer->shape.size(), 1);
+ }
+ return;
+ }
+ // Step 3. Gradually bind the loops from inner to outer,
+ // calculate the area the loops touch on each buffer
+ for (int i = n_loops - 1; i >= 0; --i) {
+ const ForNode* loop = loops[i];
+ analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min,
loop->extent),
+ /*allow_override=*/true);
+ int64_t& touched_bytes = (*for_touched_bytes)[i] = 0;
+ for (SubFeature& feature : sub_features) {
+ const BufferNode* buffer = feature.buffer;
+ // Note: `feature.access_shape` for `i == 0` is the only one preserved,
+ // while others are discarded
+ int64_t numel;
+ feature.access_shape = utils::RelaxAndUnion(feature.multi_indices,
&numel, analyzer);
+ feature.loop_accessed_numel[i][buffer] = numel;
+ touched_bytes += numel * buffer->dtype.bytes();
+ (*buffer_touched_under_loop)[loop][buffer].push_back(numel);
+ }
+ }
+}
+
+void Feature::SubFeature::SetStride(const LoopNest& loop_nest,
arith::Analyzer* analyzer) {
+ int n_loops = loop_nest.loops.size();
+ const std::vector<const ForNode*>& loops = loop_nest.loops;
+ // For each buffer, we find the loop stride on it
+ const BufferNode* buffer = this->buffer;
+ int ndim = this->buffer->shape.size();
+ IntVec buffer_shape = utils::GetBufferShape(GetRef<Buffer>(buffer),
analyzer);
+ // Calculate the buffer's stride from its shape
+ IntVec buffer_stride(ndim);
+ if (ndim >= 1) {
+ buffer_stride[ndim - 1] = 1;
+ for (int i = ndim - 2; i >= 0; --i) {
+ buffer_stride[i] = buffer_stride[i + 1] * buffer_shape[i + 1];
+ }
+ }
+ // Calculate `num_continuous_bytes`
+ {
+ int64_t& num_continuous_bytes = this->num_continuous_bytes = 1;
+ const IntVec& access_shape = this->access_shape;
+ ICHECK_EQ(access_shape.size(), buffer_shape.size());
+ for (int i = ndim - 1; i >= 0; --i) {
+ if (access_shape[i] == buffer_shape[i]) {
+ num_continuous_bytes = buffer_shape[i] * buffer->dtype.bytes();
+ break;
+ }
+ }
+ }
+ // Enumerate loops from inner to outer
+ int i = 0;
+ // Calculate this->min_stride
+ int64_t& stride = this->min_stride = 0;
+ for (i = n_loops - 1; i >= 0; --i) {
+ stride = utils::GetVarStride(this->multi_indices, buffer_stride,
loops[i]->loop_var);
+ if (stride != 0) {
+ break;
+ }
+ }
+ // Calculate this->innermost_stride
+ this->innermost_stride = (i == n_loops - 1) ? stride : 0;
+ // Calculate this->prod
+ int64_t& prod = this->prod_non_strided_loop_extent = 1;
+ for (int j = n_loops - 1; j > i; --j) {
+ if (const int64_t* extent = GetLoopIntExtent(loops[n_loops - 1])) {
+ prod *= *extent;
+ }
+ }
+}
+
+void Feature::SubFeature::SetReuse(const LoopNest& loop_nest, int64_t
top_loop_touch_bytes,
+ const ForBufferMap<IntVec>&
buffer_touched_under_loop) {
+ const BufferNode* buffer = this->buffer;
+ // Step 0. Collect all `Var`s that appears in the buffer region
+ std::unordered_set<const VarNode*> region_vars;
+ for (const MultiIndex& multi_index : this->multi_indices) {
+ for (const PrimExpr& index : multi_index) {
+ PostOrderVisit(index, [®ion_vars](const ObjectRef& obj) -> void {
+ if (const auto* var = obj.as<VarNode>()) {
+ region_vars.insert(var);
+ }
+ });
+ }
+ }
+ // Default case: no reuse
+ ReuseType& reuse_type = this->reuse_type = ReuseType::kNoReuse;
+ double& reuse_dis_iter = this->reuse_dis_iter = 0;
+ double& reuse_dis_bytes = this->reuse_dis_bytes = 0;
+ int64_t& reuse_ct = this->reuse_ct = 0;
+
+ // Step 3.2. Enumerate loops from inner to outer, find the first loop with
reuse
+ int n_loops = loop_nest.loops.size();
+ const std::vector<const ForNode*>& loops = loop_nest.loops;
+ for (int i = n_loops - 1; i >= 0; --i) {
+ const ForNode* loop = loops[i];
+ // Case 1. Find an invariant loop, i.e. reuse with kLoopMultipleRead
+ if (!region_vars.count(loop->loop_var.get())) {
+ reuse_type = ReuseType::kLoopMultipleRead;
+ if (const int64_t* extent = GetLoopIntExtent(loop)) {
+ reuse_ct = *extent;
+ } else {
+ reuse_ct = 1;
+ }
+ reuse_dis_iter = 1;
+ for (int j = n_loops - 1; j > i; --j) {
+ if (const int64_t* extent = GetLoopIntExtent(loops[j])) {
+ reuse_dis_iter *= *extent;
+ }
+ }
+ reuse_dis_bytes = 0.0;
+ if (i == n_loops - 1) {
+ reuse_dis_bytes = top_loop_touch_bytes;
+ } else {
+ for (const auto& iter : buffer_touched_under_loop.at(loops[i + 1])) {
+ const BufferNode* buffer = iter.first;
+ const IntVec& numels = iter.second;
+ int64_t numel = std::accumulate(numels.begin(), numels.end(),
int64_t(0));
+ reuse_dis_bytes += numel * buffer->dtype.bytes();
+ }
+ }
+ break;
+ }
+ // Case 2. Find serial reuse, i.e. reuse with kSerialMultipleReadWrite
+ const IntVec& touched = buffer_touched_under_loop.at(loop).at(buffer);
+ if (touched.size() >= 2) {
+ int64_t extent = 1;
+ if (const int64_t* ext = GetLoopIntExtent(loop)) {
+ extent = *ext;
+ }
+ reuse_type = ReuseType::kSerialMultipleReadWrite;
+ reuse_ct = touched.size() - 1;
+ reuse_dis_iter = *std::min_element(touched.begin(), touched.end());
+ reuse_dis_bytes = 0.0;
+ for (const auto& iter : buffer_touched_under_loop.at(loop)) {
+ const BufferNode* buffer = iter.first;
+ const IntVec& numels = iter.second;
+ int64_t numel = std::accumulate(numels.begin(), numels.end(),
int64_t(0));
+ reuse_dis_bytes += numel * buffer->dtype.bytes();
+ }
+ reuse_dis_iter /= extent;
+ reuse_dis_bytes /= extent;
+ break;
+ }
+ }
+}
+
+void Feature::SubFeature::SetFeature(const LoopNest& loop_nest, int64_t
cache_line_bytes) {
+ int64_t dtype_bytes = this->buffer->dtype.bytes();
+ this->stride = this->innermost_stride;
+ this->bytes = dtype_bytes * loop_nest.prod;
+ if (loop_nest.loops.empty()) {
+ this->unique_bytes = 1;
+ this->lines = 1;
+ this->unique_lines = 1;
+ } else {
+ this->unique_bytes = this->loop_accessed_numel.front().at(buffer) *
dtype_bytes;
+ this->lines = static_cast<double>(loop_nest.prod) /
this->prod_non_strided_loop_extent *
+ std::min(1.0, 1.0 * this->min_stride * dtype_bytes /
cache_line_bytes);
+ this->lines = std::max(1.0, this->lines);
+ this->unique_lines = static_cast<double>(this->unique_bytes) /
+ std::min(cache_line_bytes,
this->num_continuous_bytes);
+ this->unique_lines = std::max(1.0, this->unique_lines);
+ }
+ double proxy_reuse_ct = this->reuse_ct > 0 ? this->reuse_ct : 0.5;
+ this->bytes_d_reuse_ct = this->bytes / proxy_reuse_ct;
+ this->unique_bytes_d_reuse_ct = this->unique_bytes / proxy_reuse_ct;
+ this->lines_d_reuse_ct = this->lines / proxy_reuse_ct;
+ this->unique_lines_d_reuse_ct = this->unique_lines / proxy_reuse_ct;
+}
+
+Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest,
int64_t cache_line_bytes,
+ IntVec* for_touched_bytes, ForBufferMap<IntVec>*
buffer_touched_under_loop,
+ arith::Analyzer* analyzer) {
+ int n_loops = loop_nest.loops.size();
+ // Step 0. Initialize data structures
+ this->Init(store, n_loops);
+ // Step 1. Calculate region-related feature
+ this->SetRegion(loop_nest, for_touched_bytes, buffer_touched_under_loop,
analyzer);
+ // Step 2. Calculate stride-related feature
+ for (auto& feature : sub_features) {
+ feature.SetStride(loop_nest, analyzer);
+ }
+ // Step 3. Calculate reuse-related feature
+ int64_t top_loop_touch_bytes = 0.0;
+ if (n_loops > 0) {
+ for (const SubFeature& feature : sub_features) {
+ int64_t bytes = feature.buffer->dtype.bytes();
+ int64_t n_buffer = feature.loop_accessed_numel[0].size();
+ top_loop_touch_bytes += bytes * n_buffer;
+ }
+ }
+ for (auto& feature : sub_features) {
+ feature.SetReuse(loop_nest, top_loop_touch_bytes,
*buffer_touched_under_loop);
+ }
+ // Step 4. Calculate rest of the features
+ for (auto& feature : sub_features) {
+ feature.SetFeature(loop_nest, cache_line_bytes);
+ }
+ // Step 5. Sort the features
+ std::sort(sub_features.begin(), sub_features.end(), [](const SubFeature& a,
const SubFeature& b) {
+ if (a.lines != b.lines) {
+ return a.lines > b.lines;
+ }
+ if (a.bytes != b.bytes) {
+ return a.bytes > b.bytes;
+ }
+ return a.buffer->name < b.buffer->name;
+ });
+}
+
+} // namespace group2
+
+namespace group3 {
+
+/*! \brief Group 3 feature */
+struct Feature {
+ /*!
+ * \brief See the wiki page [1] for details
+ *
+ * [1] https://en.wikipedia.org/wiki/Roofline_model
+ */
+ std::vector<double> arith_intensity_curve;
+
+ void Export(std::vector<double>* v) const {
+ v->insert(v->end(), arith_intensity_curve.begin(),
arith_intensity_curve.end());
+ }
+
+ explicit Feature(int n_samples, const LoopNest& loop_nest, const IntVec&
for_touched_bytes,
+ const group1::Feature::ArithOps& arith_ops)
+ : arith_intensity_curve(n_samples, 0.0) {
+ const std::vector<const ForNode*>& loops = loop_nest.loops;
+ ICHECK_EQ(loops.size(), for_touched_bytes.size());
+ int n_loops = loops.size();
+ // Calculate `memory_bytes`
+ std::vector<double> memory_bytes;
+ memory_bytes.resize(n_loops);
+ for (int i = 0; i < n_loops; ++i) {
+ memory_bytes[n_loops - 1 - i] = std::log2(for_touched_bytes[i]);
+ }
+ // Calculate `compute_ops` and `cur_compute_ops`
+ std::vector<double> compute_ops;
+ double total_compute_ops = arith_ops.float_mad + arith_ops.float_add_sub +
arith_ops.float_mul +
+ arith_ops.float_div_mod + arith_ops.float_cmp +
+ arith_ops.float_math_func +
arith_ops.float_other_func;
+ total_compute_ops /= loop_nest.prod;
+ for (int i = n_loops - 1; i >= 0; --i) {
+ if (const int64_t* extent = GetLoopIntExtent(loops[i])) {
+ total_compute_ops *= *extent;
+ }
+ compute_ops.push_back(std::log2(total_compute_ops));
+ }
+ // Fill the feature set
+ if (total_compute_ops <= 0 || compute_ops.empty()) {
+ for (int i = 0; i < n_samples; ++i) {
+ arith_intensity_curve[i] = 0.0;
+ }
+ return;
+ }
+ total_compute_ops = compute_ops.back(); // i.e. total_compute_ops =
log2(total_compute_ops)
+ int p = 0;
+ for (int i = 0; i < n_samples; ++i) {
+ double& result = arith_intensity_curve[i];
+ double cur_compute_ops = static_cast<double>(i + 1) / n_samples *
total_compute_ops;
+ // Find the first `p` that `compute[p] >= total * (i + 1) / N`
+ for (; p < n_loops; ++p) {
+ if (compute_ops[p] >= cur_compute_ops - 1e-4) {
+ break;
+ }
+ }
+ CHECK_LT(p, n_loops);
+ if (p == 0) {
+ result = compute_ops[p] / memory_bytes[p];
+ } else {
+ double base = compute_ops[p - 1] / memory_bytes[p - 1];
+ double slope =
+ (compute_ops[p] / memory_bytes[p] - compute_ops[p - 1] /
memory_bytes[p - 1]) /
+ (compute_ops[p] - compute_ops[p - 1]);
+ result = base + slope * (cur_compute_ops - compute_ops[p - 1]);
+ }
+ }
+ }
+};
+
+} // namespace group3
+
+namespace group4 {
+
+/*! \brief Group 4 feature */
+struct Feature {
+ int64_t alloc_size = 0; // The size of allocated buffer in bytes
+ int64_t alloc_prod = 0; // alloc_outer_prod * alloc_inner_prod
+ int64_t alloc_outer_prod = 1; // The product of lengths of loops outside
the scope of the alloc
+
+ static constexpr int64_t kCount = 4;
+
+ void Export(std::vector<double>* v, int64_t outer_prod) const {
+ double vs[] = {
+ slog(alloc_size),
+ slog(alloc_prod),
+ slog(alloc_outer_prod),
+ slog(static_cast<double>(outer_prod) / alloc_outer_prod),
+ };
+ v->insert(v->end(), std::begin(vs), std::end(vs));
+ }
+
+ Feature() = default;
+
+ explicit Feature(const LoopNest& loop_nest, const Buffer& buffer,
arith::Analyzer* analyzer) {
+ std::vector<int64_t> shape = utils::GetBufferShape(buffer, analyzer);
+ int64_t numel = 1;
+ for (int64_t x : shape) {
+ numel *= x;
+ }
+ alloc_size = numel * buffer->dtype.bytes();
+ alloc_prod = numel * loop_nest.prod;
+ alloc_outer_prod = loop_nest.prod;
+ }
+};
+
+} // namespace group4
+
+namespace group5 {
+
+/*! \brief Group 5 feature */
+struct Feature {
+ int64_t outer_prod; // The product of lengths of outer loops
+ int num_loops; // The number of outer loops
+ int auto_unroll_max_step; // The value of pragma "auto_unroll_max_step"
+
+ static constexpr int64_t kCount = 3;
+
+ void Export(std::vector<double>* v) const {
+ double vs[] = {
+ slog(outer_prod),
+ slog(num_loops),
+ slog(auto_unroll_max_step),
+ };
+ v->insert(v->end(), std::begin(vs), std::end(vs));
+ }
+
+ explicit Feature(const LoopNest& loop_nest) {
+ this->outer_prod = loop_nest.prod;
+ this->num_loops = loop_nest.loops.size();
+ this->auto_unroll_max_step = loop_nest.auto_unroll.empty() ? 0 :
loop_nest.auto_unroll.back();
+ }
+};
+
+} // namespace group5
+
+/*! \brief The feature extracted */
+struct Feature {
+ const BufferNode* buffer = nullptr;
+ int buffer_order = -1;
+ std::unique_ptr<group1::Feature> group1 = nullptr;
+ std::unique_ptr<group2::Feature> group2 = nullptr;
+ std::unique_ptr<group3::Feature> group3 = nullptr;
+ std::unique_ptr<group4::Feature> group4 = nullptr;
+ std::unique_ptr<group5::Feature> group5 = nullptr;
+
+ bool operator<(const Feature& other) const { return buffer_order <
other.buffer_order; }
+};
+
+/*! \brief The main feature extractor */
+class PerStoreFeatureCollector : private StmtVisitor {
+ public:
+ static std::vector<Feature> Collect(bool is_gpu, int64_t cache_line_bytes,
+ int64_t
arith_intensity_curve_num_samples,
+ const IRModule& mod) {
+ PerStoreFeatureCollector collector(is_gpu, cache_line_bytes,
arith_intensity_curve_num_samples);
+ for (const auto& kv : mod->functions) {
+ if (const PrimFuncNode* func = kv.second.as<PrimFuncNode>()) {
+ collector(func->body);
+ for (const auto& it : func->buffer_map) {
+ collector.HandleBufferAlloc(it.second);
+ }
+ }
+ }
+ std::vector<Feature> result;
+ result.reserve(collector.buffer_features_.size());
+ for (auto& it : collector.buffer_features_) {
+ Feature& feature = it.second;
+ if (feature.buffer != nullptr) {
+ ICHECK(feature.group1);
+ ICHECK(feature.group2);
+ ICHECK(feature.group3);
+ ICHECK(feature.group5);
+ if (feature.group4 == nullptr) {
+ feature.group4 = std::make_unique<group4::Feature>();
+ }
+ result.push_back(std::move(feature));
+ }
+ }
+ std::sort(result.begin(), result.end());
+ return result;
+ }
+
+ private:
+ void VisitStmt_(const ForNode* loop) final {
+ int64_t auto_unroll;
+ ForVec* for_vec = loop_nest_.Push(loop, &auto_unroll);
+ StmtVisitor::VisitStmt_(loop);
+ loop_nest_.Pop(loop, for_vec, auto_unroll);
+ }
+
+ void VisitStmt_(const BufferStoreNode* store) final {
+ if (store->value->IsInstance<IntImmNode>() ||
store->value->IsInstance<FloatImmNode>()) {
+ return;
+ }
+ const BufferNode* buffer = store->buffer.get();
+ Feature& feature = buffer_features_[buffer];
+ if (feature.buffer == nullptr) {
+ feature.buffer = buffer;
+ feature.buffer_order = buffer_features_.size();
+ }
+ feature.group1 = std::make_unique<group1::Feature>(store, loop_nest_,
is_gpu_);
+ feature.group2 =
+ std::make_unique<group2::Feature>(store, loop_nest_,
cache_line_bytes_, &for_touched_bytes_,
+ &buffer_touched_under_loop_,
&analyzer_);
+ feature.group3 =
+ std::make_unique<group3::Feature>(arith_intensity_curve_num_samples_,
loop_nest_,
+ for_touched_bytes_,
feature.group1->arith_ops);
+ feature.group5 = std::make_unique<group5::Feature>(loop_nest_);
+ }
+
+ void VisitStmt_(const BlockNode* block) final {
+ StmtVisitor::VisitStmt_(block);
+ for (const Buffer& buffer : block->alloc_buffers) {
+ HandleBufferAlloc(buffer);
+ }
+ }
+
+ void HandleBufferAlloc(const Buffer& buffer) {
+ Feature& feature = buffer_features_[buffer.get()];
+ feature.group4 = std::make_unique<group4::Feature>(loop_nest_, buffer,
&analyzer_);
+ }
+
+ explicit PerStoreFeatureCollector(bool is_gpu, int64_t cache_line_bytes,
+ int64_t arith_intensity_curve_num_samples)
+ : is_gpu_(is_gpu),
+ cache_line_bytes_(cache_line_bytes),
+ arith_intensity_curve_num_samples_(arith_intensity_curve_num_samples)
{}
+
+ bool is_gpu_;
+ int64_t cache_line_bytes_;
+ int64_t arith_intensity_curve_num_samples_;
+ arith::Analyzer analyzer_;
+ LoopNest loop_nest_ = {};
+ IntVec for_touched_bytes_ = {};
+ ForBufferMap<IntVec> buffer_touched_under_loop_ = {};
+ std::unordered_map<const BufferNode*, Feature> buffer_features_ = {};
+};
+
+} // namespace tir
+} // namespace tvm
+
+namespace tvm {
+namespace meta_schedule {
+
+class PerStoreFeatureNode : public FeatureExtractorNode {
+ public:
+ int buffers_per_store;
+ int arith_intensity_curve_num_samples;
+ int cache_line_bytes;
+ int feature_vector_length;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("buffers_per_store", &buffers_per_store);
+ v->Visit("arith_intensity_curve_num_samples",
&arith_intensity_curve_num_samples);
+ v->Visit("cache_line_bytes", &cache_line_bytes);
+ v->Visit("feature_vector_length", &feature_vector_length);
+ }
+
+ void ExtractSingle(IRModule mod, bool is_gpu,
std::vector<std::vector<double>>* results) {
+ static transform::Sequential passes =
tir::transform::PassListForPerStoreFeature();
+ mod = passes(std::move(mod));
+ std::vector<tir::Feature> features =
tir::PerStoreFeatureCollector::Collect(
+ is_gpu, this->cache_line_bytes,
this->arith_intensity_curve_num_samples, mod);
+ int n_features = features.size();
+ results->resize(n_features);
+ for (int i = 0; i < n_features; ++i) {
+ const tir::Feature& feature = features[i];
+ std::vector<double>& result = (*results)[i];
+ result.reserve(feature_vector_length);
+ feature.group1->Export(&result);
+ feature.group2->Export(&result, this->buffers_per_store);
+ feature.group3->Export(&result);
+ feature.group4->Export(&result, feature.group5->outer_prod);
+ feature.group5->Export(&result);
+ ICHECK_EQ(static_cast<int>(result.size()), feature_vector_length);
+ }
+ }
+
+ Array<runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
+ const Array<MeasureCandidate>&
candidates) {
+ bool is_gpu = tune_context->target.value()->kind->name == "cuda";
+ std::vector<runtime::NDArray> results;
+ results.resize(candidates.size());
+ auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void {
+ const auto& candidate = candidates[task_id];
+ std::vector<std::vector<double>> features;
+ ExtractSingle(candidate->sch->mod(), is_gpu, &features);
+ results[task_id] = tir::utils::AsNDArray(features);
+ };
+ support::parallel_for_dynamic(0, candidates.size(),
tune_context->num_threads, f);
+ return results;
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.PerStoreFeature";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PerStoreFeatureNode, FeatureExtractorNode);
+};
+
+FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store,
+ int
arith_intensity_curve_num_samples,
+ int cache_line_bytes) {
+ ObjectPtr<PerStoreFeatureNode> n = make_object<PerStoreFeatureNode>();
+ n->buffers_per_store = buffers_per_store;
+ n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples;
+ n->cache_line_bytes = cache_line_bytes;
+ n->feature_vector_length = tir::group1::Feature::kCount +
//
+ tir::group2::Feature::SubFeature::kCount *
buffers_per_store + //
+ arith_intensity_curve_num_samples +
//
+ tir::group4::Feature::kCount +
//
+ tir::group5::Feature::kCount;
+ return FeatureExtractor(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode);
+TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature")
+ .set_body_typed(FeatureExtractor::PerStoreFeature);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 3e989e4..ef15f49 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -20,6 +20,7 @@
#define TVM_META_SCHEDULE_UTILS_H_
#include <dmlc/memory_io.h>
+#include <tvm/arith/analyzer.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/builder.h>
#include <tvm/meta_schedule/cost_model.h>
@@ -43,7 +44,10 @@
#include "../printer/text_printer.h"
#include "../support/array.h"
#include "../support/base64.h"
+#include "../support/nd_int_set.h"
+#include "../support/utils.h"
#include "../tir/schedule/primitive.h"
+#include "../tir/schedule/utils.h"
namespace tvm {
namespace meta_schedule {
diff --git a/src/tir/schedule/primitive/sampling.cc
b/src/tir/schedule/primitive/sampling.cc
index 83ef1e2..6d944b3 100644
--- a/src/tir/schedule/primitive/sampling.cc
+++ b/src/tir/schedule/primitive/sampling.cc
@@ -322,9 +322,9 @@ std::vector<int64_t> SamplePerfectTile(
const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t
max_innermost_factor,
Optional<Array<Integer>>* decision) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
- int64_t extent = GetLoopIntExtent(loop);
+ const int64_t* extent = GetLoopIntExtent(loop);
std::vector<int64_t> result;
- if (extent == -1) {
+ if (extent == nullptr) {
// Case 1. Handle loops with non-constant length
result = std::vector<int64_t>(n_splits, 1);
result[0] = -1;
@@ -333,7 +333,7 @@ std::vector<int64_t> SamplePerfectTile(
result = support::AsVector<Integer, int64_t>(decision->value());
int n = result.size();
ICHECK_GE(n, 2);
- int64_t len = extent;
+ int64_t len = *extent;
for (int i = n - 1; i > 0; --i) {
int64_t& l = result[i];
// A previous decision could become invalid because of the change of
outer tiles
@@ -347,7 +347,7 @@ std::vector<int64_t> SamplePerfectTile(
result[0] = len;
} else {
// Case 3. Use fresh new sampling result
- result = SamplePerfectTile(rand_state, extent, n_splits,
max_innermost_factor);
+ result = SamplePerfectTile(rand_state, *extent, n_splits,
max_innermost_factor);
ICHECK_LE(result.back(), max_innermost_factor);
}
*decision = support::AsArray<int64_t, Integer>(result);
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index c66c2ca..860b3f6 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -215,21 +215,69 @@ inline Map<Var, arith::IntSet> AsIntSet(const Map<Var,
Range>& var_dom) {
/*!
* \brief Get the extents of a loop
* \param loop The loop to be queried
- * \return The extents of the loop
+ * \return The extent of the loop, nullptr if the extent is not constant
*/
-inline int64_t GetLoopIntExtent(const ForNode* loop) {
- const auto* int_extent = loop->extent.as<IntImmNode>();
- return int_extent ? int_extent->value : -1;
-}
+inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return
as_const_int(loop->extent); }
/*!
* \brief Get the extents of a loop
* \param loop_sref The loop to be queried
- * \return The extents of the loop
+ * \return The extent of the loop, nullptr if the extent is not constant
*/
-inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) {
+inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
- return GetLoopIntExtent(loop);
+ return as_const_int(loop->extent);
+}
+
+/******** Annotation ********/
+
+/*!
+ * \brief Get the annotation on a Block/For
+ * \tparam TObjectRef The type of the annotation value
+ * \param sref The sref to the block or the for loop
+ * \param ann_key The annotation key to be looked up
+ * \return NullOpt if not found; otherwise the annotation value
+ */
+template <class TObjectRef, class TStmtNode>
+inline Optional<TObjectRef> GetAnn(const TStmtNode* stmt, const String&
ann_key) {
+ const Map<String, ObjectRef>* annotations = &stmt->annotations;
+ for (const auto& ann : *annotations) {
+ if (ann.first == ann_key) {
+ return Downcast<TObjectRef>(ann.second);
+ }
+ }
+ return NullOpt;
+}
+
+/*!
+ * \brief Get the annotation on a Block/For
+ * \tparam TObjectRef The type of the annotation value
+ * \param sref The sref to the block or the for loop
+ * \param ann_key The annotation key to be looked up
+ * \return NullOpt if not found; otherwise the annotation value
+ */
+template <class TObjectRef>
+inline Optional<TObjectRef> GetAnn(const StmtSRef& sref, const String&
ann_key) {
+ if (const auto* loop = sref->StmtAs<ForNode>()) {
+ return GetAnn<TObjectRef, ForNode>(loop, ann_key);
+ } else if (const auto* block = sref->StmtAs<BlockNode>()) {
+ return GetAnn<TObjectRef, BlockNode>(block, ann_key);
+ } else {
+ LOG(FATAL) << "TypeError: Unknown type of sref: " <<
sref->stmt->GetTypeKey();
+ throw;
+ }
+}
+
+/*!
+ * \brief Check if a Block/For has a specific pair of annotation key and values
+ * \param sref The sref to the block or the for loop
+ * \param ann_key The annotation key to be checked
+ * \param ann_val The annotation value to be checked
+ * \return Whether a Block/For has a specific pair of annotation key and values
+ */
+inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String&
ann_val) {
+ Optional<String> result = GetAnn<String>(sref, ann_key);
+ return result.defined() && result.value() == ann_val;
}
} // namespace tir
diff --git a/src/tir/transforms/unify_thread_binding.cc
b/src/tir/transforms/unify_thread_binding.cc
index aa58684..9c1aab6 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -58,8 +58,15 @@ class ThreadBindingUnifier : public StmtExprMutator {
if (op->kind != ForKind::kThreadBinding) {
return StmtExprMutator::VisitStmt_(op);
}
- return UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(),
- Range::FromMinExtent(op->min, op->extent));
+ Map<String, ObjectRef> annotations = op->annotations;
+ Stmt stmt = UnifyThreadBindingImpl(op, op->loop_var,
op->thread_binding.value(),
+ Range::FromMinExtent(op->min,
op->extent));
+ if (annotations.empty()) {
+ return stmt;
+ }
+ For new_loop = Downcast<For>(stmt);
+ new_loop.CopyOnWrite()->annotations = std::move(annotations);
+ return new_loop;
}
template <typename Node>
@@ -70,7 +77,7 @@ class ThreadBindingUnifier : public StmtExprMutator {
const String& thread_tag = old_iter_var->thread_tag;
// Step 2: Increase `thread_block_depth_` if the thread tag starts with
"blockIdx". If the
- // thread block depth is 0 before the increasement, it means we are
entering a new kernel, and
+ // thread block depth is 0 before the increment, it means we are entering
a new kernel, and
// therefore we need to make `thread_tag2iter_var_map_` empty, as
different kernels can have
// thread axes with different extents.
bool is_kernel_launch_scope = false;
diff --git
a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
new file mode 100644
index 0000000..7b6ef52
--- /dev/null
+++
b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py
@@ -0,0 +1,1555 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint:
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import sys
+from typing import Callable, List
+
+from numpy.testing import assert_allclose
+import pytest
+
+import tvm
+from tvm import meta_schedule as ms, te, tir
+from tvm.script import tir as T
+
+N_FEATURES = 164
+
+
[email protected]_func
+def matmul(
+ A: T.Buffer[(512, 512), "float32"],
+ B: T.Buffer[(512, 512), "float32"],
+ C: T.Buffer[(512, 512), "float32"],
+) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ for i0, i1, i2 in T.grid(512, 512, 512):
+ with T.block("C"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(C[i, j], A[i, k], B[k, j])
+ T.writes(C[i, j])
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+
+def _make_context(target) -> ms.TuneContext:
+ return ms.TuneContext(
+ target=target,
+ num_threads=1,
+ )
+
+
+def _make_candidate(f_sch: Callable[[], tir.Schedule]) -> ms.MeasureCandidate:
+ return ms.MeasureCandidate(sch=f_sch(), args_info=[])
+
+
+def _feature_names( # pylint: disable=invalid-name
+ buffers_per_store: int = 5,
+ arith_intensity_curve_num_samples: int = 10,
+) -> List[str]:
+ result = [
+ "float_mad",
+ "float_addsub",
+ "float_mul",
+ "float_divmod",
+ "float_cmp",
+ "float_mathfunc",
+ "float_otherfunc",
+ "int_mad",
+ "int_addsub",
+ "int_mul",
+ "int_divmod",
+ "int_cmp",
+ "int_mathfunc",
+ "int_otherfunc",
+ "bool_op",
+ "select_op",
+ "vec_num",
+ "vec_prod",
+ "vec_len",
+ "vec_type.kPosNone",
+ "vec_type.kPosInnerSpatial",
+ "vec_type.kPosMiddleSpatial",
+ "vec_type.kPosOuterSpatial",
+ "vec_type.kPosInnerReduce",
+ "vec_type.kPosMiddleReduce",
+ "vec_type.kPosOuterReduce",
+ "vec_type.kPosMixed",
+ "unroll_num",
+ "unroll_prod",
+ "unroll_len",
+ "unroll_type.kPosNone",
+ "unroll_type.kPosInnerSpatial",
+ "unroll_type.kPosMiddleSpatial",
+ "unroll_type.kPosOuterSpatial",
+ "unroll_type.kPosInnerReduce",
+ "unroll_type.kPosMiddleReduce",
+ "unroll_type.kPosOuterReduce",
+ "unroll_type.kPosMixed",
+ "parallel_num",
+ "parallel_prod",
+ "parallel_len",
+ "parallel_type.kPosNone",
+ "parallel_type.kPosInnerSpatial",
+ "parallel_type.kPosMiddleSpatial",
+ "parallel_type.kPosOuterSpatial",
+ "parallel_type.kPosInnerReduce",
+ "parallel_type.kPosMiddleReduce",
+ "parallel_type.kPosOuterReduce",
+ "parallel_type.kPosMixed",
+ "is_gpu",
+ "blockIdx_x_len",
+ "blockIdx_y_len",
+ "blockIdx_z_len",
+ "threadIdx_x_len",
+ "threadIdx_y_len",
+ "threadIdx_z_len",
+ "vthread_len",
+ ]
+ for i in range(buffers_per_store):
+ result.extend(
+ f"B{i}.{s}"
+ for s in [
+ "acc_type.kRead",
+ "acc_type.kWrite",
+ "acc_type.kReadWrite",
+ "bytes",
+ "unique_bytes",
+ "lines",
+ "unique_lines",
+ "reuse_type.kLoopMultipleRead",
+ "reuse_type.kSerialMultipleReadWrite",
+ "reuse_type.kNoReuse",
+ "reuse_dis_iter",
+ "reuse_dis_bytes",
+ "reuse_ct",
+ "bytes_d_reuse_ct",
+ "unique_bytes_d_reuse_ct",
+ "lines_d_reuse_ct",
+ "unique_lines_d_reuse_ct",
+ "stride",
+ ]
+ )
+ result.extend(f"arith_intensity_curve_{i}" for i in
range(arith_intensity_curve_num_samples))
+ result.extend(
+ [
+ "alloc_size",
+ "alloc_prod",
+ "alloc_outer_prod",
+ "alloc_inner_prod",
+ "outer_prod",
+ "num_loops",
+ "auto_unroll_max_step",
+ ]
+ )
+ # 57 + 18 * 5 + 10 + 4 + 3
+ assert len(result) == N_FEATURES
+ return result
+
+
+def _zip_feature(feature, names):
+ assert feature.ndim == 1
+ assert feature.shape[0] == N_FEATURES
+ assert len(names) == N_FEATURES
+ return list(zip(names, feature))
+
+
+def _print_feature(feature, st, ed): # pylint: disable=invalid-name
+ named_feature = _zip_feature(feature, _feature_names())
+ for k, v in named_feature[st:ed]:
+ print("\t", k, v)
+
+
+def test_cpu_matmul():
+ def _create_schedule():
+ func = matmul
+ sch = tir.Schedule(func, debug_mask="all")
+ block = sch.get_block("C")
+ i, j, k = sch.get_loops(block)
+ i_o, i_i = sch.split(i, factors=[None, 16]) # outer: 32
+ j_o, j_i = sch.split(j, factors=[None, 8]) # outer: 64
+ sch.reorder(i_o, j_o, k, j_i, i_i)
+ sch.vectorize(j_i)
+ sch.parallel(i_o)
+ sch.parallel(j_o)
+ sch.unroll(k)
+ return sch
+
+ extractor = ms.feature_extractor.PerStoreFeature()
+ (feature,) = extractor.extract_from(
+ _make_context(tvm.target.Target("llvm")),
+ candidates=[_make_candidate(_create_schedule)],
+ )
+ feature = feature.numpy()
+ assert feature.shape == (1, N_FEATURES)
+ f = feature[0]
+ # Group 1.1: arith
+ assert_allclose(
+ actual=f[0:16],
+ # fmt: off
+ desired=[
+ # float math ops
+ 0, 27, 27, 0, 0, 0, 0,
+ # int math ops
+ 0, 29, 29, 0, 0, 0, 0,
+ # bool/select ops
+ 0, 0,
+ ],
+ # fmt: on
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.2: vectorize
+ assert_allclose(
+ actual=f[16:27],
+ desired=[1.0, 3.169924, 3.169924, 0, 0, 0, 0, 0, 0, 0, 1],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.3: unroll
+ assert_allclose(
+ actual=f[27:38],
+ desired=[1.0, 9.002815, 9.002815, 0, 0, 0, 0, 0, 0, 0, 1],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.4: parallel
+ assert_allclose(
+ actual=f[38:49],
+ desired=[1.58496, 11.0007, 6.022368, 0, 0, 0, 0, 0, 0, 0, 1],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+ assert_allclose(
+ actual=f[49:57],
+ desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.1: Buffer A
+ assert_allclose(
+ actual=f[57:75],
+ desired=[
+ 1,
+ 0,
+ 0,
+ 29,
+ 20,
+ 27,
+ 14,
+ 1,
+ 0,
+ 0,
+ 4.087463,
+ 7.0552826,
+ 3.169925,
+ 26,
+ 17,
+ 24,
+ 11.0007038,
+ 9.002815,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.2: Buffer C
+ assert_allclose(
+ actual=f[75:93],
+ desired=[
+ 0.0,
+ 0.0,
+ 1.0,
+ 29.0,
+ 20.000001907348633,
+ 27.0,
+ 14.00008773803711,
+ 1.0,
+ 0.0,
+ 0.0,
+ 7.011227130889893,
+ 9.250298500061035,
+ 9.002815246582031,
+ 20.000001907348633,
+ 11.000703811645508,
+ 18.0000057220459,
+ 5.044394016265869,
+ 9.002815246582031,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.3: Buffer B
+ assert_allclose(
+ actual=f[93:111],
+ desired=[
+ 1.0,
+ 0.0,
+ 0.0,
+ 29.0,
+ 20.000001907348633,
+ 19.000001907348633,
+ 14.00008773803711,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 3.700439691543579,
+ 4.087462902069092,
+ 25.0,
+ 16.000022888183594,
+ 15.000043869018555,
+ 10.001408576965332,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.4: Dummy padding
+ assert_allclose(
+ actual=f[111:129],
+ desired=[0.0] * 18,
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.5: Dummy padding
+ assert_allclose(
+ actual=f[129:147],
+ desired=[0.0] * 18,
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 3: Arithmetic intensity
+ assert_allclose(
+ actual=f[147:157],
+ desired=[
+ 0.7097842693328857,
+ 0.7408391237258911,
+ 0.8750449419021606,
+ 0.9449487924575806,
+ 1.0148526430130005,
+ 1.0847564935684204,
+ 1.113688349723816,
+ 1.1394684314727783,
+ 1.2119636535644531,
+ 1.2971993684768677,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 4 & 5
+ assert_allclose(
+ actual=f[157:164],
+ desired=[
+ 20.000001907348633,
+ 18.0000057220459,
+ 1.0,
+ 27.0,
+ 27.0,
+ 2.5849626064300537,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+
+
+def test_cpu_fusion():
+ # pylint: disable=all
+ @T.prim_func
+ def func(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [64, 32], dtype="float32")
+ B = T.match_buffer(b, [64, 32], dtype="float32")
+ C = T.match_buffer(c, [64, 32], dtype="float32")
+ for i, j in T.grid(64, 32): # type: ignore
+ with T.block():
+ T.reads([A[i, j], B[i, j]]) # type: ignore
+ T.writes([B[i, j], C[i, j]]) # type: ignore
+ with T.block("B"):
+ T.reads([A[i, j]]) # type: ignore
+ T.writes([B[i, j]]) # type: ignore
+ B[i, j] = A[i, j] # type: ignore
+ with T.block("C"):
+ T.reads([B[i, j]]) # type: ignore
+ T.writes([C[i, j]]) # type: ignore
+ C[i, j] = B[i, j] # type: ignore
+
+ # pylint: enable=all
+
+ def _create_schedule():
+ return tir.Schedule(func, debug_mask="all")
+
+ extractor = ms.feature_extractor.PerStoreFeature()
+ (feature,) = extractor.extract_from(
+ _make_context(tvm.target.Target("llvm")),
+ candidates=[_make_candidate(_create_schedule)],
+ )
+ feature = feature.numpy()
+ assert feature.shape == (2, N_FEATURES)
+ ## Features for BufferStore(B)
+ f = feature[0]
+ # Group 1.1: arith
+ assert_allclose(
+ actual=f[0:16],
+ # fmt: off
+ desired=[0.0] * 16,
+ # fmt: on
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.2: vectorize
+ assert_allclose(
+ actual=f[16:27],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.3: unroll
+ assert_allclose(
+ actual=f[27:38],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.4: parallel
+ assert_allclose(
+ actual=f[38:49],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+ assert_allclose(
+ actual=f[49:57],
+ desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.1: Buffer A
+ assert_allclose(
+ actual=f[57:75],
+ desired=[
+ 1.0,
+ 0.0,
+ 0.0,
+ 13.000176429748535,
+ 13.000176429748535,
+ 7.011227130889893,
+ 7.011227130889893,
+ 0.0,
+ 0.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 14.00008773803711,
+ 14.00008773803711,
+ 8.005624771118164,
+ 8.005624771118164,
+ 1.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.2: Buffer B
+ assert_allclose(
+ actual=f[75:93],
+ desired=[
+ 0.0,
+ 1.0,
+ 0.0,
+ 13.000176429748535,
+ 13.000176429748535,
+ 7.011227130889893,
+ 7.011227130889893,
+ 0.0,
+ 0.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 14.00008773803711,
+ 14.00008773803711,
+ 8.005624771118164,
+ 8.005624771118164,
+ 1.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.3: Dummy padding
+ assert_allclose(
+ actual=f[93:111],
+ desired=[0.0] * 18,
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.4: Dummy padding
+ assert_allclose(
+ actual=f[111:129],
+ desired=[0.0] * 18,
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.5: Dummy padding
+ assert_allclose(
+ actual=f[129:147],
+ desired=[0.0] * 18,
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 3: Arithmetic intensity
+ assert_allclose(
+ actual=f[147:157],
+ desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 4 & 5
+ assert_allclose(
+ actual=f[157:164],
+ desired=[
+ 13.000176,
+ 11.000703811645508,
+ 1.0,
+ 11.000703811645508,
+ 11.000703811645508,
+ 1.5849624872207642,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ ## Features for BufferStore(C)
+ f = feature[1]
+ # Group 1.1: arith
+ assert_allclose(
+ actual=f[0:16],
+ # fmt: off
+ desired=[0.0] * 16,
+ # fmt: on
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.2: vectorize
+ assert_allclose(
+ actual=f[16:27],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.3: unroll
+ assert_allclose(
+ actual=f[27:38],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.4: parallel
+ assert_allclose(
+ actual=f[38:49],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+ assert_allclose(
+ actual=f[49:57],
+ desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.1: Buffer B
+ assert_allclose(
+ actual=f[57:75],
+ desired=[
+ 1.0,
+ 0.0,
+ 0.0,
+ 13.000176429748535,
+ 13.000176429748535,
+ 7.011227130889893,
+ 7.011227130889893,
+ 0.0,
+ 1.0,
+ 0.0,
+ 1.0,
+ 4.087462902069092,
+ 1.0,
+ 13.000176429748535,
+ 13.000176429748535,
+ 7.011227130889893,
+ 7.011227130889893,
+ 1.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.2: Buffer C
+ assert_allclose(
+ actual=f[75:93],
+ desired=[
+ 0.0,
+ 1.0,
+ 0.0,
+ 13.000176429748535,
+ 13.000176429748535,
+ 7.011227130889893,
+ 7.011227130889893,
+ 0.0,
+ 0.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 14.00008773803711,
+ 14.00008773803711,
+ 8.005624771118164,
+ 8.005624771118164,
+ 1.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.3: Dummy padding
+ assert_allclose(
+ actual=f[93:111],
+ desired=[0.0] * 18,
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.4: Dummy padding
+ assert_allclose(
+ actual=f[111:129],
+ desired=[0.0] * 18,
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.5: Dummy padding
+ assert_allclose(
+ actual=f[129:147],
+ desired=[0.0] * 18,
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 3: Arithmetic intensity
+ assert_allclose(
+ actual=f[147:157],
+ desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 4 & 5
+ assert_allclose(
+ actual=f[157:164],
+ desired=[
+ 13.000176429748535,
+ 11.000703811645508,
+ 1.0,
+ 11.000703811645508,
+ 11.000703811645508,
+ 1.5849624872207642,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+
+
+def test_gpu():
+ def _create_schedule():
+ func = matmul
+ sch = tir.Schedule(func, debug_mask="all")
+ c = sch.get_block("C")
+ c_local = sch.cache_write(c, 0, "local")
+ i, j, k = sch.get_loops(c)
+ # pylint: disable=invalid-name
+ i0, i1, i2, i3, i4 = sch.split(i, factors=[None, 1, 16, 32, 1]) #
outer: 1
+ j0, j1, j2, j3, j4 = sch.split(j, factors=[None, 4, 1, 1, 16]) #
outer: 8
+ k0, k1, k2 = sch.split(k, factors=[None, 1, 2]) # outer: 256
+ # pylint: enable=invalid-name
+ # fmt: off
+ sch.reorder(
+ i0, j0, # S
+ i1, j1, # S
+ i2, j2, # S
+ k0, # R
+ k1, # R
+ i3, j3, # S
+ k2, # R
+ i4, j4, # S
+ )
+ # fmt: on
+ # thread binding
+ i0_j0 = sch.fuse(i0, j0)
+ i1_j1 = sch.fuse(i1, j1)
+ i2_j2 = sch.fuse(i2, j2)
+ sch.bind(i0_j0, "blockIdx.x")
+ sch.bind(i1_j1, "vthread.x")
+ sch.bind(i2_j2, "threadIdx.x")
+ # fusion
+ sch.reverse_compute_at(c_local, i2_j2)
+ # cache read 'A'
+ a_shared = sch.cache_read(c, 1, "shared")
+ sch.compute_at(a_shared, k0)
+ _, _, _, _, a_i, a_j = sch.get_loops(a_shared)
+ a_ij = sch.fuse(a_i, a_j)
+ _, a_j = sch.split(a_ij, factors=[None, 16]) # outer: 64
+ sch.bind(a_j, "threadIdx.x")
+ # cache read 'B'
+ b_shared = sch.cache_read(c, 2, "shared")
+ sch.compute_at(b_shared, k0)
+ _, _, _, _, b_i, b_j = sch.get_loops(b_shared)
+ b_ij = sch.fuse(b_i, b_j)
+ _, b_j = sch.split(b_ij, factors=[None, 16]) # outer: 8
+ sch.bind(b_j, "threadIdx.x")
+ # auto unroll
+ sch.annotate(i0_j0, "pragma_auto_unroll_max_step", tir.IntImm("int32",
1024))
+ sch.annotate(i0_j0, "pragma_unroll_explicit", tir.IntImm("int32", 1))
+ return sch
+
+ extractor = ms.feature_extractor.PerStoreFeature()
+ (feature,) = extractor.extract_from(
+ _make_context(tvm.target.Target("cuda")),
+ candidates=[_make_candidate(_create_schedule)],
+ )
+ feature = feature.numpy()
+ assert feature.shape == (4, N_FEATURES)
+ ### Check feature[0]: BufferStore(A_shared) <= A[...]
+ f = feature[0]
+ # Group 1.1: arith
+ assert_allclose(
+ actual=f[0:16],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 24.000000085991324,
+ 24.000000085991324,
+ 24.000000085991324,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.2: vectorize
+ assert_allclose(
+ actual=f[16:27],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.3: unroll
+ assert_allclose(
+ actual=f[27:38],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.4: parallel
+ assert_allclose(
+ actual=f[38:49],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+ assert_allclose(
+ actual=f[49:57],
+ desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0,
1.0, 2.321928094887362],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.1: Buffer A
+ assert_allclose(
+ actual=f[57:75],
+ desired=[
+ 1.0,
+ 0.0,
+ 0.0,
+ 25.000000042995662,
+ 20.000001375860553,
+ 23.00000017198264,
+ 14.000088052430122,
+ 1.0,
+ 0.0,
+ 0.0,
+ 18.00000550343433,
+ 20.00562591970089,
+ 2.321928094887362,
+ 23.00000017198264,
+ 18.00000550343433,
+ 21.000000687930438,
+ 12.0003521774803,
+ 12.0003521774803,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.2: Buffer A.shared
+ assert_allclose(
+ actual=f[75:93],
+ desired=[
+ 0.0,
+ 1.0,
+ 0.0,
+ 25.000000042995662,
+ 12.0003521774803,
+ 23.00000017198264,
+ 9.002815015607053,
+ 1.0,
+ 0.0,
+ 0.0,
+ 6.022367813028454,
+ 11.98049663618346,
+ 8.005624549193879,
+ 17.000011006847668,
+ 4.087462841250339,
+ 15.000044026886828,
+ 1.584962500721156,
+ 4.087462841250339,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.3: Dummy padding
+ assert_allclose(
+ actual=f[93:111],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.4: Dummy padding
+ assert_allclose(
+ actual=f[111:129],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.5: Dummy padding
+ assert_allclose(
+ actual=f[129:147],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 3: Arithmetic intensity
+ assert_allclose(
+ actual=f[147:157],
+ desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 4 & 5
+ assert_allclose(
+ actual=f[157:164],
+ desired=[
+ 12.0003521774803,
+ 27.000000010748916,
+ 17.000011006847668,
+ 6.022367813028454,
+ 23.00000017198264,
+ 2.584962500721156,
+ 10.001408,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ ### Check feature[1]: BufferStore(B_shared) <= B[...]
+ f = feature[1]
+ # Group 1.1: arith
+ assert_allclose(
+ actual=f[0:16],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 22.00000034396526,
+ 22.00000034396526,
+ 21.000000687930438,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.2: vectorize
+ assert_allclose(
+ actual=f[16:27],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.3: unroll
+ assert_allclose(
+ actual=f[27:38],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.4: parallel
+ assert_allclose(
+ actual=f[38:49],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+ assert_allclose(
+ actual=f[49:57],
+ desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0,
1.0, 2.321928094887362],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.1: Buffer B
+ assert_allclose(
+ actual=f[57:75],
+ desired=[
+ 1.0,
+ 0.0,
+ 0.0,
+ 22.00000034396526,
+ 20.000001375860553,
+ 20.000001375860553,
+ 14.000088052430122,
+ 1.0,
+ 0.0,
+ 0.0,
+ 15.000044026886828,
+ 20.17555076886471,
+ 2.321928094887362,
+ 20.000001375860553,
+ 18.00000550343433,
+ 18.00000550343433,
+ 12.0003521774803,
+ 4.087462841250339,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.2: Buffer B.shared
+ assert_allclose(
+ actual=f[75:93],
+ desired=[
+ 0.0,
+ 1.0,
+ 0.0,
+ 22.00000034396526,
+ 9.002815015607053,
+ 20.000001375860553,
+ 3.169925001442312,
+ 1.0,
+ 0.0,
+ 0.0,
+ 3.169925001442312,
+ 10.001408194392809,
+ 8.005624549193879,
+ 14.000088052430122,
+ 1.584962500721156,
+ 12.0003521774803,
+ 0.044394119358453436,
+ 4.087462841250339,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.3: Dummy padding
+ assert_allclose(
+ actual=f[93:111],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.4: Dummy padding
+ assert_allclose(
+ actual=f[111:129],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.5: Dummy padding
+ assert_allclose(
+ actual=f[129:147],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 3: Arithmetic intensity
+ assert_allclose(
+ actual=f[147:157],
+ desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 4 & 5
+ assert_allclose(
+ actual=f[157:164],
+ desired=[
+ 9.002815015607053,
+ 24.000000085991324,
+ 17.000011006847668,
+ 3.169925001442312,
+ 20.000001375860553,
+ 2.584962500721156,
+ 10.001408,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ ### Check feature[2]: BufferStore(C_local) <= C_local[...] + A_shared[...]
* B_shared[...]
+ f = feature[2]
+ # Group 1.1: arith
+ assert_allclose(
+ actual=f[0:16],
+ desired=[
+ 0.0,
+ 27.000000010748916,
+ 27.000000010748916,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 28.000000005374456,
+ 28.000000005374456,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.2: vectorize
+ assert_allclose(
+ actual=f[16:27],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.3: unroll
+ assert_allclose(
+ actual=f[27:38],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.4: parallel
+ assert_allclose(
+ actual=f[38:49],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+ assert_allclose(
+ actual=f[49:57],
+ desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0,
1.0, 2.321928094887362],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.1: Buffer B.shared
+ assert_allclose(
+ actual=f[57:75],
+ desired=[
+ 1.0,
+ 0.0,
+ 0.0,
+ 29.00000000268723,
+ 9.002815015607053,
+ 23.00000017198264,
+ 3.169925001442312,
+ 1.0,
+ 0.0,
+ 0.0,
+ 5.044394119358453,
+ 7.651051691178929,
+ 5.044394119358453,
+ 24.000000085991324,
+ 4.087462841250339,
+ 18.00000550343433,
+ 0.32192809488736235,
+ 1.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.2: Buffer C.local
+ assert_allclose(
+ actual=f[75:93],
+ desired=[
+ 0.0,
+ 0.0,
+ 1.0,
+ 29.00000000268723,
+ 11.000704269011246,
+ 23.00000017198264,
+ 5.044394119358453,
+ 1.0,
+ 0.0,
+ 0.0,
+ 4.087462841250339,
+ 7.05528243550119,
+ 1.584962500721156,
+ 28.000000005374456,
+ 10.001408194392809,
+ 22.00000034396526,
+ 4.087462841250339,
+ 1.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.3: Buffer A.shared
+ assert_allclose(
+ actual=f[93:111],
+ desired=[
+ 1.0,
+ 0.0,
+ 0.0,
+ 29.00000000268723,
+ 12.0003521774803,
+ 19.00000275171979,
+ 9.002815015607053,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 3.700439718141092,
+ 4.087462841250339,
+ 25.000000042995662,
+ 8.005624549193879,
+ 15.000044026886828,
+ 5.044394119358453,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.4: Dummy padding
+ assert_allclose(
+ actual=f[111:129],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.5: Dummy padding
+ assert_allclose(
+ actual=f[129:147],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 3: Arithmetic intensity
+ assert_allclose(
+ actual=f[147:157],
+ desired=[
+ 0.7097842504665767,
+ 0.7548801745187567,
+ 0.8775907547541741,
+ 0.9957389916154509,
+ 1.2446737395193135,
+ 1.493608487423176,
+ 1.7093103019954263,
+ 1.8031580276850985,
+ 1.9841832691827785,
+ 2.204648076869754,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 4 & 5
+ assert_allclose(
+ actual=f[157:164],
+ desired=[
+ 11.000704269011246,
+ 18.00000550343433,
+ 9.002815015607053,
+ 18.00000550343433,
+ 27.000000010748916,
+ 3.0,
+ 10.001408,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ ### Check feature[3]: BufferStore(C) <= C_local[...]
+ f = feature[3]
+ # Group 1.1: arith
+ assert_allclose(
+ actual=f[0:16],
+ desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.2: vectorize
+ assert_allclose(
+ actual=f[16:27],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.3: unroll
+ assert_allclose(
+ actual=f[27:38],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.4: parallel
+ assert_allclose(
+ actual=f[38:49],
+ desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread
+ assert_allclose(
+ actual=f[49:57],
+ desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0,
1.0, 2.321928094887362],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.1: Buffer C
+ assert_allclose(
+ actual=f[57:75],
+ desired=[
+ 0.0,
+ 1.0,
+ 0.0,
+ 20.000001375860553,
+ 20.000001375860553,
+ 14.000088052430122,
+ 14.000088052430122,
+ 0.0,
+ 0.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 21.000000687930438,
+ 21.000000687930438,
+ 15.000044026886828,
+ 15.000044026886828,
+ 1.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.2: Buffer C.local
+ assert_allclose(
+ actual=f[75:93],
+ desired=[
+ 1.0,
+ 0.0,
+ 0.0,
+ 20.000001375860553,
+ 11.000704269011246,
+ 14.000088052430122,
+ 5.044394119358453,
+ 1.0,
+ 0.0,
+ 0.0,
+ 9.002815015607053,
+ 12.0003521774803,
+ 4.087462841250339,
+ 16.00002201361136,
+ 7.011227255423254,
+ 10.001408194392809,
+ 1.584962500721156,
+ 1.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.3: Dummy padding
+ assert_allclose(
+ actual=f[93:111],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.4: Dummy padding
+ assert_allclose(
+ actual=f[111:129],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 2.5: Dummy padding
+ assert_allclose(
+ actual=f[129:147],
+ desired=[
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 3: Arithmetic intensity
+ assert_allclose(
+ actual=f[147:157],
+ desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+ # Group 4 & 5
+ assert_allclose(
+ actual=f[157:164],
+ desired=[
+ 20.000001375860553,
+ 18.00000550343433,
+ 1.0,
+ 18.00000550343433,
+ 18.00000550343433,
+ 2.584962500721156,
+ 10.001408,
+ ],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))