This is an automated email from the ASF dual-hosted git repository.
sslyu pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new a763b22119 [Unity][Transform] Replace eligible operators with in-place
versions in dataflow blocks (#16129)
a763b22119 is described below
commit a763b2211936d319d2177bc2a06e22f5284f4bc4
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Wed Jan 17 17:05:54 2024 -0500
[Unity][Transform] Replace eligible operators with in-place versions in
dataflow blocks (#16129)
* Implement basic analyses
* Fix typo
* Add tests for analyses
* Include in-place analysis
* Return the lists instead
* Update python binding
* No need to assume *pure* functions capture all values ever passed to
them. Also use pointers instead of non-const refs
* Improve handling of tuples in mystery call case
* Corrections to inplace checking
* Add test case for mystery value
* typo
* Add inplace test case, correct minor issues
* Consider also using larger tensors to store smaller ones
* Check call args against any possible target sinfo, also check tensor
sinfo dtype
* Handle output vars and tuple get item
* Add legalization for in-place functions
* No need to update the NoAlias attribute, actually
* Fix TIR transformation, add tests for inline transformation
* Only find candidates from supported ops and list _all_ feasible argument
indices
* Implement basic transformation pass
* Use a module pass so wider changes are visible, reorganize
* Have an end-to-end test case for the in-place transformation
* Rebase fixes and use GetBoundValue instead of reimplementing it
* Let's just use 'inplace' everywhere
* Reorganize code and add more documentation
* Include proper bounds check
* Trailing whitespace
* Need a trailing newline
* Remove unused imports
* Add docstrings for exposed inner functions
* Reformat docstrings to appease the linter
* C++ stylistic changes
* Treat args as mystery values by default, do not allow overwriting
* Formatting
* Clarify pass description
* Add check to ensure that testing functions are used only in a testing
environment
* Improve size match check readability per review suggestions
* Improve the size match check per review suggestions (use PrimExprs)
* Treat non-dataflow vars as living past the end of the block in all cases
* Clarify notion of size in comment
* Remove commented-out code
* Assume any op that returns a tuple is returning a fresh one (exceptions
can be noted later)
* Add full structural equality check in large test case
* Fix parser roundtripping bug with call_tir_inplace
* Refactor tests to ensure maps are nonempty
* Use .empty() where it's more reasonable
* linting changes
* Flipped the check by accident
* Remove debug print
* Factor out data structure for representing matches and match opportunities
* Style fix
* Use the analyzer to handle dynamic cases too
* Whitespace
* Use BlockBuilder APIs more to avoid re-normalizing
* Check for expired vars at start of loop so that the use of continue does
not skip that step
---------
Co-authored-by: Eric Lunderberg <[email protected]>
---
include/tvm/relax/transform.h | 10 +
python/tvm/relax/__init__.py | 8 +-
python/tvm/relax/testing/transform.py | 98 +-
python/tvm/relax/transform/__init__.py | 1 +
python/tvm/relax/transform/transform.py | 16 +
src/relax/transform/dataflow_inplace.cc | 1040 ++++++++++++++++++++
src/script/printer/relax/call.cc | 19 +-
tests/python/relax/test_dataflow_inplace.py | 644 ++++++++++++
tests/python/relax/test_tvmscript_parser.py | 36 +
tests/python/relax/test_tvmscript_printer_relax.py | 25 +
10 files changed, 1894 insertions(+), 3 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 5376d99ee1..efe30e5cbb 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -572,6 +572,16 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2);
*/
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
+/*!
+ * \brief Pass that changes calls to operators that can be done in-place
+ * (generally, these are elementwise operations) in dataflow blocks into
in-place implementations.
+ * Supported operators will be replaced by calls to `call_tir_inplace` that
invoke in-place
+ * PrimFunc implementations of those operators (which are based on the
legalizations of those
+ * operators).
+ * \return The pass.
+ */
+TVM_DLL Pass DataflowUseInplaceCalls();
+
/*!
* \brief Automatic mixed precision pass. Currently the pass assumes the input
module to be fp32
* only, and will automatically cast fp32 to fp16 for certain ops.
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index 5bc0d6c56e..23cfaf2935 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -63,7 +63,13 @@ from .ty import (
from .exec_builder import ExecBuilder
# Operator
-from .op.base import call_tir, call_pure_packed, call_dps_packed,
call_tir_with_grad
+from .op.base import (
+ call_tir,
+ call_tir_inplace,
+ call_pure_packed,
+ call_dps_packed,
+ call_tir_with_grad,
+)
# BlockBuilder
from .block_builder import BlockBuilder
diff --git a/python/tvm/relax/testing/transform.py
b/python/tvm/relax/testing/transform.py
index ccae38a138..42dbd37d29 100644
--- a/python/tvm/relax/testing/transform.py
+++ b/python/tvm/relax/testing/transform.py
@@ -17,14 +17,18 @@
# pylint: disable=unused-argument, invalid-name, no-else-return,
abstract-method, arguments-differ
"""Relax transformation passes for testing"""
+import logging
+import os
+from typing import Dict, List, Set, Tuple
import tvm
from tvm import ir, relax
from tvm.ir import transform
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.relax import PyExprMutator
-from tvm.relax.expr import Call
+from tvm.relax.expr import Call, DataflowBlock, Var
from tvm.relay.backend.te_compiler import select_implementation
+from tvm.runtime.object import Object
from tvm.target import Target
@@ -128,3 +132,95 @@ class LowerWithRelayOpStrategyPass(transform.Pass):
def ApplyEmptyCppMutator() -> tvm.ir.transform.Pass:
packed_func =
tvm.get_global_func("relax.testing.transform.ApplyEmptyCppMutator")
return packed_func()
+
+
+def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int,
int]]:
+ """
+ Inner function for the dataflow inplace transformation exposed for testing.
+ """
+ if "PYTEST_CURRENT_TEST" not in os.environ:
+ logging.warning("The function dataflow_liveness_analysis is exposed
for testing only.")
+
+ live_ranges =
tvm.get_global_func("relax.testing.transform.DataflowLivenessAnalysis")(
+ block
+ ) # type: ignore
+ ret = {}
+ for var, live_range in live_ranges.items():
+ ret[var] = tuple(live_range)
+ return ret # type: ignore
+
+
+def dataflow_alias_analysis(
+ block: DataflowBlock, inputs: List[Var]
+) -> Tuple[Dict[Var, Set[int]], Dict[int, List[Set[int]]]]:
+ """
+ Inner function for the dataflow inplace transformation exposed for testing.
+ """
+ if "PYTEST_CURRENT_TEST" not in os.environ:
+ logging.warning("The function dataflow_alias_analysis is exposed for
testing only.")
+
+ alias_sets, tuple_map =
tvm.get_global_func("relax.testing.transform.DataflowAliasAnalysis")(
+ block,
+ inputs,
+ ) # type: ignore
+ res_alias_sets = {}
+ res_tuple_map = {}
+ for var, alias_set in alias_sets.items():
+ res_alias_sets[var] = set(alias_set)
+ for idx, elem_alias_sets in tuple_map.items():
+ res_tuple_map[idx] = [set(alias_set) for alias_set in elem_alias_sets]
+ return res_alias_sets, res_tuple_map # type: ignore
+
+
+@tvm._ffi.register_object("relax.transform.InplaceOpportunity")
+class InplaceOpportunity(Object):
+ """
+ Represents an opportunity to make a binding in-place. Exposed only for
testing;
+ the constructor is not exposed.
+
+ Parameters:
+ -----------
+ binding_idx: int
+ Index of the binding within its block
+
+ arg_idxs: List[int]
+ Indices of arguments that are eligible to be used as in-place targets.
+ """
+
+ def __init__(self, _binding_idx, _arg_idxs):
+ raise NotImplementedError("Constructor for InplaceOpportunity not
exposed!")
+
+
+def dataflow_inplace_analysis(
+ block: DataflowBlock, inputs: List[Var], mod: IRModule
+) -> Tuple[List[Tuple[int, Set[int]]], List[Tuple[int, Set[int]]]]:
+ """
+ Inner function for the dataflow inplace transformation exposed for testing.
+ """
+ if "PYTEST_CURRENT_TEST" not in os.environ:
+ logging.warning("The function dataflow_inplace_analysis is exposed for
testing only.")
+ index_lists =
tvm.get_global_func("relax.testing.transform.DataflowInplaceAnalysis")(
+ block, inputs, mod
+ ) # type: ignore
+
+ def convert(opp_list):
+ return list(map(lambda opp: (int(opp.binding_idx), set(map(int,
opp.arg_idxs))), opp_list))
+
+ return (convert(index_lists[0]), convert(index_lists[1])) # type: ignore
+
+
+def dataflow_single_inplace_call(
+ mod: IRModule, call: Call, inplace_indices: List[int]
+) -> Tuple[Call, IRModule]:
+ """
+ Inner function for the dataflow inplace transformation exposed for testing.
+ """
+ if "PYTEST_CURRENT_TEST" not in os.environ:
+ logging.warning("The function dataflow_single_inplace_call is exposed
for testing only.")
+
+ ret = tvm.get_global_func("relax.testing.transform.SingleInplaceCall")(
+ mod,
+ call,
+ inplace_indices,
+ ) # type: ignore
+ return (ret[0], ret[1]) # type: ignore
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index 9c2d02bf36..2cb6cd6b32 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -31,6 +31,7 @@ from .transform import (
ConvertLayout,
ConvertToDataflow,
DataflowBlockPass,
+ DataflowUseInplaceCalls,
DeadCodeElimination,
DecomposeOpsForInference,
DecomposeOpsForTraining,
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index c49e1fd13c..268210549e 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -252,6 +252,22 @@ def RemovePurityChecking() -> tvm.ir.transform.Pass:
return _ffi_api.RemovePurityChecking() # type: ignore
+def DataflowUseInplaceCalls() -> tvm.ir.transform.Pass:
+ """
+ Pass that changes calls to operators that can be done in-place
+ (generally, these are elementwise operations) into in-place
implementations.
+ Supported operators will be replaced by calls to `call_tir_inplace` that
invoke
+ in-place PrimFunc implementations of those operators (which are based on
the legalizations of
+ those operators).
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ The pass
+ """
+ return _ffi_api.DataflowUseInplaceCalls()
+
+
def LambdaLift() -> tvm.ir.transform.Pass:
"""A pass that lifts local functions into global.
diff --git a/src/relax/transform/dataflow_inplace.cc
b/src/relax/transform/dataflow_inplace.cc
new file mode 100644
index 0000000000..755c5dbab4
--- /dev/null
+++ b/src/relax/transform/dataflow_inplace.cc
@@ -0,0 +1,1040 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/dataflow_inplace.cc
+ * \brief Pass that converts eligible operator calls in dataflow blocks
+ * into in-place versions.
+ */
+
+#include <tvm/ir/transform.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/op.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+// Perform liveness analysis on a dataflow block, returning a map of vars to
+// pairs of indices (the liveness interval, from the starting index to the end
index).
+// A starting index of -1 means the var is defined before the block starts and
an end index
+// of block->bindings.size() (one past the last index) means it is live after
the block ends.
+std::unordered_map<Var, std::pair<int, int>, ObjectPtrHash, ObjectPtrEqual>
AnalyzeLiveness(
+ const DataflowBlock& block) {
+ std::unordered_map<Var, std::pair<int, int>, ObjectPtrHash, ObjectPtrEqual>
ret;
+ for (int i = block->bindings.size() - 1; i >= 0; i--) {
+ Binding b = block->bindings[i];
+ Var defined_var = b->var;
+ Expr value = GetBoundValue(b);
+ Array<Var> used_vars;
+ // for a function literal, we consider only the free vars
+ // (those captured from the outer scope)
+ if (value.as<FunctionNode>()) {
+ used_vars = FreeVars(value);
+ } else if (value.as<TupleGetItemNode>()) {
+ // Special case: we do not consider a tuple index to be a "use."
+ // This is a bit of a hack but allows us to do operations that
+ // create tuples to be done in-place (otherwise, any index of the tuple
+ // would be considered a use and so the tuple would be live later).
+ // Hence we keep the array empty.
+ } else {
+ used_vars = AllVars(value);
+ }
+
+ for (auto var : used_vars) {
+ int range_end = i;
+ // if the var is not a dataflow var, then it is live
+ // after the block (we are not checking later blocks)
+ if (!var.as<DataflowVarNode>()) {
+ range_end = block->bindings.size();
+ }
+ if (!ret.count(var)) {
+ ret[var] = {-1, range_end};
+ }
+ }
+
+ if (!ret.count(defined_var)) {
+ // if it's an output, then it lives past the end of the block
+ if (!defined_var.as<DataflowVarNode>()) {
+ ret[defined_var] = {i, block->bindings.size()};
+ } else {
+ // otherwise, it's live only here
+ ret[defined_var] = {i, i};
+ }
+ } else {
+ // this means the var is used later but we encountered its definition now
+ auto last_range = ret[defined_var];
+ CHECK_EQ(last_range.first, -1);
+ std::pair<int, int> new_range = {i, last_range.second};
+ ret[defined_var] = new_range;
+ }
+ }
+ return ret;
+}
+
+class AliasAnalyzer {
+ public:
+ AliasAnalyzer() : alias_map_(), tuple_map_(), mem_idx_(0) {}
+
+ // The analysis returns a map of vars to memory locations that it *could*
map to
+ // (any unique allocation = one memory location), plus a map of memory
locations
+ // that correspond to tuples (this maps to sets of memory locations for each
tuple element).
+ // Note: inputs are values that should be assumed not to be aliased and are
therefore
+ // (in the case of in-place ops) safe to overwrite. This may not be true of
function args.
+ std::pair<std::unordered_map<Var, std::unordered_set<int>, ObjectPtrHash,
ObjectPtrEqual>,
+ std::unordered_map<int, std::vector<std::unordered_set<int>>>>
+ Analyze(const DataflowBlock& block, const Array<Var>& inputs) {
+ for (auto input : inputs) {
+ int curr_idx = get_fresh_idx();
+ alias_map_[input] = {curr_idx};
+ if (auto* tup_info = GetStructInfoAs<TupleStructInfoNode>(input)) {
+ InsertFreshTuple(curr_idx, tup_info);
+ }
+ }
+
+ for (const Binding& binding : block->bindings) {
+ Var current_var = binding->var;
+ Expr value = GetBoundValue(binding);
+ alias_map_[current_var] = GetAliasSet(value, current_var);
+ }
+
+ return {alias_map_, tuple_map_};
+ }
+
+ private:
+ int get_fresh_idx() {
+ int ret = mem_idx_;
+ mem_idx_++;
+ return ret;
+ }
+
+ // Fresh tuple = each element is assumed to be a unique allocation
+ void InsertFreshTuple(int tup_idx, const TupleStructInfoNode* tup_info) {
+ std::vector<std::unordered_set<int>> tuple_set;
+ for (int i = 0; i < static_cast<int>(tup_info->fields.size()); i++) {
+ int curr_field = get_fresh_idx();
+ tuple_set.push_back({curr_field});
+ if (auto* nested_tup_info =
tup_info->fields[i].as<TupleStructInfoNode>()) {
+ InsertFreshTuple(curr_field, nested_tup_info);
+ }
+ }
+ tuple_map_[tup_idx] = tuple_set;
+ }
+
+ // given a tuple index, add the given memory location indices to each
component's
+ // alias set
+ void UpdateTupleComponents(int tup_idx, const std::unordered_set<int>&
insert_idxs) {
+ if (tuple_map_.count(tup_idx)) {
+ auto tuple_comps = tuple_map_[tup_idx];
+ for (size_t i = 0; i < tuple_comps.size(); i++) {
+ auto comp_set = tuple_comps[i];
+
+ // if a member is a tuple, update its components as well
+ for (int member : comp_set) {
+ if (tuple_map_.count(member)) {
+ UpdateTupleComponents(member, insert_idxs);
+ }
+ }
+
+ // update after iterating to avoid iterating over the inserted elements
+ tuple_map_[tup_idx][i].insert(insert_idxs.begin(), insert_idxs.end());
+ }
+ }
+ }
+
+ // capture the given index and also its tuple components (including
recursively)
+ // if they exist
+ void AddCapturedIndices(std::unordered_set<int>* captured_set, int idx) {
+ captured_set->insert(idx);
+ if (tuple_map_.count(idx)) {
+ for (auto comp_set : tuple_map_[idx]) {
+ for (auto tup_comp_idx : comp_set) {
+ AddCapturedIndices(captured_set, tup_comp_idx);
+ }
+ }
+ }
+ }
+
+ // Conservative extremely pessimistic assumption:
+ // assume that the result of a non-op call can be aliased to any argument
+ // or that it could be a newly allocated value.
+ // For tuples, assume all members are aliased. Yeah, it's bad.
+ // (Skip first arg is for handling call_pure_packed, where the first arg is
an ExternFunc that we
+ // should ignore)
+ std::unordered_set<int> HandleMysteryCall(const CallNode* call_node, const
Var& bound_var,
+ bool skip_first_arg = false) {
+ // the result may or may not be newly allocated
+ std::unordered_set<int> ret;
+ int res_idx = get_fresh_idx();
+ // the result may be a tuple
+ if (auto* tup_info_node = GetStructInfoAs<TupleStructInfoNode>(bound_var))
{
+ InsertFreshTuple(res_idx, tup_info_node);
+ }
+ AddCapturedIndices(&ret, res_idx);
+
+ for (size_t i = (skip_first_arg) ? 1 : 0; i < call_node->args.size(); i++)
{
+ auto arg = call_node->args[i];
+ auto arg_alias_set = GetAliasSet(arg, bound_var);
+ for (int alias_idx : arg_alias_set) {
+ AddCapturedIndices(&ret, alias_idx);
+ }
+ }
+ // if the result is a tuple, the components can also potentially be
aliased to any arg
+ // or, in fact, to each other
+ UpdateTupleComponents(res_idx, ret);
+ return ret;
+ }
+
+ // given the expression value, return the set of memory locations
corresponding to it
+ // (the var the expression is being bound to is needed for struct info)
+ std::unordered_set<int> GetAliasSet(const Expr& value, const Var& bound_var)
{
+ std::unordered_set<int> ret;
+
+ // cases for value:
+ // constant: it's a fresh index
+ // var: look up in alias map (-1 if not present)
+ // op call: assume it's fresh (may need to make list of exceptions)
+ // tuple: fresh entry in tuple index, recurse to determine indices for
values
+ // function/packed call: chaos reigns, alias with any other argument
+ // (if tuple is passed, assume also aliased with all members of the
tuple)
+ // tuple index: -1 if tuple is not in tuple map, otherwise look up
corresponding entry
+ // function constant: give them a fresh index (TODO: we can handle in more
detail if this is a
+ // case we need to support) prim value: fresh index if node: should not
happen inside dataflow
+ // block
+ if (value.as<ConstantNode>() || value.as<PrimValueNode>() ||
value.as<FunctionNode>()) {
+ // TODO(@slyubomirsky): We will probably want special handling for
closures
+ ret.insert(get_fresh_idx());
+ } else if (auto* target_var_node = value.as<VarNode>()) {
+ auto target_var = GetRef<Var>(target_var_node);
+ if (alias_map_.count(target_var)) {
+ ret.insert(alias_map_[target_var].begin(),
alias_map_[target_var].end());
+ } else {
+ ret.insert(-1);
+ }
+ } else if (auto* target_tuple = value.as<TupleNode>()) {
+ // fresh idx but we update the tuple map
+ int tup_idx = get_fresh_idx();
+ ret.insert(tup_idx);
+ std::vector<std::unordered_set<int>> new_tuple_map;
+ for (auto field : target_tuple->fields) {
+ new_tuple_map.push_back(GetAliasSet(field, bound_var));
+ }
+ tuple_map_[tup_idx] = new_tuple_map;
+ } else if (auto* target_tgi = value.as<TupleGetItemNode>()) {
+ std::unordered_set<int> tuple_set = GetAliasSet(target_tgi->tuple,
bound_var);
+ // if -1 is a member of the tuple set, then we have to assume the result
is -1
+ if (tuple_set.count(-1)) {
+ ret.insert(-1);
+ } else {
+ // otherwise, consider all members that are tuples of appropriate size
and index into them
+ // (this is safe because the type system will ensure we're not
indexing into a tuple
+ // of the wrong size)
+ for (int member : tuple_set) {
+ if (tuple_map_.count(member) &&
+ static_cast<int>(tuple_map_[member].size()) > target_tgi->index)
{
+ auto member_set = tuple_map_[member][target_tgi->index];
+ ret.insert(member_set.begin(), member_set.end());
+ }
+ }
+ }
+ } else if (auto* call_node = value.as<CallNode>()) {
+ if (auto* op_node = call_node->op.as<OpNode>()) {
+ // call_pure_packed: treat as non-op call
+ if (op_node->name == "relax.call_pure_packed") {
+ return HandleMysteryCall(call_node, bound_var, true);
+ } else if (op_node->name == "relax.call_tir") {
+ // call_tir: can potentially return a tuple
+ if (auto* tuple_struct_info =
call_node->sinfo_args[0].as<TupleStructInfoNode>()) {
+ int tup_idx = get_fresh_idx();
+ ret.insert(tup_idx);
+ InsertFreshTuple(tup_idx, tuple_struct_info);
+ } else {
+ ret.insert(get_fresh_idx());
+ }
+ } else {
+ // We are assuming most op calls return fresh values.
+ // We may have to track more exceptions
+
+ // If the returned value is a tuple, we'll assume it's a fresh tuple
+ // (there may be exceptions to this too)
+ if (auto* tup_info =
GetStructInfoAs<TupleStructInfoNode>(bound_var)) {
+ int tup_idx = get_fresh_idx();
+ ret.insert(tup_idx);
+ InsertFreshTuple(tup_idx, tup_info);
+ return ret;
+ }
+ ret.insert(get_fresh_idx());
+ }
+ } else {
+ // assume any non-op call can be extremely dangerous and do anything
+ return HandleMysteryCall(call_node, bound_var);
+ }
+ }
+
+ return ret;
+ }
+
+ std::unordered_map<Var, std::unordered_set<int>, ObjectPtrHash,
ObjectPtrEqual> alias_map_;
+ std::unordered_map<int, std::vector<std::unordered_set<int>>> tuple_map_;
+ int mem_idx_;
+};
+
+// given a shape, return the number of elements corresponding to it (product
of elements)
+PrimExpr NumElements(const ShapeExpr& shape) {
+ PrimExpr ret = IntImm(DataType::Int(64), 1);
+ for (auto dim : shape->values) {
+ ret *= dim;
+ }
+ return ret;
+}
+
+// Given the struct info of the result, return any struct info nested in it
+// that is eleigible to be used for in-place computations (tensors are eligible
+// only if all their dimensions are integer constants, tuples are eligible if
+// all members are eligible though we can consider only individual members
separately)
+std::unordered_set<StructInfo, ObjectPtrHash, ObjectPtrEqual>
GatherCandidateSinfo(
+ const StructInfo& result_sinfo) {
+ if (auto* tensor_info = result_sinfo.as<TensorStructInfoNode>()) {
+ // don't consider void dtype (don't know the size at compile time)
+ if (tensor_info->dtype.is_void()) {
+ return {};
+ }
+ // don't consider cases where we don't know the shape at compile time
+ // (we will use the analyzer to do best-effort analysis where there are
vars)
+ if (tensor_info->shape.as<ShapeExprNode>()) {
+ return {GetRef<TensorStructInfo>(tensor_info)};
+ } else {
+ return {};
+ }
+ } else if (auto* tuple_info = result_sinfo.as<TupleStructInfoNode>()) {
+ // we can see if the whole tuple matches or go for any of the components
+ std::unordered_set<StructInfo, ObjectPtrHash, ObjectPtrEqual> ret;
+ for (auto field : tuple_info->fields) {
+ auto field_candidates = GatherCandidateSinfo(field);
+ ret.insert(field_candidates.begin(), field_candidates.end());
+ }
+ // at least one field should be eligible to be done in-place
+ if (!ret.empty()) {
+ ret.insert(GetRef<StructInfo>(tuple_info));
+ }
+ return ret;
+ } else {
+ // don't consider any other types
+ return {};
+ }
+}
+
+// Given the two struct info, return a pair of bools where the first element
is true if
+// the two struct info have the same number of elements and dtype and the
second element is true
+// if the shapes match _exactly_. Performs this check recursively and ensures
the
+// stated condition is true for all tensor members of the struct info (return
false
+// if a single pair of corresponding tensors does not meet the condition).
+std::pair<bool, bool> SizeMatches(const StructInfo& target_info, const
StructInfo& arg_info,
+ const BlockBuilder& ctx) {
+ if (target_info.as<TensorStructInfoNode>() &&
arg_info.as<TensorStructInfoNode>()) {
+ auto target_tensor = Downcast<TensorStructInfo>(target_info);
+ auto arg_tensor = Downcast<TensorStructInfo>(arg_info);
+ if (target_tensor->shape.defined() &&
target_tensor->shape.as<ShapeExprNode>() &&
+ arg_tensor->shape.defined() && arg_tensor->shape.as<ShapeExprNode>()) {
+ if (target_tensor->dtype != arg_tensor->dtype) {
+ return {false, false};
+ }
+ auto target_shape = Downcast<ShapeExpr>(target_tensor->shape);
+ auto arg_shape = Downcast<ShapeExpr>(arg_tensor->shape);
+ PrimExpr target_size = NumElements(target_shape);
+ PrimExpr arg_size = NumElements(arg_shape);
+ if (!ctx->GetAnalyzer()->CanProve(arg_size >= target_size)) {
+ return {false, false};
+ }
+ // exact match: number of dims and each dim matches
+ if (target_shape->values.size() == arg_shape->values.size()) {
+ for (size_t i = 0; i < target_shape->values.size(); i++) {
+ if (!ctx->GetAnalyzer()->CanProveEqual(target_shape->values[i],
arg_shape->values[i])) {
+ return {true, false};
+ }
+ }
+ return {true, true};
+ }
+ return {true, false};
+ } else {
+ return {false, false};
+ }
+ } else if (target_info.as<TupleStructInfoNode>() &&
arg_info.as<TupleStructInfoNode>()) {
+ auto target_tup = Downcast<TupleStructInfo>(target_info);
+ auto arg_tup = Downcast<TupleStructInfo>(arg_info);
+ if (target_tup->fields.size() != arg_tup->fields.size()) {
+ return {false, false};
+ }
+ bool all_exact = true;
+ for (size_t i = 0; i < target_tup->fields.size(); i++) {
+ // if members aren't either tuples or tensors, simply skip them,
+ // since they don't matter for in-place computations
+ if (!(target_tup->fields[i].as<TensorStructInfoNode>() ||
+ target_tup->fields[i].as<TupleStructInfoNode>()) &&
+ !(arg_tup->fields[i].as<TensorStructInfoNode>() ||
+ arg_tup->fields[i].as<TupleStructInfoNode>())) {
+ continue;
+ }
+ auto [field_size_match, field_exact_match] =
+ SizeMatches(target_tup->fields[i], arg_tup->fields[i], ctx);
+ if (!field_size_match) {
+ return {false, false};
+ }
+ all_exact = all_exact && field_exact_match;
+ }
+ return {true, all_exact};
+ } else {
+ return {false, false};
+ }
+}
+
+// Given an alias index, check if it's a tuple and gather the sets of aliases
for the tuple
+// members if so (apply recursively if any of those members are tuples).
+// Return false if the alias set contains -1, meaning a reference to an
unknown or
+// possibly dangerous value (no checking we can do for that).
+bool GatherSetsToCheckForLiveness(
+ const std::unordered_map<Var, std::unordered_set<int>, ObjectPtrHash,
ObjectPtrEqual>&
+ alias_sets,
+ const std::unordered_map<int, std::vector<std::unordered_set<int>>>&
tuple_map,
+ std::vector<std::unordered_set<int>>* sets_to_check, int alias_idx) {
+ if (tuple_map.count(alias_idx)) {
+ for (auto member_set : tuple_map.at(alias_idx)) {
+ // contains -1 -> unknown and dangerous, we can short-circuit
+ if (member_set.count(-1)) {
+ return false;
+ }
+ sets_to_check->push_back(member_set);
+
+ // if a member can be a tuple, check it recursively
+ for (int member : member_set) {
+ if (tuple_map.count(member)) {
+ if (!GatherSetsToCheckForLiveness(alias_sets, tuple_map,
sets_to_check, member)) {
+ return false;
+ }
+ }
+ }
+ }
+ }
+ return true;
+}
+
+// Check that the target is not live past the index and that no alias of it is
live past the
+// binding index (if the target is a tuple, check the conditions recursively
for the members)
+bool InplaceConditionsMet(
+ const std::unordered_map<Var, std::pair<int, int>, ObjectPtrHash,
ObjectPtrEqual>& live_ranges,
+ const std::unordered_map<Var, std::unordered_set<int>, ObjectPtrHash,
ObjectPtrEqual>&
+ alias_sets,
+ const std::unordered_map<int, std::vector<std::unordered_set<int>>>&
tuple_map,
+ const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>&
currently_live,
+ const Expr& target, int binding_idx) {
+ if (auto* var_node = target.as<VarNode>()) {
+ auto current_var = GetRef<Var>(var_node);
+ // if the var is live past this point, we can't use it for in-place
computations anyway
+ if (live_ranges.count(current_var)) {
+ auto live_range = live_ranges.at(current_var);
+ if (live_range.second > binding_idx) {
+ return false;
+ }
+ }
+
+ // no entry for the current var -> it must be something external and we
have to assume the worst
+ if (!alias_sets.count(current_var)) {
+ return false;
+ }
+ auto alias_set = alias_sets.at(current_var);
+ // -1 -> an external value and we must assume the worst
+ if (alias_set.count(-1)) {
+ return false;
+ }
+ std::vector<std::unordered_set<int>> sets_to_check = {alias_set};
+ std::unordered_set<int> indices_checked;
+ // If a possible alias is a tuple, we will also check for aliases of the
members
+ // (possibly recursively)
+ for (int alias_idx : alias_set) {
+ if (!GatherSetsToCheckForLiveness(alias_sets, tuple_map, &sets_to_check,
alias_idx)) {
+ return false;
+ }
+ }
+
+ for (Var other_var : currently_live) {
+ if (other_var.same_as(target)) {
+ continue;
+ }
+ // not represented = spooky unknown value that should be modeled by -1
+ if (!alias_sets.count(other_var) || !live_ranges.count(other_var)) {
+ continue;
+ }
+ // var is not live past this point => don't need to worry
+ if (live_ranges.at(other_var).second <= binding_idx) {
+ continue;
+ }
+ auto other_alias_set = alias_sets.at(other_var);
+ for (int alias_idx : other_alias_set) {
+ for (auto check_set : sets_to_check) {
+ if (check_set.count(alias_idx)) {
+ return false;
+ }
+ }
+ }
+ }
+ return true;
+ } else if (auto* tup_node = target.as<TupleNode>()) {
+ for (auto field : tup_node->fields) {
+ if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map,
currently_live, field,
+ binding_idx)) {
+ return false;
+ }
+ }
+ return true;
+ } else {
+ return true;
+ }
+}
+
+// this is obviously not a complete list
+static std::unordered_set<std::string> SUPPORTED_OPS = {"relax.add",
"relax.subtract",
+ "relax.multiply",
"relax.divide",
+ "relax.nn.silu",
"relax.nn.relu"};
+bool OpSupportsInplace(const Op& op) { return SUPPORTED_OPS.count(op->name); }
+
+/*! \brief Corresponds to a binding where at least one argument meets the
conditions to be
+ * made in-place. Contains the binding index and indices of the applicable
arguments
+ */
+class InplaceOpportunityNode : public Object {
+ public:
+ // need to use Array for the benefit of the FFI
+ Integer binding_idx;
+ Array<Integer> arg_idxs;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("binding_idx", &binding_idx);
+ v->Visit("arg_idxs", &arg_idxs);
+ }
+
+ static constexpr const char* _type_key =
"relax.transform.InplaceOpportunity";
+ TVM_DECLARE_BASE_OBJECT_INFO(InplaceOpportunityNode, Object);
+};
+
+TVM_REGISTER_NODE_TYPE(InplaceOpportunityNode);
+
+class InplaceOpportunity : public ObjectRef {
+ public:
+ TVM_DLL InplaceOpportunity(const Integer& binding_idx, const Array<Integer>&
arg_idxs) {
+ auto node = make_object<InplaceOpportunityNode>();
+ node->binding_idx = binding_idx;
+ node->arg_idxs = arg_idxs;
+ data_ = std::move(node);
+ }
+
+ TVM_DEFINE_OBJECT_REF_METHODS(InplaceOpportunity, ObjectRef,
InplaceOpportunityNode);
+};
+
+// Check for in-place eligibility:
+// 1. see if there's an arg big enough to hold the result
+// 2. see if the arg is live past the call
+// 3. see if the arg has an alias that's live past the call
+// If the conditions are met, record the index of that binding.
+// Returns two lists of lists:
+// 1. A list of bindings where at least one argument meets the in-place
conditions and the *size*
+// matches the size of the result.
+// 2. A list of bindings where at least one argument meets the in-place
conditions
+// and *exactly* matches the shape of the result.
+// For both lists, each element is a list of ints of the following format:
+// The first element is the index of the *binding* in the block.
+// All remaining elements are the indices of *eligible arguments* in that
call.
+std::pair<std::vector<InplaceOpportunity>, std::vector<InplaceOpportunity>>
+FindInplaceOpportunities(const DataflowBlock& block, const Array<Var>& inputs,
+ const BlockBuilder& ctx) {
+ auto live_ranges = AnalyzeLiveness(block);
+ AliasAnalyzer analyzer;
+ auto alias_info = analyzer.Analyze(block, inputs);
+ auto alias_sets = alias_info.first;
+ auto tuple_map = alias_info.second;
+
+ std::vector<InplaceOpportunity> size_match_list;
+ std::vector<InplaceOpportunity> exact_match_list;
+
+ // sort the live ranges by starting index
+ std::vector<Var> live_order;
+ for (auto kv : live_ranges) {
+ live_order.push_back(kv.first);
+ }
+ std::sort(live_order.begin(), live_order.end(),
+ [&live_ranges](const Var& var1, const Var& var2) -> bool {
+ return live_ranges[var1].first < live_ranges[var2].first;
+ });
+
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> currently_live;
+ int last_live = 0;
+
+ for (size_t i = 0; i < block->bindings.size(); i++) {
+ // include all vars that are currently live
+ for (int j = last_live; j < static_cast<int>(live_order.size()); j++) {
+ auto live_var = live_order[j];
+ auto live_range = live_ranges[live_var];
+ if (live_range.first > static_cast<int>(i)) {
+ break;
+ }
+ currently_live.insert(live_var);
+ last_live++;
+ }
+ // remove vars whose range has come to an end
+ // (keep a separate set to avoid changing the set while iterating on it)
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> remove;
+ for (auto var : currently_live) {
+ auto live_range = live_ranges[var];
+ if (live_range.second < static_cast<int>(i)) {
+ remove.insert(var);
+ }
+ }
+ for (auto var : remove) {
+ currently_live.erase(var);
+ }
+
+ // if we reach a binding check the conditions
+ Binding b = block->bindings[i];
+ Var defined_var = b->var;
+ Expr value = GetBoundValue(b);
+
+ if (auto* call_node = value.as<CallNode>()) {
+ if (auto* op_node = call_node->op.as<OpNode>()) {
+ if (!OpSupportsInplace(GetRef<Op>(op_node))) {
+ continue;
+ }
+
+ std::unordered_set<int> candidates;
+ std::unordered_set<int> exact_match_candidates;
+
+ auto target_sinfo = GatherCandidateSinfo(GetStructInfo(defined_var));
+ // can't be done in-place, ignore
+ if (target_sinfo.empty()) {
+ continue;
+ }
+
+ // Check that at least one argument matches size with the result
+ for (size_t j = 0; j < call_node->args.size(); j++) {
+ auto arg = call_node->args[j];
+ for (auto target : target_sinfo) {
+ auto [matches_size, matches_exactly] = SizeMatches(target,
GetStructInfo(arg), ctx);
+ if (matches_size) {
+ candidates.insert(static_cast<int>(j));
+ if (matches_exactly) {
+ exact_match_candidates.insert(static_cast<int>(j));
+ }
+ }
+ }
+ }
+ if (candidates.empty()) {
+ continue;
+ }
+
+ // Make sure at least one candidate is not live past this point and
does not have an alias
+ // live past this point
+ std::unordered_set<int> remove_candidates;
+ for (auto candidate : candidates) {
+ if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map,
currently_live,
+ call_node->args[candidate], i)) {
+ remove_candidates.insert(candidate);
+ }
+ }
+ // (remove now to avoid modifying the list as we iterate on it)
+ for (auto candidate : remove_candidates) {
+ candidates.erase(candidate);
+ }
+
+ // if we have a candidate, then this can be made in-place. Report the
appropriate candidates
+ if (candidates.empty()) {
+ continue;
+ }
+
+ // produce a list of candidates for this index
+ Array<Integer> size_candidate_list;
+ for (auto candidate : candidates) {
+ size_candidate_list.push_back(Integer(candidate));
+ }
+ size_match_list.push_back(InplaceOpportunity(Integer(i),
size_candidate_list));
+
+ // also gather up the exact match candidates if there are any
+ Array<Integer> exact_candidate_list;
+ for (auto candidate : candidates) {
+ if (!exact_match_candidates.count(candidate)) {
+ continue;
+ }
+ exact_candidate_list.push_back(Integer(candidate));
+ }
+ if (exact_candidate_list.empty()) {
+ continue;
+ }
+ exact_match_list.push_back(InplaceOpportunity(Integer(i),
exact_candidate_list));
+ }
+ }
+ }
+
+ return {size_match_list, exact_match_list};
+}
+
+// Replace buffers in a PrimFunc according to the mapping.
+tir::Stmt RemapBuffers(const tir::Stmt& stmt, const Map<tir::Buffer,
tir::Buffer>& buffer_map) {
+ class BufferMapper : public tir::StmtExprMutator {
+ public:
+ explicit BufferMapper(const Map<tir::Buffer, tir::Buffer>& buffer_map)
+ : buffer_map_(buffer_map) {}
+
+ tir::Stmt Remap(const tir::Stmt& stmt) { return VisitStmt(stmt); }
+
+ PrimExpr VisitExpr_(const tir::BufferLoadNode* op) final {
+ auto node =
Downcast<tir::BufferLoad>(tir::StmtExprMutator::VisitExpr_(op));
+ auto* node_cow = node.CopyOnWrite();
+ node_cow->buffer = AttemptRemap(node->buffer);
+ return node;
+ }
+
+ tir::Stmt VisitStmt_(const tir::BufferStoreNode* op) final {
+ auto node =
Downcast<tir::BufferStore>(tir::StmtExprMutator::VisitStmt_(op));
+ auto* node_cow = node.CopyOnWrite();
+ node_cow->buffer = AttemptRemap(node->buffer);
+ return node;
+ }
+
+ tir::Stmt VisitStmt_(const tir::BufferRealizeNode* op) final {
+ auto node =
Downcast<tir::BufferRealize>(tir::StmtExprMutator::VisitStmt_(op));
+ auto* node_cow = node.CopyOnWrite();
+ node_cow->buffer = AttemptRemap(node->buffer);
+ return node;
+ }
+
+ tir::Stmt VisitStmt_(const tir::DeclBufferNode* op) final {
+ auto node =
Downcast<tir::DeclBuffer>(tir::StmtExprMutator::VisitStmt_(op));
+ auto* node_cow = node.CopyOnWrite();
+ node_cow->buffer = AttemptRemap(node->buffer);
+ return node;
+ }
+
+ tir::Stmt VisitStmt_(const tir::BlockNode* op) final {
+ auto node = Downcast<tir::Block>(tir::StmtExprMutator::VisitStmt_(op));
+ auto* node_cow = node.CopyOnWrite();
+ // need the lambdas because class methods are not first-class (how
ironic)
+ node_cow->alloc_buffers =
+ node->alloc_buffers.Map([this](const tir::Buffer& b) { return
AttemptRemap(b); });
+ node_cow->reads =
+ node->reads.Map([this](const tir::BufferRegion& br) { return
VisitBufferRegion(br); });
+ node_cow->writes =
+ node->writes.Map([this](const tir::BufferRegion& br) { return
VisitBufferRegion(br); });
+ node_cow->match_buffers = node->match_buffers.Map(
+ [this](const tir::MatchBufferRegion& mbr) { return
VisitMatchBufferRegion(mbr); });
+ return node;
+ }
+
+ private:
+ tir::Buffer AttemptRemap(const tir::Buffer& buffer) {
+ if (buffer_map_.count(buffer)) {
+ return buffer_map_.at(buffer);
+ }
+ return buffer;
+ }
+
+ tir::BufferRegion VisitBufferRegion(tir::BufferRegion region) {
+ auto* region_cow = region.CopyOnWrite();
+ region_cow->buffer = AttemptRemap(region_cow->buffer);
+ return region;
+ }
+
+ tir::MatchBufferRegion VisitMatchBufferRegion(tir::MatchBufferRegion
region) {
+ auto* region_cow = region.CopyOnWrite();
+ region_cow->buffer = AttemptRemap(region_cow->buffer);
+ return region;
+ }
+
+ const Map<tir::Buffer, tir::Buffer>& buffer_map_;
+ };
+
+ BufferMapper mapper(buffer_map);
+ auto ret = mapper.Remap(stmt);
+ return ret;
+}
+
+class ModuleInplaceTransformer : public ExprMutator {
+ public:
+ explicit ModuleInplaceTransformer(const IRModule& mod) : mod_(mod) {
+ builder_ = BlockBuilder::Create(mod);
+ }
+
+ IRModule Transform() {
+ // visit every Relax function in the module
+ for (auto kv : mod_->functions) {
+ if (auto* func_node = kv.second.as<FunctionNode>()) {
+ auto gv = kv.first;
+ auto func_params = func_node->params;
+ auto function =
Downcast<Function>(VisitExpr(GetRef<Function>(func_node)));
+ builder_->UpdateFunction(gv, function);
+ }
+ }
+
+ auto ret = builder_->GetContextIRModule();
+ // clean up to avoid polluting the IRModule
+ for (auto gv : legalizers_added) {
+ ret->Remove(gv);
+ }
+ return ret;
+ }
+
+ Expr VisitExpr_(const FunctionNode* op) override {
+ auto old_func_params = func_params;
+ func_params = op->params;
+ auto ret = ExprMutator::VisitExpr_(op);
+ func_params = old_func_params;
+ return ret;
+ }
+
+ // the only case we will override: we will visit all binding blocks
+ // and replace any valid calls in them
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override {
+ auto block = GetRef<DataflowBlock>(op);
+ auto old_idxs = inplace_idxs;
+
+ // For now, only handle exact match cases.
+ // Note: Not passing any input values for now, as we can't make any
assumptions
+ // about them.
+ auto matches_found = FindInplaceOpportunities(block, {}, builder_);
+ Map<Binding, Array<Integer>> new_idxs;
+ for (auto match : matches_found.second) {
+ new_idxs.Set(block->bindings[match->binding_idx.IntValue()],
match->arg_idxs);
+ }
+
+ inplace_idxs = new_idxs;
+ auto ret = ExprMutator::VisitBindingBlock_(op);
+ inplace_idxs = old_idxs;
+ return ret;
+ }
+
+ Expr ReplaceBoundCall(const Binding& binding) {
+ // can just pick the first index arbitrarily (only using one output for
now too)
+ // now replace the binding appropriately
+ auto arg_idxs = inplace_idxs.at(binding);
+ auto target = Downcast<Call>(GetBoundValue(binding));
+ auto new_call = CreateInplaceCall(target, {arg_idxs[0]});
+ return builder_->Normalize(new_call);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) override {
+ auto binding_ref = GetRef<VarBinding>(binding);
+ if (!inplace_idxs.count(binding_ref)) {
+ ExprMutator::VisitBinding_(binding);
+ return;
+ }
+ Expr new_value = ReplaceBoundCall(binding_ref);
+ builder_->EmitNormalized(VarBinding(binding->var, new_value,
binding->span));
+ }
+
+ void VisitBinding_(const MatchCastNode* binding) override {
+ auto binding_ref = GetRef<MatchCast>(binding);
+ if (!inplace_idxs.count(binding_ref)) {
+ ExprMutator::VisitBinding_(binding);
+ return;
+ }
+ Expr new_value = ReplaceBoundCall(binding_ref);
+ builder_->EmitNormalized(
+ MatchCast(binding->var, new_value, binding->struct_info,
binding->span));
+ }
+
+ // Given the call and indices of arguments that could be done in-place,
+ // replace the call with a call to an in-place PrimFunc.
+ // (Made public for testing.)
+ Call CreateInplaceCall(const Call& call, const Array<Integer>&
inplace_indices) {
+ static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
+ static const auto& call_tir_inplace_op = Op::Get("relax.call_tir_inplace");
+
+ auto op = Downcast<Op>(call->op);
+ auto legalized_call = Downcast<Call>(legalize_map[op](builder_, call));
+ auto* legalized_call_cow = legalized_call.CopyOnWrite();
+
+ // The legalized call should be call_tir. We will replace it with
call_tir_inplace
+ // and replace the called PrimFunc with an inplace version
+ auto legal_op = Downcast<GlobalVar>(legalized_call->args[0]);
+ legalizers_added.push_back(legal_op);
+ auto inline_legal_op_name = legal_op->name_hint + "_inplace";
+
+ auto mod = builder_->GetContextIRModule();
+ auto legal_primfunc = Downcast<tir::PrimFunc>(mod->Lookup(legal_op));
+ auto* legal_primfunc_cow = legal_primfunc.CopyOnWrite();
+ size_t num_outs = inplace_indices.size();
+ size_t num_params = legal_primfunc->params.size();
+
+ // the replacement we must make:
+ // 1. For each output var, replace its corresponding buffers with the
corresponding inplace
+ // index
+ // var's buffers
+ // 2. For each output var, replace its instances with the corresponding
inplace index var
+ // 3. Do the same for the *buffer vars* corresponding to the output vars
+ // 4. Remove the output vars from the param list and buffer map
+ Map<tir::Buffer, tir::Buffer> buffer_subst_map;
+ Map<tir::Var, tir::Var> var_subst_map;
+ for (size_t i = 0; i < num_outs; i++) {
+ // we will substitute output i with the corresponding param indicated by
inplace indices
+ auto output_var = legal_primfunc->params[num_params - num_outs + i];
+ auto inplace_var = legal_primfunc->params[inplace_indices[i].IntValue()];
+ var_subst_map.Set(output_var, inplace_var);
+
+ // also do the same with the buffer vars
+ auto output_buffer = legal_primfunc->buffer_map.at(output_var);
+ auto inplace_buffer = legal_primfunc->buffer_map.at(inplace_var);
+ var_subst_map.Set(output_buffer->data, inplace_buffer->data);
+ buffer_subst_map.Set(output_buffer, inplace_buffer);
+ }
+
+ // apply substitutions
+ legal_primfunc_cow->body = RemapBuffers(legal_primfunc->body,
buffer_subst_map);
+ legal_primfunc_cow->body = tir::Substitute(
+ legal_primfunc->body, [&var_subst_map](const tir::Var& v) ->
Optional<PrimExpr> {
+ if (var_subst_map.count(v)) {
+ return var_subst_map.at(v);
+ }
+ return Optional<PrimExpr>();
+ });
+
+ // remove the now-unused outputs from the buffer map
+ auto buffer_map = legal_primfunc->buffer_map;
+ for (size_t i = 0; i < num_outs; i++) {
+ buffer_map.erase(legal_primfunc->params[num_params - num_outs + i]);
+ }
+ legal_primfunc_cow->buffer_map = buffer_map;
+
+ // now get rid of the last num_outputs arguments
+ // (couldn't do earlier or else it would have thrown off the indexing)
+ legal_primfunc_cow->params = Array<tir::Var>(
+ legal_primfunc->params.begin(), legal_primfunc->params.begin() +
(num_params - num_outs));
+
+ // note: this might be a good time to get rid of the old legalized
function, but we don't do it
+ // now because later ops might need the same one. Instead, we will clean
up at the end
+ auto new_gv = builder_->AddFunction(legal_primfunc, inline_legal_op_name);
+
+ // update the call (change the op, update the argument, change the attrs)
+ legalized_call_cow->op = call_tir_inplace_op;
+
+ Array<Expr> new_args(legalized_call->args.begin(),
legalized_call->args.end());
+ new_args.Set(0, new_gv);
+ legalized_call_cow->args = new_args;
+
+ ObjectPtr<CallTIRInplaceAttrs> attrs = make_object<CallTIRInplaceAttrs>();
+ attrs->inplace_indices = inplace_indices;
+ legalized_call_cow->attrs = Attrs(attrs);
+
+ return legalized_call;
+ }
+
+ // Made public for testing.
+ IRModule CurrentMod() { return builder_->GetContextIRModule(); }
+
+ private:
+ const IRModule& mod_;
+ // Keep track of legalizers we add so we can clean up at the end.
+ Array<GlobalVar> legalizers_added;
+ // The current function's params will be treated as non-aliased
+ // (we are assuming good behavior on the user's part).
+ Array<Var> func_params;
+ // map of eligible bindings to indices of arguments that can be used as the
in-place target
+ Map<Binding, Array<Integer>> inplace_idxs;
+};
+
+namespace transform {
+
+Map<Var, Array<Integer>> DataflowLivenessAnalysis(const DataflowBlock& block) {
+ auto liveness_ranges = AnalyzeLiveness(block);
+ Map<Var, Array<Integer>> ret;
+ for (auto kv : liveness_ranges) {
+ ret.Set(kv.first, {kv.second.first, kv.second.second});
+ }
+ return ret;
+}
+
+Array<ObjectRef> DataflowAliasAnalysis(const DataflowBlock& block, Array<Var>
inputs) {
+ AliasAnalyzer analyzer;
+ auto res = analyzer.Analyze(block, inputs);
+ auto alias_sets = res.first;
+ auto tuple_map = res.second;
+ Map<Var, Array<Integer>> new_alias_sets;
+ Map<Integer, Array<Array<Integer>>> new_tuple_map;
+ for (auto kv : alias_sets) {
+ Array<Integer> aliases;
+ for (auto alias : kv.second) {
+ aliases.push_back(alias);
+ }
+ new_alias_sets.Set(kv.first, aliases);
+ }
+ for (auto kv : tuple_map) {
+ Array<Array<Integer>> elem_aliases;
+ for (auto alias_set : kv.second) {
+ Array<Integer> dim_aliases;
+ for (auto alias : alias_set) {
+ dim_aliases.push_back(alias);
+ }
+ elem_aliases.push_back(dim_aliases);
+ }
+ new_tuple_map.Set(kv.first, elem_aliases);
+ }
+ return {new_alias_sets, new_tuple_map};
+}
+
+// this would be preferable to do as a dataflow block pass,
+// but the transformation adds new PrimFuncs, so it affects the module
+tvm::transform::Pass DataflowUseInplaceCalls() {
+ return tvm::transform::CreateModulePass(
+ [](const IRModule& mod, const PassContext& ctx) -> IRModule {
+ ModuleInplaceTransformer transformer(mod);
+ return transformer.Transform();
+ },
+ 0, "DataflowInsertInPlaceCalls", {}, false);
+}
+
+Array<Array<InplaceOpportunity>> DataflowInplaceAnalysis(const DataflowBlock&
block,
+ const Array<Var>&
inputs,
+ const IRModule& mod) {
+ auto index_lists = relax::FindInplaceOpportunities(block, inputs,
BlockBuilder::Create(mod));
+ return {Array<InplaceOpportunity>(index_lists.first.begin(),
index_lists.first.end()),
+ Array<InplaceOpportunity>(index_lists.second.begin(),
index_lists.second.end())};
+}
+
+// these are exposed only for testing
+TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowLivenessAnalysis")
+ .set_body_typed(DataflowLivenessAnalysis);
+TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowAliasAnalysis")
+ .set_body_typed(DataflowAliasAnalysis);
+TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowInplaceAnalysis")
+ .set_body_typed(DataflowInplaceAnalysis);
+TVM_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall")
+ .set_body_typed([](const IRModule& mod, const Call& call,
+ const Array<Integer>& inplace_indices) ->
Array<ObjectRef> {
+ ModuleInplaceTransformer transformer(mod);
+ auto ret_call = transformer.CreateInplaceCall(call, inplace_indices);
+ return Array<ObjectRef>{ret_call, transformer.CurrentMod()};
+ });
+
+// actually exposed
+TVM_REGISTER_GLOBAL("relax.transform.DataflowUseInplaceCalls")
+ .set_body_typed(DataflowUseInplaceCalls);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc
index 785dc6d963..ef9438350c 100644
--- a/src/script/printer/relax/call.cc
+++ b/src/script/printer/relax/call.cc
@@ -97,11 +97,13 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath&
n_p, const IRDocsifi
Optional<ExprDoc> PrintCallTIRDPSPacked(const relax::Call& n, const
ObjectPath& n_p,
const IRDocsifier& d) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
+ static const Op& call_tir_inplace_op = Op::Get("relax.call_tir_inplace");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
static const Op& call_tir_with_grad_op = Op::Get("relax.call_tir_with_grad");
static const Op& call_tir_local_view =
Op::Get("relax.dist.call_tir_local_view");
if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op) &&
- !n->op.same_as(call_tir_with_grad_op) &&
!n->op.same_as(call_tir_local_view)) {
+ !n->op.same_as(call_tir_with_grad_op) &&
!n->op.same_as(call_tir_local_view) &&
+ !n->op.same_as(call_tir_inplace_op)) {
return NullOpt;
}
ICHECK(n->args.size() == 2 || n->args.size() == 3);
@@ -135,6 +137,19 @@ Optional<ExprDoc> PrintCallTIRDPSPacked(const relax::Call&
n, const ObjectPath&
kwargs_values.push_back(d->AsDoc<ExprDoc>(o_sinfo, o_sinfo_p));
}
+ // for call_tir_inplace, we also need to include the inplace args
+ if (n->op.same_as(call_tir_inplace_op)) {
+ kwargs_keys.push_back("inplace_indices");
+ Array<ExprDoc> index_fields;
+ if (auto* call_tir_inplace_attrs =
n->attrs.as<relax::CallTIRInplaceAttrs>()) {
+ for (auto inplace_index : call_tir_inplace_attrs->inplace_indices) {
+ index_fields.push_back(
+ LiteralDoc::Int(inplace_index.IntValue(),
n_p->Attr("attrs")->Attr("inplace_indices")));
+ }
+ }
+ kwargs_values.push_back(ListDoc(index_fields));
+ }
+
// start of specially handling call_tir_with_grad
if (const auto* call_tir_with_grad_attrs =
n->attrs.as<relax::CallTIRWithGradAttrs>()) {
kwargs_keys.push_back("te_grad_name");
@@ -163,6 +178,8 @@ Optional<ExprDoc> PrintCallTIRDPSPacked(const relax::Call&
n, const ObjectPath&
return Relax(d, "dist.call_tir_local_view")->Call(args, kwargs_keys,
kwargs_values);
} else if (is_dtensor) {
return Relax(d, "dist.call_tir")->Call(args, kwargs_keys, kwargs_values);
+ } else if (n->op.same_as(call_tir_inplace_op)) {
+ return Relax(d, "call_tir_inplace")->Call(args, kwargs_keys,
kwargs_values);
} else {
return Relax(d, "call_tir")->Call(args, kwargs_keys, kwargs_values);
}
diff --git a/tests/python/relax/test_dataflow_inplace.py
b/tests/python/relax/test_dataflow_inplace.py
new file mode 100644
index 0000000000..8d5eb07c78
--- /dev/null
+++ b/tests/python/relax/test_dataflow_inplace.py
@@ -0,0 +1,644 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, Set, Tuple
+import tvm
+from tvm import relax, testing
+from tvm.relax.transform import DataflowUseInplaceCalls
+from tvm.relax.testing.transform import (
+ dataflow_liveness_analysis,
+ dataflow_alias_analysis,
+ dataflow_inplace_analysis,
+ dataflow_single_inplace_call,
+)
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def test_liveness_analysis():
+ @I.ir_module
+ class BasicLiveness:
+ @R.function
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ with R.dataflow():
+ y = R.const(1, dtype="int32")
+ z = R.add(x, y)
+ q = R.multiply(z, y)
+ p = R.add(z, q)
+ n = R.multiply(p, p)
+ R.output(n, p)
+ return n
+
+ block = BasicLiveness["main"].body.blocks[0]
+ live_ranges = dataflow_liveness_analysis(block)
+ expected_ranges = {
+ # x is live past the binding block
+ "x": (-1, 5),
+ "y": (0, 2),
+ "z": (1, 3),
+ "q": (2, 3),
+ # exposed though ultimately not used
+ "p": (3, 5),
+ "n": (4, 5),
+ }
+ actual_ranges = {var.name_hint: live_range for var, live_range in
live_ranges.items()}
+ assert actual_ranges == expected_ranges
+
+
+def test_alias_analysis_basic():
+ @I.ir_module
+ class BasicAliasAnalysis:
+ @R.function
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ with R.dataflow():
+ y = x # y is an alias of x
+ z = R.add(y, y) # fresh value
+ n = z # alias of z
+ R.output(n)
+ return n
+
+ block = BasicAliasAnalysis["main"].body.blocks[0]
+ alias_sets, tuple_map = dataflow_alias_analysis(block,
BasicAliasAnalysis["main"].params)
+ expected = {
+ "x": {0},
+ "y": {0},
+ "z": {1},
+ "n": {1},
+ }
+
+ for var, alias_set in alias_sets.items():
+ assert alias_set == expected[var.name_hint]
+ assert tuple_map == {}
+
+
+def test_alias_analysis_tuple():
+ @I.ir_module
+ class AliasesWithTuples:
+ @R.function
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ with R.dataflow():
+ y = R.const(1, dtype="int32")
+ t = (x, y)
+ a = t[0]
+ b = t[1]
+ c = t[0]
+ d = t[1]
+ u = t
+ e = t[0]
+ f = t[1]
+ z = R.add(c, d)
+ n = z
+ R.output(n)
+ return n
+
+ block = AliasesWithTuples["main"].body.blocks[0]
+ alias_sets, tuple_map = dataflow_alias_analysis(block,
AliasesWithTuples["main"].params)
+ expected = {
+ "x": {0},
+ "y": {1},
+ "t": {2},
+ "a": {0},
+ "b": {1},
+ "c": {0},
+ "d": {1},
+ "u": {2},
+ "e": {0},
+ "f": {1},
+ "z": {3},
+ "n": {3},
+ }
+
+ actual_alias_sets = {var.name_hint: alias_set for var, alias_set in
alias_sets.items()}
+ assert expected == actual_alias_sets
+ assert 2 in tuple_map
+ assert tuple_map[2] == [{0}, {1}]
+
+
+def test_alias_split():
+ @I.ir_module
+ class AliasSplit:
+ @R.function
+ def main(x: R.Tensor((60,), "int32")) -> R.Tensor((15,), "int32"):
+ with R.dataflow():
+ t = R.split(x, 4)
+ y = t[0]
+ z = t[1]
+ q = t[2]
+ p = t[3]
+ n = z
+ R.output(n)
+ return n
+
+ block = AliasSplit["main"].body.blocks[0]
+ alias_sets, tuple_map = dataflow_alias_analysis(block,
AliasSplit["main"].params)
+ expected = {
+ "x": {0},
+ "t": {1},
+ "y": {2},
+ "z": {3},
+ "q": {4},
+ "p": {5},
+ "n": {3},
+ }
+
+ actual_alias_sets = {var.name_hint: alias_set for var, alias_set in
alias_sets.items()}
+ assert expected == actual_alias_sets
+ assert len(tuple_map) == 1
+ assert 1 in tuple_map
+ assert tuple_map[1] == [{2}, {3}, {4}, {5}]
+
+
+def test_alias_call_tir():
+ # call TIR can yield either a single tensor or a tuple
+ @I.ir_module
+ class AliasCallTir:
+ @T.prim_func
+ def tir_id(x: T.handle, y: T.handle) -> None:
+ T.func_attr({"global_symbol": "tir_id"})
+ m = T.int32()
+ n = T.int32()
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (m, n))
+
+ for i, j in T.grid(m, n):
+ with T.block("id"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj]
+
+ @T.prim_func
+ def tir_id2(x: T.handle, y: T.handle, z: T.handle) -> None:
+ T.func_attr({"global_symbol": "tir_id"})
+ m = T.int32()
+ n = T.int32()
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (m, n))
+ C = T.match_buffer(z, (m, n))
+
+ for i, j in T.grid(m, n):
+ with T.block("id"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj]
+ C[vi, vj] = A[vi, vj]
+
+ @R.function
+ def main(x: R.Tensor((10, 10), "int32")) -> R.Tensor((10, 10),
"int32"):
+ with R.dataflow():
+ cls = AliasCallTir
+ y = R.call_tir(cls.tir_id, (x,), out_sinfo=R.Tensor((10, 10),
"int32"))
+ t = R.call_tir(
+ cls.tir_id2,
+ (y,),
+ out_sinfo=[R.Tensor((10, 10), "int32"), R.Tensor((10, 10),
"int32")],
+ )
+ z = y
+ p = t[0]
+ q = t[1]
+ u = t
+ m = u[0]
+ n = u[1]
+ v = n
+ R.output(v)
+ return v
+
+ block = AliasCallTir["main"].body.blocks[0]
+ alias_sets, tuple_map = dataflow_alias_analysis(block,
AliasCallTir["main"].params)
+ expected = {
+ "x": {0},
+ "y": {1},
+ "t": {2},
+ "z": {1},
+ "p": {3},
+ "q": {4},
+ "u": {2},
+ "m": {3},
+ "n": {4},
+ "v": {4},
+ }
+
+ actual_alias_sets = {var.name_hint: alias_set for var, alias_set in
alias_sets.items()}
+ assert expected == actual_alias_sets
+ assert len(tuple_map) == 1
+ assert 2 in tuple_map
+ assert tuple_map[2] == [{3}, {4}]
+
+
+def test_mystery_calls():
+ @I.ir_module
+ class AliasChaosCalls:
+ @R.function
+ def identity(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ return x
+
+ @R.function
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ with R.dataflow():
+ cls = AliasChaosCalls
+ y = cls.identity(x)
+ z = cls.identity(y)
+ m = R.const(1, dtype="int32")
+ n = R.const(2, dtype="int32")
+ t = (m, n)
+ a = R.call_pure_packed(
+ "chaos", t, sinfo_args=R.Tuple(R.Tensor((), "int32"),
R.Tensor((), "int32"))
+ )
+ b = a[0]
+ c = a[1]
+ R.output(c)
+ return c
+
+ block = AliasChaosCalls["main"].body.blocks[0]
+ alias_sets, tuple_map = dataflow_alias_analysis(block,
AliasChaosCalls["main"].params)
+ expected = {
+ "x": {0},
+ "y": {0, 1},
+ "z": {0, 1, 2},
+ "m": {3},
+ "n": {4},
+ "t": {5},
+ "a": {3, 4, 5, 6, 7, 8}, # either t or a fresh tuple
+ "b": {3, 4, 5, 6, 7, 8}, # the tuple components can be aliased to any
member...
+ "c": {3, 4, 5, 6, 7, 8}, # the tuple components can be aliased to any
member...
+ # (in principle, we can use type information to narrow down the
aliasing)
+ }
+
+ actual_alias_sets = {var.name_hint: alias_set for var, alias_set in
alias_sets.items()}
+ assert expected == actual_alias_sets
+ assert len(tuple_map) == 2
+ assert 5 in tuple_map
+ assert tuple_map[5] == [{3}, {4}]
+ assert 6 in tuple_map
+ assert tuple_map[6] == [{3, 4, 5, 6, 7, 8}, {3, 4, 5, 6, 7, 8}]
+
+
+def test_alias_external_value():
+ @I.ir_module
+ class AliasExternalValue:
+ @R.function
+ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+ y = R.const(1, dtype="int32") # not in DF block, treated as
external
+ t1 = (y, y) # not in DF block, treated as external
+ with R.dataflow():
+ z = y # mystery value
+ a = R.const(2, dtype="int32")
+ t2 = (z, a)
+ b = t2[0]
+ c = t1[1] # tuple index into external value
+ R.output(b)
+ return b
+
+ block = AliasExternalValue["main"].body.blocks[1]
+ alias_sets, tuple_map = dataflow_alias_analysis(block,
AliasExternalValue["main"].params)
+ expected = {
+ "x": {0},
+ "z": {-1},
+ "a": {1},
+ "t2": {2},
+ "b": {-1},
+ "c": {-1},
+ }
+
+ actual_alias_sets = {var.name_hint: alias_set for var, alias_set in
alias_sets.items()}
+ assert expected == actual_alias_sets
+ assert len(tuple_map) == 1
+ assert 2 in tuple_map
+ assert tuple_map[2] == [{-1}, {1}]
+
+
+def test_inplace_simple_case():
+ @I.ir_module
+ class InplaceBasic:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")
+ ) -> R.Tensor((2, 3), "int32"):
+ with R.dataflow():
+ z = R.add(x, y) # cannot be done inplace: x and y are live
later
+ p = R.add(z, z) # can be done inplace: z is not used later
+ r = p # alias of p
+ m = R.multiply(p, p) # p is not used later but r is, so can't
do inplace
+ n = R.add(m, r) # can be done inplace: r is not used again
+ ret = R.subtract(n, m) # can be done inplace: neither is used
again
+ R.output(ret)
+ return ret
+
+ block = InplaceBasic["main"].body.blocks[0]
+ size_match, exact_match = dataflow_inplace_analysis(
+ block, InplaceBasic["main"].params, InplaceBasic
+ )
+
+ # order does not matter for the listing of candidates, so we have to
implement as sets
+ def assert_candidate_list(
+ actual: List[Tuple[int, Set[int]]], expected: List[Tuple[int,
Set[int]]]
+ ) -> None:
+ assert len(actual) == len(expected)
+ for i in range(len(actual)):
+ assert actual[i][0] == expected[i][0]
+ assert len(expected[i][1]) == len(actual[i][1])
+ for idx in actual[i][1]:
+ assert idx in expected[i][1]
+
+ assert_candidate_list(size_match, [(1, {0, 1}), (4, {1}), (5, {0, 1})])
+ # TODO(@slyubomirsky): I couldn't think of an easy example where sizes
don't match,
+ # but broadcasting might cause it to happen
+ assert_candidate_list(exact_match, [(1, {0, 1}), (4, {1}), (5, {0, 1})])
+
+
+def test_inplace_single_call():
+ @I.ir_module
+ class TestModule:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ z = R.add(x, y)
+ q = R.nn.silu(z)
+ return q
+
+ add_call = TestModule["main"].body.blocks[0].bindings[0].value
+ new_add, new_mod = dataflow_single_inplace_call(TestModule, add_call, [0])
+
+ @T.prim_func(private=True)
+ def expected_add(
+ A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
+ T.writes(A[v_ax0, v_ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
+
+ tvm.ir.assert_structural_equal(new_mod["add_inplace"], expected_add)
+ assert new_add.op.name == "relax.call_tir_inplace"
+ assert new_add.args[0].name_hint == "add_inplace"
+ for i, arg in enumerate(new_add.args[1].fields):
+ arg == add_call.args[i]
+ new_add.attrs.inplace_indices == [0]
+
+ @T.prim_func(private=True)
+ def expected_silu(A: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ compute = T.alloc_buffer((T.int64(2), T.int64(3)))
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(A[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.sigmoid(A[v_i0, v_i1])
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1], compute[v_ax0, v_ax1])
+ T.writes(A[v_ax0, v_ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1] * compute[v_ax0, v_ax1]
+
+ silu_call = TestModule["main"].body.blocks[0].bindings[1].value
+ new_silu, new_mod = dataflow_single_inplace_call(TestModule, silu_call,
[0])
+
+ tvm.ir.assert_structural_equal(new_mod["silu_inplace"], expected_silu)
+ assert new_silu.op.name == "relax.call_tir_inplace"
+ assert new_silu.args[0].name_hint == "silu_inplace"
+ for i, arg in enumerate(new_silu.args[1].fields):
+ arg == silu_call.args[i]
+ new_silu.attrs.inplace_indices == [0]
+
+
+def test_insert_inplace_calls():
+ @I.ir_module
+ class EndToEndTest:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1, 3),
dtype="float32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ with R.dataflow():
+ z = R.add(x, y) # broadcast happens here
+ # Cannot be done in-place because x is an argument.
+ a = R.add(z, y) # this one can be done in-place
+ q = R.multiply(a, y) # broadcast again, a is eligible
+ r = R.subtract(y, y) # cannot be done in-place because y is
an argument
+ s = R.subtract(r, r) # No broadcast. Can be done in-place
+ m = R.multiply(q, s) # should give us all zeros
+ R.output(m)
+ return m
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def add_inplace(
+ A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(1), T.int64(3)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1], B[T.int64(0), v_ax1])
+ T.writes(A[v_ax0, v_ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[T.int64(0), v_ax1]
+
+ @T.prim_func(private=True)
+ def multiply_inplace(
+ A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(1), T.int64(3)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1], B[T.int64(0), v_ax1])
+ T.writes(A[v_ax0, v_ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[T.int64(0), v_ax1]
+
+ @T.prim_func(private=True)
+ def subtract_inplace(
+ A: T.Buffer((T.int64(1), T.int64(3)), "float32"),
+ B: T.Buffer((T.int64(1), T.int64(3)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(3)):
+ with T.block("T_subtract"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
+ T.writes(B[v_ax0, v_ax1])
+ B[v_ax0, v_ax1] = A[v_ax0, v_ax1] - B[v_ax0, v_ax1]
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1, 3),
dtype="float32")
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ z: R.Tensor((2, 3), dtype="float32") = R.add(x, y)
+ a: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace(
+ cls.add_inplace,
+ (z, y),
+ inplace_indices=[0],
+ out_sinfo=[
+ R.Tensor((2, 3), dtype="float32"),
+ ],
+ )
+ q: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace(
+ cls.multiply_inplace,
+ (a, y),
+ inplace_indices=[0],
+ out_sinfo=[
+ R.Tensor((2, 3), dtype="float32"),
+ ],
+ )
+ r: R.Tensor((1, 3), dtype="float32") = R.subtract(y, y)
+ s: R.Tensor((1, 3), dtype="float32") = R.call_tir_inplace(
+ cls.subtract_inplace,
+ (r, r),
+ inplace_indices=[1],
+ out_sinfo=[
+ R.Tensor((1, 3), dtype="float32"),
+ ],
+ )
+ m: R.Tensor((2, 3), dtype="float32") = R.call_tir_inplace(
+ cls.multiply_inplace,
+ (q, s),
+ inplace_indices=[0],
+ out_sinfo=[
+ R.Tensor((2, 3), dtype="float32"),
+ ],
+ )
+ R.output(m)
+ return m
+
+ transform_pass = DataflowUseInplaceCalls()
+ new_mod = transform_pass(EndToEndTest)
+ tvm.ir.assert_structural_equal(new_mod, Expected)
+
+ x = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+ y = tvm.nd.array(np.random.rand(1, 3).astype("float32"))
+ expected = np.zeros((2, 3), dtype="float32")
+
+ target = tvm.target.Target("llvm")
+ ex = relax.build(new_mod, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ res = vm["main"](x, y)
+ assert (expected == res.numpy()).all()
+
+
+def test_dynamic():
+ @I.ir_module
+ class DynamicTestCase:
+ @R.function
+ def main(
+ x: R.Tensor(("a", "b"), dtype="float32"), y: R.Tensor(("a", "b"),
dtype="float32")
+ ) -> R.Tensor(("a", "b"), dtype="float32"):
+ with R.dataflow():
+ z = R.add(x, y)
+ # Cannot be done in-place because x and y are arguments
+ a = R.add(z, y) # this one can be done in-place
+ s = R.subtract(a, a) # No broadcast. Can be done in-place
+ R.output(s)
+ return s
+
+ # the result should be all zeroes
+ transform_pass = DataflowUseInplaceCalls()
+ new_mod = transform_pass(DynamicTestCase)
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def add_inplace(var_A: T.handle, var_B: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ a, b = T.int64(), T.int64()
+ A = T.match_buffer(var_A, (a, b))
+ B = T.match_buffer(var_B, (a, b))
+ for ax0, ax1 in T.grid(a, b):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
+ T.writes(A[v_ax0, v_ax1])
+ A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
+
+ @T.prim_func(private=True)
+ def subtract_inplace(var_A: T.handle, var_B: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ a, b = T.int64(), T.int64()
+ A = T.match_buffer(var_A, (a, b))
+ B = T.match_buffer(var_B, (a, b))
+ for ax0, ax1 in T.grid(a, b):
+ with T.block("T_subtract"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
+ T.writes(B[v_ax0, v_ax1])
+ B[v_ax0, v_ax1] = A[v_ax0, v_ax1] - B[v_ax0, v_ax1]
+
+ @R.function
+ def main(
+ x: R.Tensor(("a", "b"), dtype="float32"), y: R.Tensor(("a", "b"),
dtype="float32")
+ ) -> R.Tensor(("a", "b"), dtype="float32"):
+ a = T.int64()
+ b = T.int64()
+ cls = Expected
+ with R.dataflow():
+ z = R.add(x, y)
+ a_1 = R.call_tir_inplace(
+ cls.add_inplace,
+ (z, y),
+ out_sinfo=R.Tensor((a, b), dtype="float32"),
+ inplace_indices=[0],
+ )
+ s = R.call_tir_inplace(
+ cls.subtract_inplace,
+ (a_1, a_1),
+ out_sinfo=R.Tensor((a, b), dtype="float32"),
+ inplace_indices=[1],
+ )
+ R.output(s)
+ return s
+
+ tvm.ir.assert_structural_equal(new_mod, Expected, map_free_vars=True)
+ x = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+ y = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+ expected = np.zeros((2, 3), dtype="float32")
+
+ target = tvm.target.Target("llvm")
+ ex = relax.build(new_mod, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ res = vm["main"](x, y)
+ assert (expected == res.numpy()).all()
+
+
+def test_dynamic_mismatch():
+ # cannot statically prove the shapes to be equal so the module should be
unchanged
+ @I.ir_module
+ class DynamicMistmatchTestCase:
+ @R.function
+ def main(
+ x: R.Tensor(("a", "b"), dtype="float32"), y: R.Tensor(("c", "d"),
dtype="float32")
+ ):
+ with R.dataflow():
+ z = R.add(x, y)
+ # Cannot be done in-place because x and y are arguments
+ a = R.add(z, y) # cannot conclude that shapes match
+ R.output(a)
+ return a
+
+ transform_pass = DataflowUseInplaceCalls()
+ new_mod = transform_pass(DynamicMistmatchTestCase)
+ tvm.ir.assert_structural_equal(new_mod, DynamicMistmatchTestCase)
+
+
+if __name__ == "__main__":
+ testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index ce6fd8e042..3ef75b4b49 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -986,6 +986,42 @@ def test_call_tir_with_grad():
_check(Module)
+def test_call_tir_inplace():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def copy(
+ A: T.Buffer((2, 3), "int32"),
+ B: T.Buffer((2, 3), "int32"),
+ out1: T.Buffer((2, 3), "int32"),
+ ):
+ # copies the contents of B into A and out1
+ T.func_attr({"tir.noalias": True})
+ for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_zeros"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(B[ax0, ax1])
+ T.writes(A[ax0, ax1], out1[ax0, ax1])
+ A[ax0, ax1] = B[ax0, ax1]
+ out1[ax0, ax1] = B[ax0, ax1]
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")
+ ) -> R.Tuple(
+ R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2,
3), "int32")
+ ):
+ res = R.call_tir_inplace(
+ Module.copy,
+ (x, y),
+ [0, -1],
+ [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")],
+ )
+ return res
+
+ _check(Module)
+
+
def test_local_function():
@R.function
def main(
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py
b/tests/python/relax/test_tvmscript_printer_relax.py
index dc3334f216..530e45e610 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -399,6 +399,31 @@ R.call_tir_with_grad(tir_func, (v0,),
out_sinfo=R.Tensor((54, 96), dtype="float3
)
+def test_call_tir_inplace():
+ x = relax.Var("x", R.Tensor((32, 32), dtype="int32"))
+ y = relax.Var("y", R.Tensor((32, 32), dtype="int32"))
+ t = tir.Var("t", dtype="int64")
+ call = relax.call_tir_inplace(
+ relax.GlobalVar("tir_func"),
+ (
+ x,
+ y,
+ ),
+ inplace_indices=[-1, 0],
+ out_sinfo=[R.Tensor((32, 32), dtype="int32"), R.Tensor((32, 32),
dtype="int32")],
+ tir_vars=[t],
+ )
+ _assert_print(
+ call,
+ """
+x: R.Tensor((32, 32), dtype="int32")
+y: R.Tensor((32, 32), dtype="int32")
+t = T.int64()
+R.call_tir_inplace(tir_func, (x, y), out_sinfo=[R.Tensor((32, 32),
dtype="int32"), R.Tensor((32, 32), dtype="int32")], inplace_indices=[-1, 0],
tir_vars=R.shape([t]))
+ """,
+ )
+
+
def test_seq_expr():
x = tir.Var("x", "int64")
a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))