junrushao1994 commented on code in PR #11793: URL: https://github.com/apache/tvm/pull/11793#discussion_r912552390
########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; Review Comment: Instead let's just use enum class like: ```C++ enum class StorageType : int32_t { kGlobal = 0, kShared = 1, kLocal = 2, kOthers = 3, }; ``` ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } Review Comment: Use `runtime::StorageScope::Create`: https://github.com/apache/tvm/blob/64c6405d8253fd9381822ecd5b5102cdc7c2e3f8/src/runtime/thread_storage_scope.h#L126 Example: https://github.com/apache/tvm/blob/64c6405d8253fd9381822ecd5b5102cdc7c2e3f8/src/tir/schedule/state.cc#L46 ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { Review Comment: no need to introduce anonymous namespace, which might make stacktrace less informative ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { + if (rhs_val->value != init_val->value) { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = true; + init_constexpr_ = rhs; + } + } + } + void operator()(const Stmt& stmt) { + StmtVisitor::operator()(stmt); + if (!init_with_single_constexpr_) { + init_constexpr_ = PrimExpr(); + } + } + + bool init_with_single_constexpr_ = false; + PrimExpr init_constexpr_; + + friend class LocalPadder; +}; + +/*! + * \brief Split a predicate into inlinable and non-inlinable component. + * + * We refer to "inlinable predicate" as + * + * if (predicate) A = ...; + * ↓ + * A = predicate ? ... : init_constexpr; + * + * Note that not all predicates can be inlined. For example, if a predicate is there to guard + * against out-of-boundary accesses to local/shared variables, then it cannot be inlined. + */ +class PredicateInliner : public StmtExprVisitor { + private: + explicit PredicateInliner(const Stmt& body_stmt) : body_stmt_(body_stmt) {} + +#define VISIT_PREDICATE(OpType) \ + void VisitExpr_(const OpType##Node* op) final { \ Review Comment: nit: you may use `ContainerType` ```suggestion void VisitExpr_(const OpType::ContainerType* op) final { \ ``` ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { + if (rhs_val->value != init_val->value) { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = true; + init_constexpr_ = rhs; + } + } + } + void operator()(const Stmt& stmt) { + StmtVisitor::operator()(stmt); + if (!init_with_single_constexpr_) { + init_constexpr_ = PrimExpr(); + } + } + + bool init_with_single_constexpr_ = false; + PrimExpr init_constexpr_; + + friend class LocalPadder; +}; + +/*! + * \brief Split a predicate into inlinable and non-inlinable component. + * + * We refer to "inlinable predicate" as + * + * if (predicate) A = ...; + * ↓ + * A = predicate ? ... : init_constexpr; + * + * Note that not all predicates can be inlined. For example, if a predicate is there to guard + * against out-of-boundary accesses to local/shared variables, then it cannot be inlined. + */ +class PredicateInliner : public StmtExprVisitor { + private: + explicit PredicateInliner(const Stmt& body_stmt) : body_stmt_(body_stmt) {} + +#define VISIT_PREDICATE(OpType) \ + void VisitExpr_(const OpType##Node* op) final { \ + OpType predicate = GetRef<OpType>(op); \ + if (CanInlinePredicate_<OpType##Node>(op)) { \ + inlinable_predicates_.push_back(predicate); \ + } else { \ + non_inlinable_residuals_.push_back(predicate); \ + } \ + } + VISIT_PREDICATE(LT) + VISIT_PREDICATE(LE) + VISIT_PREDICATE(GT) + VISIT_PREDICATE(GE) +#undef VISIT_PREDICATE + + void VisitStmt_(const BufferStoreNode* op) final { + if (op->indices.size() != 1) { + return StmtVisitor::VisitStmt_(op); + } + CHECK(op->buffer.scope() == "shared" || op->buffer.scope() == "local"); + if (StructuralEqual()(op->indices[0], predicate_lhs_)) { + predicate_inlinable_ = false; + } + return StmtVisitor::VisitStmt_(op); + } + /*! + * \brief Check if a predicate can be inlined. + */ + template <typename OpNodeType> + bool CanInlinePredicate_(const OpNodeType* op) { + predicate_inlinable_ = true; + predicate_lhs_ = op->a; + VisitStmt(body_stmt_); + return predicate_inlinable_; + } + + Stmt body_stmt_; + std::vector<PrimExpr> inlinable_predicates_, non_inlinable_residuals_; + bool predicate_inlinable_; + PrimExpr predicate_lhs_; + + friend class LocalPadder; +}; Review Comment: Generally, I would love to propose that we restructure the logic of this class a little bit. Looks like the analyzer is interested in the following patterns: ```python if A </<=/>=/> X: B[Y] = ... ``` If so, there isn't much reason to use a visitor pattern because recursion didn't actually happen. Instead, let's go with a more plain and readable fashion, for example ```C++ // inputs: IfThenElse if_then_else; // extract the lhs & rhs of the if-condition PrimExpr predicate_lhs{nullptr}; PrimExpr predicate_rhs{nullptr}; if (const auto *op = if_then_else->condition.as<LENode>()) { predicate_lhs = op->a; predicate_rhs = op->a; } else if (...) { // use a macro or something to deal with LT, GE, GT } // then let's analyze the body statement const BufferStoreNode* buffer_store = if_then_else->then_case.as<BufferStoreNode>(); ICHECK(buffer_store); if (StructuralEqual()(buffer_store->indices[0], predicate_lhs)) { ... // some logic here } else { ... // some logic here } ``` ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { + if (rhs_val->value != init_val->value) { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = true; + init_constexpr_ = rhs; + } + } + } + void operator()(const Stmt& stmt) { + StmtVisitor::operator()(stmt); + if (!init_with_single_constexpr_) { + init_constexpr_ = PrimExpr(); + } + } + + bool init_with_single_constexpr_ = false; + PrimExpr init_constexpr_; + + friend class LocalPadder; +}; + +/*! + * \brief Split a predicate into inlinable and non-inlinable component. + * + * We refer to "inlinable predicate" as + * + * if (predicate) A = ...; + * ↓ + * A = predicate ? ... : init_constexpr; + * + * Note that not all predicates can be inlined. For example, if a predicate is there to guard + * against out-of-boundary accesses to local/shared variables, then it cannot be inlined. + */ +class PredicateInliner : public StmtExprVisitor { + private: + explicit PredicateInliner(const Stmt& body_stmt) : body_stmt_(body_stmt) {} + +#define VISIT_PREDICATE(OpType) \ + void VisitExpr_(const OpType##Node* op) final { \ + OpType predicate = GetRef<OpType>(op); \ + if (CanInlinePredicate_<OpType##Node>(op)) { \ + inlinable_predicates_.push_back(predicate); \ + } else { \ + non_inlinable_residuals_.push_back(predicate); \ + } \ + } + VISIT_PREDICATE(LT) + VISIT_PREDICATE(LE) + VISIT_PREDICATE(GT) + VISIT_PREDICATE(GE) +#undef VISIT_PREDICATE + + void VisitStmt_(const BufferStoreNode* op) final { + if (op->indices.size() != 1) { + return StmtVisitor::VisitStmt_(op); + } + CHECK(op->buffer.scope() == "shared" || op->buffer.scope() == "local"); + if (StructuralEqual()(op->indices[0], predicate_lhs_)) { + predicate_inlinable_ = false; + } Review Comment: so it means in: ```python if X < 100: # `predicate_lhs_` is X A[X] = ... # `op->indices[0]` is X ``` we are not going to inline the predicate? ########## include/tvm/tir/transform.h: ########## @@ -117,11 +127,13 @@ TVM_DLL Pass LoopPartition(); /*! * \brief Lower vectorization loops. * - * \param enable_vectorize Whether vectorization is enabled. + * \param enable_vectorize Whether vectorization is enabled. + * \param enable_local_pad Whether local padding is enabled. Local padding can affect + * how vectorization is made. * * \return The pass. */ -TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true); +TVM_DLL Pass VectorizeLoop(const bool enable_vectorize = true, const bool enable_local_pad = false); Review Comment: nit: there is no difference (for scalars) to add the `const` specifier ```suggestion TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true, bool enable_local_pad = false); ``` ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { Review Comment: there is no need to specify with extra `const`, because the pointer address is never mutated in this program, and the compiler can usually figure it out in CFA ```suggestion if (const ImmNodeType* rhs_val = rhs.as<ImmNodeType>()) { if (init_constexpr_.defined()) { if (const ImmNodeType* init_val = init_constexpr_.as<ImmNodeType>()) { ``` ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { + if (rhs_val->value != init_val->value) { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = true; + init_constexpr_ = rhs; + } + } + } + void operator()(const Stmt& stmt) { + StmtVisitor::operator()(stmt); + if (!init_with_single_constexpr_) { + init_constexpr_ = PrimExpr(); + } + } + + bool init_with_single_constexpr_ = false; + PrimExpr init_constexpr_; Review Comment: Use explicit nullability ```suggestion Optional<PrimExpr> init_constexpr_; ``` ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } Review Comment: Besides `BufferLoad` and `BufferStore`, there is another buffer access pattern called opaque access - accessing a buffer with intrinsics, for example, tensor core's MMA intrinsics. To check if a buffer `b` is opaque-accessed, we only need to check if `b->data` (whose type is `tir::Var`) is visited in `StmtExprVisitor`. Example: https://github.com/apache/tvm/blob/f8186d8c7d3e4679a6dfd83d17521f20bfb3ca42/src/tir/schedule/primitive/compute_inline.cc#L223-L226. where we check if `inlined_buffer` is opaquely accessed. ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { Review Comment: nit: No need to add the underscore given it's already a private method ```suggestion void CheckInitValue(const PrimExpr& rhs) { ``` ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { + if (rhs_val->value != init_val->value) { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = true; + init_constexpr_ = rhs; + } + } + } + void operator()(const Stmt& stmt) { + StmtVisitor::operator()(stmt); + if (!init_with_single_constexpr_) { + init_constexpr_ = PrimExpr(); + } + } + + bool init_with_single_constexpr_ = false; + PrimExpr init_constexpr_; + + friend class LocalPadder; +}; + +/*! + * \brief Split a predicate into inlinable and non-inlinable component. + * + * We refer to "inlinable predicate" as + * + * if (predicate) A = ...; + * ↓ + * A = predicate ? ... : init_constexpr; + * + * Note that not all predicates can be inlined. For example, if a predicate is there to guard + * against out-of-boundary accesses to local/shared variables, then it cannot be inlined. + */ +class PredicateInliner : public StmtExprVisitor { + private: + explicit PredicateInliner(const Stmt& body_stmt) : body_stmt_(body_stmt) {} + +#define VISIT_PREDICATE(OpType) \ + void VisitExpr_(const OpType##Node* op) final { \ + OpType predicate = GetRef<OpType>(op); \ + if (CanInlinePredicate_<OpType##Node>(op)) { \ + inlinable_predicates_.push_back(predicate); \ + } else { \ + non_inlinable_residuals_.push_back(predicate); \ + } \ + } + VISIT_PREDICATE(LT) + VISIT_PREDICATE(LE) + VISIT_PREDICATE(GT) + VISIT_PREDICATE(GE) +#undef VISIT_PREDICATE + + void VisitStmt_(const BufferStoreNode* op) final { + if (op->indices.size() != 1) { + return StmtVisitor::VisitStmt_(op); + } + CHECK(op->buffer.scope() == "shared" || op->buffer.scope() == "local"); Review Comment: ditto. Use `runtime::StorageScope` instead of bare strings ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { + if (rhs_val->value != init_val->value) { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = true; + init_constexpr_ = rhs; + } + } + } + void operator()(const Stmt& stmt) { + StmtVisitor::operator()(stmt); + if (!init_with_single_constexpr_) { + init_constexpr_ = PrimExpr(); + } + } + + bool init_with_single_constexpr_ = false; + PrimExpr init_constexpr_; + + friend class LocalPadder; +}; + +/*! + * \brief Split a predicate into inlinable and non-inlinable component. + * + * We refer to "inlinable predicate" as + * + * if (predicate) A = ...; + * ↓ + * A = predicate ? ... : init_constexpr; + * + * Note that not all predicates can be inlined. For example, if a predicate is there to guard + * against out-of-boundary accesses to local/shared variables, then it cannot be inlined. + */ +class PredicateInliner : public StmtExprVisitor { + private: + explicit PredicateInliner(const Stmt& body_stmt) : body_stmt_(body_stmt) {} + +#define VISIT_PREDICATE(OpType) \ Review Comment: nit: usually when defining macros, we will have to find a longer name in case it conflicts. for example: ```suggestion #define TVM_TIR_TRANSFORM_LOCAL_PAD_VISIT_PREDICATE(OpType) \ ``` ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { + if (rhs_val->value != init_val->value) { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = true; + init_constexpr_ = rhs; + } + } + } + void operator()(const Stmt& stmt) { + StmtVisitor::operator()(stmt); + if (!init_with_single_constexpr_) { + init_constexpr_ = PrimExpr(); + } + } + + bool init_with_single_constexpr_ = false; + PrimExpr init_constexpr_; + + friend class LocalPadder; +}; + +/*! + * \brief Split a predicate into inlinable and non-inlinable component. + * + * We refer to "inlinable predicate" as + * + * if (predicate) A = ...; + * ↓ + * A = predicate ? ... : init_constexpr; + * + * Note that not all predicates can be inlined. For example, if a predicate is there to guard + * against out-of-boundary accesses to local/shared variables, then it cannot be inlined. + */ +class PredicateInliner : public StmtExprVisitor { + private: + explicit PredicateInliner(const Stmt& body_stmt) : body_stmt_(body_stmt) {} + +#define VISIT_PREDICATE(OpType) \ + void VisitExpr_(const OpType##Node* op) final { \ + OpType predicate = GetRef<OpType>(op); \ + if (CanInlinePredicate_<OpType##Node>(op)) { \ + inlinable_predicates_.push_back(predicate); \ + } else { \ + non_inlinable_residuals_.push_back(predicate); \ + } \ Review Comment: Quick question: Do we expect to continue visiting `op->a` and `op->b`? ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { + if (rhs_val->value != init_val->value) { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = true; + init_constexpr_ = rhs; + } + } + } + void operator()(const Stmt& stmt) { + StmtVisitor::operator()(stmt); + if (!init_with_single_constexpr_) { + init_constexpr_ = PrimExpr(); + } + } + + bool init_with_single_constexpr_ = false; + PrimExpr init_constexpr_; + + friend class LocalPadder; +}; + +/*! + * \brief Split a predicate into inlinable and non-inlinable component. + * + * We refer to "inlinable predicate" as + * + * if (predicate) A = ...; + * ↓ + * A = predicate ? ... : init_constexpr; + * + * Note that not all predicates can be inlined. For example, if a predicate is there to guard + * against out-of-boundary accesses to local/shared variables, then it cannot be inlined. + */ +class PredicateInliner : public StmtExprVisitor { + private: + explicit PredicateInliner(const Stmt& body_stmt) : body_stmt_(body_stmt) {} Review Comment: Quick question: do we assume that `body_stmt` is a BufferStore? ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { Review Comment: ditto. Usually we put underscore after variables but not methods ########## src/tir/transforms/local_pad.cc: ########## @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/meta_schedule/postproc.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <array> +#include <utility> +#include <vector> + +namespace tvm { +namespace tir { +namespace transform { +namespace { + +/*! + * \brief Analyze the read and write accesses of the body statements, used by `LocalPadder`. + */ +class StorageAccessAnalyzer : public StmtExprVisitor { + private: + struct StorageType { + enum { kGlobal = 0, kShared, kLocal, kOthers }; + }; + + void VisitStmt_(const BufferStoreNode* op) final { + write_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + void VisitExpr_(const BufferLoadNode* op) final { + read_marker_.SetStorageAccessMarker_(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + class AccessMarker { + public: + void SetStorageAccessMarker_(const Buffer& buf) { + if (buf.scope() == "global") { + bit_vector_[StorageType::kGlobal] = true; + } else if (buf.scope() == "shared") { + bit_vector_[StorageType::kShared] = true; + } else if (buf.scope() == "local") { + bit_vector_[StorageType::kLocal] = true; + } else { + bit_vector_[StorageType::kOthers] = true; + } + } + bool NoAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kShared] || + bit_vector_[StorageType::kLocal] || bit_vector_[StorageType::kOthers]); + } + bool OnlyGlobalAccesses() const { + return !(bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal] || + bit_vector_[StorageType::kOthers]) && + bit_vector_[StorageType::kGlobal]; + } + bool OnlyLocalOrSharedAccesses() const { + return !(bit_vector_[StorageType::kGlobal] || bit_vector_[StorageType::kOthers]) && + (bit_vector_[StorageType::kShared] || bit_vector_[StorageType::kLocal]); + } + + private: + std::array<bool, StorageType::kOthers + 1> bit_vector_ = {false}; + }; + AccessMarker read_marker_, write_marker_; + std::pair<AccessMarker, AccessMarker> Analyze_(const Stmt& stmt) { + VisitStmt(stmt); + return std::make_pair(read_marker_, write_marker_); + } + + friend class LocalPadder; +}; + +/*! + * \brief Verify that all local variables are initialized to the same constant expression. + */ +class InitChecker : public StmtVisitor { + private: + void VisitStmt_(const BufferStoreNode* op) final { + // Read the check the RHS values, make sure that they are the same constant for all the + // initialization statements. + CheckInitValue_<IntImmNode>(op->value); + CheckInitValue_<FloatImmNode>(op->value); + return StmtVisitor::VisitStmt_(op); + } + template <typename ImmNodeType> + void CheckInitValue_(const PrimExpr& rhs) { + if (const ImmNodeType* const rhs_val = rhs.as<ImmNodeType>()) { + if (init_constexpr_.defined()) { + if (const ImmNodeType* const init_val = init_constexpr_.as<ImmNodeType>()) { + if (rhs_val->value != init_val->value) { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = false; + } + } else { + init_with_single_constexpr_ = true; + init_constexpr_ = rhs; + } + } + } + void operator()(const Stmt& stmt) { + StmtVisitor::operator()(stmt); + if (!init_with_single_constexpr_) { + init_constexpr_ = PrimExpr(); + } + } + + bool init_with_single_constexpr_ = false; + PrimExpr init_constexpr_; + + friend class LocalPadder; +}; + +/*! + * \brief Split a predicate into inlinable and non-inlinable component. + * + * We refer to "inlinable predicate" as + * + * if (predicate) A = ...; + * ↓ Review Comment: just nitpicking...can we avoid using non-ascii comment in the code base? i remembered there used to be some weird compilation issues on windows... ```suggestion * | ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
