junrushao1994 commented on a change in pull request #9689:
URL: https://github.com/apache/tvm/pull/9689#discussion_r765920384
##########
File path: include/tvm/tir/analysis.h
##########
@@ -26,12 +26,14 @@
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
+#include <tvm/target/se_scope.h>
Review comment:
nit: probably we don't need two two includes?
##########
File path: src/tir/analysis/device_constraint_utils.h
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/device_constraint_utils.cc
+ * \brief Utilities for extracting and applying device-related constraints to
\p PrimFunc
+ * parameters.
+ *
+ * These utilities are used by the \p PlanDevices pass to extract memory (aka
'storage') scope
+ * information from \p PrimFuncs and convert them back into \p SEScope form
w.r.t. the original
+ * Relay type of the \p PrimFunc (ie before flattening of tuple
arguments/results and conversion
+ * to destination-passing style aka DPS).
+ *
+ * A utility is also supplied to go the other way: impose memory scopes on \p
PrimFunc parameters.
+ * However that's still in EXPERIMENTAL form.
+ *
+ * We may extend these utilities to also gather/apply layout information
should we add that to
+ * \p SEScope.
+ */
+
+#ifndef TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_
+#define TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_
+
+#include <tvm/target/se_scope.h>
+#include <tvm/tir/function.h>
+
+namespace tvm {
+namespace tir {
+
+/*
+ * A Relay Function with type:
+ * \code
+ * fn((Tensor[...], Tensor[...]), Tensor[...]) -> (Tensor[...], Tensor[...])
+ * ^ ^ ^ ^ ^
+ * a b c d e
+ * \endcode
+ * will be represented by a TIR PrimFunc in flattened and DPS form with at
least 5 argument a..e.
+ * Each such PrimFunc argument will have a type annotation for a PointerType
to the underlying
+ * tensor's buffer. The PrimFunc may have additional non-pointer arguments,
for example to represent
Review comment:
> The PrimFunc may have additional non-pointer arguments
Yeah that's correct. Another example is that PrimFunc takes scalars as
inputs too, but there is no correspondence in Relay either
##########
File path: src/tir/analysis/device_constraint_utils.cc
##########
@@ -0,0 +1,523 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/apply_device_constraints.cc
+ * \brief Applies device-related constraints to \p PrimFunc parameters.
+ *
+ * This is used by the \p PlanDevices pass to flow device-constraints *into*
\p PrimFuncs.
+ *
+ * Currently only applies memory scope constraints into \p Buffer data pointer
+ * storage scopes. Aliased ('matched') buffers take on any scope introduced on
+ * the buffer they alias. However currently does not attempt to flow
constraints into
+ * allocated buffers.
+ */
+
+#include "./device_constraint_utils.h"
+
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/target/se_scope.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+namespace {
+
+/*!
+ * \brief Returns the \p PointerTypeNode for \p buffer, or nullptr if \p
buffer does not describe a
+ * pointer.
+ */
+const PointerTypeNode* PointerInBuffer(const tir::Buffer& buffer) {
+ return buffer->data->type_annotation.defined()
+ ? buffer->data->type_annotation.as<PointerTypeNode>()
+ : nullptr;
+}
+
+/*!
+ * \brief Returns the parameter variable and corresponding buffer at or after
\p
+ * *current_primfunc_param_index in \p prim_func. Will skip over any
non-pointer parameters. This
+ * can be used to find the parameter matching a tensor type in a flattened
Relay function parameter
+ * or result.
+ */
+std::pair<tir::Var, tir::Buffer> FindPointerParam(const tir::PrimFunc&
prim_func,
+ size_t*
current_primfunc_param_index) {
+ while (true) {
+ ICHECK_LT(*current_primfunc_param_index, prim_func->params.size());
+ const tir::Var& param = prim_func->params[*current_primfunc_param_index];
+ auto itr = prim_func->buffer_map.find(param);
+ if (itr == prim_func->buffer_map.end()) {
+ VLOG(2) << "no buffer map entry for '" << param->name_hint << "'";
+ ++*current_primfunc_param_index;
+ continue;
+ }
+ const auto* pointer_type_node = PointerInBuffer((*itr).second);
+ if (pointer_type_node == nullptr) {
+ VLOG(2) << "not a pointer type for '" << param->name_hint << "'";
+ ++*current_primfunc_param_index;
+ continue;
+ }
+ VLOG(2) << "using PrimFunc param '" << param->name_hint << "'";
+ return *itr;
+ }
+}
+
+/*!
+ * \brief Check fails if any parameter at or after \p
*current_primfunc_param_index in \p prim_func
+ * is for a pointer type. This can be used to check all \p prim_func
parameters have been accounted
+ * for when using \p FindPointerParam above.
+ */
+void CheckNoRemainingPointerParams(const tir::PrimFunc& prim_func,
+ size_t* current_primfunc_param_index) {
+ while (*current_primfunc_param_index < prim_func->params.size()) {
+ const tir::Var& param = prim_func->params[*current_primfunc_param_index];
+ auto itr = prim_func->buffer_map.find(param);
+ if (itr == prim_func->buffer_map.end()) {
+ VLOG(1) << "no buffer map entry for '" << param->name_hint << "'";
+ ++*current_primfunc_param_index;
+ continue;
+ }
+ const auto* pointer_type_node = PointerInBuffer((*itr).second);
+ ICHECK(pointer_type_node == nullptr);
+ ++*current_primfunc_param_index;
+ }
+}
+
+/*!
+ * \brief Returns the (consistent) constraint to use for a Relay parameter of
\p type,
+ * using \p prim_func parameters at or after \p *current_primfunc_param_index.
Currently
+ * only memory scope is extracted. Fails if constraints are not consistent, ie
\p type is a tuple
+ * type and the \p prim_func is attempting to map different fields of that
tuple to different memory
+ * scopes. Returns the fully unconstrained \p SEScope if no memory scopes
constraints arise from
+ * the \p prim_func, ie all storage scope strings in pointer types are empty.
+ */
+SEScope ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type&
type,
+ size_t* current_primfunc_param_index) {
+ std::string memory_scope; // default empty => no constraint
+ for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) {
+ std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func,
current_primfunc_param_index);
+ const tir::Buffer& buffer = kv.second;
+ const auto* pointer_type_node =
buffer->data->type_annotation.as<PointerTypeNode>();
+ const MemoryScope& buffer_memory_scope = pointer_type_node->storage_scope;
+ if (memory_scope.empty()) {
+ memory_scope = buffer_memory_scope;
+ } else if (buffer_memory_scope.empty()) {
+ // No constraint.
+ } else {
+ // Tuples must be homogenous on their SEScope and thus memory scope.
+ ICHECK_EQ(buffer_memory_scope, memory_scope);
+ }
+ ++*current_primfunc_param_index;
+ }
+ return SEScope::ForMemoryScope(memory_scope);
+}
+
+/*!
+ * \brief Insert into param_constraints an entry for each parameter of \p
prim_func starting from
+ * \p *current_primfunc_param_index for the flattened form of a Rleay
parameters of \p type. Each
+ * entry maps to \p se_scope.
+ */
+void InsertParamConstraints(const tir::PrimFunc& prim_func, const Type& type,
+ const SEScope& se_scope, size_t*
current_primfunc_param_index,
+ std::unordered_map<const tir::VarNode*, SEScope>*
param_constraints) {
+ for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) {
+ std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func,
current_primfunc_param_index);
+ param_constraints->emplace(kv.first.get(), se_scope);
+ ++*current_primfunc_param_index;
+ }
+}
+
+/*!
+ * \brief Apply the memory scope constraints to the \p Buffers and data \p
Vars of a \p PrimFunc.
+ *
+ * All definitional occurrences of buffer Vars are rewritten to capture memory
scopes in their
+ * PointerTypes:
+ * - Buffer::data (if the buffer itself is a definitional occurrence)
+ * - AllocateNode::buffer_var
+ * - FUTURE: LetStmtNode::var if aliasing a buffer data var.
+ *
+ * All referential occurrences of buffer Vars are replaced with their new
definitions:
+ * - LoadNode::buffer_var
+ * - StoreNode::buffer_var
+ *
+ * Similarly all definitional occurrences of Buffers are rewritten to account
for any new memory
+ * scopes:
+ * - PrimFuncNode::buffer_map keys.
+ * - BlockNode::match_buffers.buffer
+ * - FUTURE: BlockNode::alloc_buffers?
+ *
+ * And all referential occurrences of Buffers are replaced with their new
definitions:
+ * - BufferLoadNode::buffer
+ * - BufferStoreNode::buffer
+ * - BufferRealizeNode::buffer
+ * - PrefetchNode::buffer
+ * - BufferRegionNode:buffer
+ * - BlockNode.match_buffers.source.buffer
+ * - BlockNode::{reads, writes}.buffer
+ *
+ * CAUTION: We assume strict sharing of Buffer objects and do not attempt to
rewrite the bodies
+ * of referential buffers.
+ *
+ * CAUTION: EXPERIMENTAL: We don't yet account for all buffers and pointer
types.
+ */
+class ApplyDeviceConstraintsMutator : public StmtExprMutator {
+ public:
+ ApplyDeviceConstraintsMutator() = default;
+
+ /*!
+ * \brief Returns \p prim_func written to capture the memory scope
constraints in \p
+ * param_constraints for each pointer \p prim_func parameter. Returns \p
prim_func unchanged if no
+ * memory scopes needed to change.
+ */
+ PrimFunc Rewrite(const PrimFunc& prim_func, const FuncType& relay_func_type,
+ const Array<SEScope>& arg_and_result_se_scopes) {
+ size_t current_primfunc_param_index = 0;
+ std::unordered_map<const tir::VarNode*, SEScope> param_constraints;
+
+ // For each Relay function parameter...
+ for (size_t i = 0; i < relay_func_type->arg_types.size(); ++i) {
+ const Type& param_type = relay_func_type->arg_types[i];
+ const SEScope& param_se_scope = arg_and_result_se_scopes[i];
+ InsertParamConstraints(prim_func, param_type, param_se_scope,
¤t_primfunc_param_index,
+ ¶m_constraints);
+ }
+
+ // For the Relay function result...
+ const Type& ret_type = relay_func_type->ret_type;
+ const SEScope& ret_se_scope = arg_and_result_se_scopes.back();
+ InsertParamConstraints(prim_func, ret_type, ret_se_scope,
¤t_primfunc_param_index,
+ ¶m_constraints);
+
+ // Make sure we accounted for all prim_func parameters.
+ CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index);
+
+ // Start with a copy of the current prim_func buffer map.
+ Map<Var, Buffer> new_buffer_map(prim_func->buffer_map.begin(),
prim_func->buffer_map.end());
+ bool any_change = false;
+
+ // For each constrained parameter...
+ for (const auto& kv : param_constraints) {
+ const tir::Var param = GetRef<tir::Var>(kv.first);
+ const SEScope& se_scope = kv.second;
+ const tir::Buffer& buffer = prim_func->buffer_map[param];
+ // Rewrite the buffer to account for constraint.
+ const Buffer new_buffer = RewriteBuffer(buffer, se_scope);
+ if (!new_buffer.same_as(buffer)) {
+ any_change = true;
+ }
+ new_buffer_map.Set(param, new_buffer);
+ }
+ // Make sure we have accounted for all prim_func parameters.
+ CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index);
+
+ // Apply data variable and buffer substitutions to the prim_func body.
These will have been
+ // accumulated from processing the parameters above.
+ Stmt new_body = VisitStmt(prim_func->body);
+ if (!new_body.same_as(prim_func->body)) {
+ any_change = true;
+ }
+
+ // We are done with the substitutions.
+ var_subst_.clear();
+ buffer_subst_.clear();
+
+ if (any_change) {
+ return PrimFunc(prim_func->params, std::move(new_body),
prim_func->ret_type,
+ std::move(new_buffer_map), prim_func->attrs,
prim_func->span);
+ } else {
+ return prim_func;
+ }
+ }
+
+ private:
+ PrimExpr VisitExpr_(const VarNode* var_node) final { return Subst(var_node);
}
+
+ PrimExpr VisitExpr_(const LoadNode* load_node) final {
+ Load new_load = Downcast<Load>(StmtExprMutator::VisitExpr_(load_node));
+ Var new_buffer_var = Subst(new_load->buffer_var.get());
+ if (!new_buffer_var.same_as(new_load->buffer_var)) {
+ return Load(load_node->dtype, new_buffer_var, load_node->index,
load_node->predicate);
+ }
+ return new_load;
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* buffer_load_node) final {
+ BufferLoad new_buffer_load =
+ Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(buffer_load_node));
+ Buffer new_buffer = Subst(new_buffer_load->buffer.get());
+ if (!new_buffer.same_as(new_buffer_load->buffer)) {
+ return BufferLoad(new_buffer, new_buffer_load->indices,
new_buffer_load->span);
+ }
+ return new_buffer_load;
+ }
+
+ Stmt VisitStmt_(const LetStmtNode* let_stmt_node) final {
+ // TODO(mbs): If the let-bound var is aliasing an existing buffer data var
we need to
+ // rewrite it.
+ return StmtExprMutator::VisitStmt_(let_stmt_node);
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* attr_stmt_node) final {
+ AttrStmt new_attr_stmt =
Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(attr_stmt_node));
+ // remap node if a var
+ if (const auto* var_node = new_attr_stmt->node.as<VarNode>()) {
+ Var new_var = Subst(var_node);
+ if (!new_var.same_as(new_attr_stmt->node)) {
+ return AttrStmt(new_var, new_attr_stmt->attr_key, new_attr_stmt->value,
+ new_attr_stmt->body);
+ }
+ }
+ return new_attr_stmt;
+ }
+
+ // ForNode default ok since loop_var never of PointerType
+
+ // WhileNode default ok
+
+ Stmt VisitStmt_(const AllocateNode* allocate_node) final {
+ // TODO(mbs): What memory scope should we assign to the new pointer?
+ return StmtExprMutator::VisitStmt_(allocate_node);
+ }
+
+ Stmt VisitStmt_(const StoreNode* store_node) final {
+ Store new_store = Downcast<Store>(StmtExprMutator::VisitStmt_(store_node));
+ Var new_buffer_var = Subst(new_store->buffer_var.get());
+ if (!new_buffer_var.same_as(new_store->buffer_var)) {
+ Store(new_buffer_var, new_store->value, new_store->index,
new_store->predicate);
+ }
+ return new_store;
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* buffer_store_node) final {
+ BufferStore new_buffer_store =
+ Downcast<BufferStore>(StmtExprMutator::VisitStmt_(buffer_store_node));
+ Buffer new_buffer = Subst(new_buffer_store->buffer.get());
+ if (!new_buffer.same_as(new_buffer_store->buffer)) {
+ return BufferStore(new_buffer, new_buffer_store->value,
new_buffer_store->indices,
+ new_buffer_store->span);
+ }
+ return new_buffer_store;
+ }
+
+ Stmt VisitStmt_(const BufferRealizeNode* buffer_realize_node) final {
+ BufferRealize new_buffer_realize =
+
Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(buffer_realize_node));
+ Buffer new_buffer = Subst(new_buffer_realize->buffer.get());
+ if (!new_buffer.same_as(new_buffer_realize->buffer)) {
+ return BufferRealize(new_buffer, new_buffer_realize->bounds,
new_buffer_realize->condition,
+ new_buffer_realize->body, new_buffer_realize->span);
+ }
+ return new_buffer_realize;
+ }
+
+ // IfThenElseNode default ok
+ // AssertStmtNode default ok
+ // ProducerStoreNode default ok (though does not visit producer)
+ // ProducerRealizeNode default ok (though does not visit producer)
+
+ Stmt VisitStmt_(const PrefetchNode* prefetch_node) final {
+ Prefetch new_prefetch =
Downcast<Prefetch>(StmtExprMutator::VisitStmt_(prefetch_node));
+ Buffer new_buffer = Subst(new_prefetch->buffer.get());
+ if (!new_buffer.same_as(new_prefetch->buffer)) {
+ return Prefetch(new_buffer, prefetch_node->bounds, prefetch_node->span);
+ }
+ return new_prefetch;
+ }
+
+ // SeqStmtNode default ok
+ // EvaluateNode default ok
+
+ BufferRegion VisitItem(const BufferRegionNode* buffer_region_node) {
+ Buffer new_buffer = Subst(buffer_region_node->buffer.get());
+ if (!new_buffer.same_as(buffer_region_node->buffer)) {
+ return BufferRegion(new_buffer, buffer_region_node->region);
+ }
+ return GetRef<BufferRegion>(buffer_region_node);
+ }
+
+ MatchBufferRegion VisitItem(const MatchBufferRegionNode*
match_buffer_region_node) {
+ // The source field has a referential occurrence of the buffer. Apply the
buffer substitution
+ // to that.
+ BufferRegion new_source =
VisitItem(match_buffer_region_node->source.get());
+ // The buffer field however is a definitional occurrence, aliased on top
of the source.
+ // Transfer any memory scope from the source to the destination.
+ Optional<SEScope> opt_se_scope = GetBufferConstraint(new_source->buffer);
+ tir::Buffer new_buffer;
+ if (opt_se_scope.defined()) {
+ new_buffer = RewriteBuffer(match_buffer_region_node->buffer,
opt_se_scope.value());
+ } else {
+ new_buffer = match_buffer_region_node->buffer;
+ }
+ if (!new_buffer.same_as(match_buffer_region_node->buffer) ||
+ !new_source.same_as(match_buffer_region_node->source)) {
+ return MatchBufferRegion(new_buffer, new_source);
+ }
+ return GetRef<MatchBufferRegion>(match_buffer_region_node);
+ }
+
+ template <typename T>
+ Array<T> VisitItems(Array<T> items) {
Review comment:
What's the diff between this method and Array's MutateByApply API?
##########
File path: python/tvm/tir/analysis/analysis.py
##########
@@ -196,3 +197,70 @@ def detect_buffer_access_lca(func: PrimFunc) ->
Dict[Buffer, Stmt]:
Map from buffer to the LCA of all access to it.
"""
return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint:
disable=no-member
+
+
+# NOTE: relay_func_type in the following two functions should be
relay.FuncType however that would
+# introduce a cycling dependency. We make do with Object.
+
+
+def get_prim_func_arg_and_result_memory_constraints(
+ func: PrimFunc, relay_func_type: Object
+) -> List[AnyStr]:
Review comment:
QQ: Why did we use List[AnyStr] instead of List[str]
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]