This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new dd57556660 [Unity][Op] Introduce `call_inplace_packed` as a
counterpart to `call_tir_inplace` (#15878)
dd57556660 is described below
commit dd575566606830925995875e8fed6d2bb2ac92c3
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Mon Oct 9 09:35:26 2023 -0400
[Unity][Op] Introduce `call_inplace_packed` as a counterpart to
`call_tir_inplace` (#15878)
* Add call_inplace_packed operator
* Whitespace
---
include/tvm/relax/attrs/op.h | 14 +++
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/base.py | 66 ++++++++++++++
python/tvm/script/ir_builder/relax/ir.py | 2 +
src/relax/op/op.cc | 119 ++++++++++++++++++++++++++
src/relax/transform/remove_purity_checking.cc | 7 ++
tests/python/relax/test_relax_operators.py | 87 ++++++++++++++++++-
7 files changed, 295 insertions(+), 1 deletion(-)
diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h
index e419a42111..8e3e9d92f5 100644
--- a/include/tvm/relax/attrs/op.h
+++ b/include/tvm/relax/attrs/op.h
@@ -57,6 +57,20 @@ struct CallTIRInplaceAttrs : public
tvm::AttrsNode<CallTIRInplaceAttrs> {
}
}; // struct CallTIRInplaceAttrs
+/*! \brief Attributes used in call_inplace_packed */
+struct CallInplacePackedAttrs : public tvm::AttrsNode<CallInplacePackedAttrs> {
+ Array<Integer> inplace_indices;
+
+ TVM_DECLARE_ATTRS(CallInplacePackedAttrs,
"relax.attrs.CallInplacePackedAttrs") {
+ TVM_ATTR_FIELD(inplace_indices)
+ .describe(
+ "Indices that describe which input corresponds to which output. If
the `i`th member "
+ "has the value `k` >= 0, then that means that input `k` should be
used to store the "
+ "`i`th output. If an element has the value -1, that means the
output will be newly "
+ "allocated.");
+ }
+}; // struct CallInplacePackedAttrs
+
/*! \brief Attributes used in to_vdevice */
struct ToVDeviceAttrs : public tvm::AttrsNode<ToVDeviceAttrs> {
VDevice dst_vdevice;
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 77f1d0ff44..60a4332d83 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -25,6 +25,7 @@ from .base import (
assert_op,
call_builtin_with_ctx,
call_dps_packed,
+ call_inplace_packed,
call_pure_packed,
call_tir,
call_tir_inplace,
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 603148d0cf..b363dc6952 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -604,6 +604,72 @@ def shape_to_tensor(expr: Expr) -> Expr:
return _ffi_api.shape_to_tensor(expr) # type: ignore # pylint:
disable=no-member
+@args_converter.auto
+def call_inplace_packed(
+ func: Union[str, ExternFunc, GlobalVar],
+ *args: Expr,
+ inplace_indices: Union[int, List[int]],
+ sinfo_args: Union[StructInfo, List[StructInfo]],
+) -> Expr:
+ """
+ Construct a call to a packed function that consumes some of its arguments
"in-place"
+ and returns the mutated arguments (aliased), but should be considered to
be otherwise pure.
+ The `inplace_indices` argument indicates which of the outputs are mutated
arguments.
+
+ The resulting call will have the same semantics as calling the packed
function directly.
+
+ Note: This should be used for cases when the user knows that calling the
packed function
+ with these arguments will **in reality** not cause any other side effects.
+ If it is used for a call that **does** result in other side effects, then
the compiler
+ may end up removing, reordering, or repeating that call, with no guarantees
+ made about any side effects from the callee.
+
+ Warning: This operator as treated as pure by the type system even though
it *is* performing
+ side effects (mutating some arguments). It is therefore incumbent upon the
user to ensure
+ that it is being used safely (viz., that mutated arguments are not live
after the mutation,
+ that they do not alias values live after the mutation).
+
+ Parameters
+ ----------
+ func : Union[str, ExternFunc]
+ The name (global symbol) for a PackedFunc or an ExternFunc node.
+
+ args: Expr
+ The arguments for the PackedFunc.
+
+ input_indices : Union[int, List[int]]
+ Specify which arguments should be used for in-place computations.
+ If `input_indices` is a single integer, it will be made into a singleton
list.
+ Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
+ will be an alias of `args[j]`.
+ If `input_indices[i] = -1`, then the `i`th output will be a freshly
allocated tensor.
+ At least one member of `input_indices` must not be -1.
+
+ sinfo_args: Union[StructInfo, List[StructInfo]]
+ The list of structure info arguments (giving the structural info for
the returned value).
+
+ Returns
+ -------
+ result : Expr
+ A Relax call, corresponding to
+ `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), sinfo_args)`
+ """
+ if isinstance(func, ExternFunc):
+ func = func.global_symbol
+
+ op = ExternFunc(func)
+ if sinfo_args is None:
+ raise ValueError("R.call_pure_packed is required to have type_args")
+ if isinstance(sinfo_args, tuple): # type: ignore
+ sinfo_args = list(sinfo_args)
+ elif not isinstance(sinfo_args, list):
+ sinfo_args = [sinfo_args]
+ if not isinstance(inplace_indices, list):
+ inplace_indices = [inplace_indices]
+
+ return _ffi_api.call_inplace_packed(op, args, inplace_indices, sinfo_args)
# type: ignore # pylint: disable=no-member
+
+
@args_converter.auto
def call_pure_packed(
func: Union[str, ExternFunc, GlobalVar],
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 2166827f48..142d0e6d96 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -52,6 +52,7 @@ from tvm.relax.op import (
broadcast_to,
builtin,
call_builtin_with_ctx,
+ call_inplace_packed,
call_pure_packed,
call_tir,
call_tir_inplace,
@@ -650,6 +651,7 @@ __all__ = [
"bitwise_xor",
"broadcast_to",
"builtin",
+ "call_inplace_packed",
"call_packed",
"call_pure_packed",
"call_tir",
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 326a219ed5..01d0d04be0 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -121,6 +121,125 @@ Expr MakeCallPurePacked(const Expr& callee, Array<Expr>
args, const Attrs& attrs
TVM_REGISTER_GLOBAL("relax.op.call_pure_packed").set_body_typed(MakeCallPurePacked);
+// call_inplace_packed
+
+StructInfo InferStructInfoCallInplacePacked(const Call& call, const
BlockBuilder& ctx) {
+ if (call->args.size() <= 1) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "call_inplace_packed must be called with at least two arguments"
+ << " (the packed call and at least one argument to the packed call"
+ << "if the packed call does not need arguments, use call_pure_packed
instead)");
+ }
+
+ // the callee must be an opaque function
+ auto callee = call->args[0];
+ ICHECK(!callee.as<OpNode>()) << "call_pure_packed cannot be used with an op
node";
+ auto opt = MatchStructInfo<FuncStructInfo>(callee);
+ ICHECK(opt) << "Callee must have a function struct info";
+ FuncStructInfo finfo = opt.value();
+ ICHECK(finfo->IsOpaque()) << "call_pure_packed must be called with an opaque
function, but "
+ << callee << " is not opaque";
+
+ // check the range for inplace indices, make sure at least one is not -1,
ensure they're unique
+ const auto* attrs = call->attrs.as<CallInplacePackedAttrs>();
+ size_t num_args = call->args.size() - 1;
+ std::unordered_set<int> encountered;
+ for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
+ int index = attrs->inplace_indices[i].IntValue();
+ if (index < -1 || index >= static_cast<int>(num_args)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "In-place index " << i << " is out of range (must be
between -1 and "
+ << (num_args - 1) << ", inclusive, but is " << index <<
")");
+ }
+ if (index != -1) {
+ if (encountered.count(index)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "All in-place indices must be unique, but index "
<< index
+ << " appears more than once.");
+ }
+ encountered.insert(index);
+ }
+ }
+ if (encountered.empty()) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "At least one index must have
a value other than "
+ "-1 (or else simply use
call_pure_packed)");
+ }
+
+ // same logic as from DeriveCallRetStructInfo for ordinary calls
+ StructInfo ret;
+ if (finfo->derive_func.defined()) {
+ // derive using custom derivation function.
+ ret = finfo->derive_func.value()(call, ctx);
+ } else {
+ // directly return the normal value.
+ ret = finfo->ret;
+ }
+
+ // make sure that the derived return struct info matches that of the
in-place args
+ // (note: arg 0 is the packed func, so we add 1 to the arg index)
+ if (attrs->inplace_indices.size() == 1) {
+ auto arg_idx = attrs->inplace_indices[0].IntValue() + 1;
+ auto arg_sinfo = GetStructInfo(call->args[arg_idx]);
+ if (!IsBaseOf(ret, arg_sinfo, ctx->GetAnalyzer())) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "The derived return StructInfo does not match that
for "
+ << "the in-place argument at index " << (arg_idx - 1)
<< ": " << ret
+ << " vs " << arg_sinfo);
+ }
+ } else {
+ auto* tup_info = ret.as<TupleStructInfoNode>();
+ if (!tup_info) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "Multiple outputs given via
the inplace indices "
+ "but the derived StructInfo
is not a tuple");
+ }
+ for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
+ if (attrs->inplace_indices[i] == -1) {
+ continue;
+ }
+ auto arg_idx = attrs->inplace_indices[i].IntValue() + 1;
+ auto arg_sinfo = GetStructInfo(call->args[arg_idx]);
+ auto ret_sinfo = tup_info->fields[i];
+ if (!IsBaseOf(ret_sinfo, arg_sinfo, ctx->GetAnalyzer())) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "The derived return StructInfo does not match that
for "
+ << "the in-place argument at index " << (arg_idx - 1)
<< ": " << ret_sinfo
+ << " vs " << arg_sinfo);
+ }
+ }
+ }
+
+ return ret;
+}
+
+TVM_REGISTER_NODE_TYPE(CallInplacePackedAttrs);
+
+RELAY_REGISTER_OP("relax.call_inplace_packed")
+ .set_num_inputs(-1)
+ .set_attrs_type<CallInplacePackedAttrs>()
+ .add_argument("args", "Array<Expr>",
+ "The first argument is the function being called. The rest
are the "
+ "arguments to that function.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoCallInplacePacked)
+ // Warning: considered pure, but it has the potential to create visible
effects!
+ // This should only be used if it has been *checked* that it is safe (no
aliases, in-place
+ // arguments will no longer be live) and the user believes the packed func
to have no
+ // side effects other than modifying the arguments specified as "inplace"
+ .set_attr<Bool>("FPurity", Bool(true));
+
+Expr MakeCallInplacePacked(Expr func, Array<Expr> args, Array<Integer>
inplace_indices,
+ Array<StructInfo> sinfo_args) {
+ ObjectPtr<CallInplacePackedAttrs> attrs =
make_object<CallInplacePackedAttrs>();
+ attrs->inplace_indices = Array<Integer>(inplace_indices.begin(),
inplace_indices.end());
+
+ static const Op& op = Op::Get("relax.call_inplace_packed");
+ Array<Expr> call_args = {func};
+ call_args.insert(call_args.end(), args.begin(), args.end());
+ return Call(op, call_args, Attrs(attrs), sinfo_args);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInplacePacked);
+
// call_tir
StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
diff --git a/src/relax/transform/remove_purity_checking.cc
b/src/relax/transform/remove_purity_checking.cc
index f190968d9d..7e3a31d0ba 100644
--- a/src/relax/transform/remove_purity_checking.cc
+++ b/src/relax/transform/remove_purity_checking.cc
@@ -52,6 +52,12 @@ class PurityRemover : public ExprMutator {
call->attrs, call->sinfo_args);
return VisitExpr(ret);
}
+ if (call->op == call_inplace_packed_op_) {
+ // call_inplace_packed has its own attrs so we don't pass those down
+ auto ret = Call(call->args[0], Array<Expr>(call->args.begin() + 1,
call->args.end()),
+ tvm::Attrs(), call->sinfo_args);
+ return VisitExpr(ret);
+ }
if (call->op == invoke_pure_closure_op_) {
auto ret = Call(invoke_closure_op_, call->args, call->attrs,
call->sinfo_args);
return VisitExpr(ret);
@@ -66,6 +72,7 @@ class PurityRemover : public ExprMutator {
private:
const Op& call_pure_packed_op_ = Op::Get("relax.call_pure_packed");
+ const Op& call_inplace_packed_op_ = Op::Get("relax.call_inplace_packed");
const Op& invoke_pure_closure_op_ = Op::Get("relax.invoke_pure_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
};
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index 632ac96ff4..a278b09167 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -39,7 +39,9 @@ def run_cpu(mod, func_name, *input):
target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
- return vm[func_name](*input)
+ vm.set_input(func_name, *input)
+ vm.invoke_stateful(func_name)
+ return vm.get_outputs(func_name)
def test_unique():
@@ -248,6 +250,89 @@ def test_op_call_pure_packed():
assert (copy_found.numpy() == arr).all()
+def test_op_call_inplace_packed():
+ # in this case we can use the same test as above
+ @tvm.script.ir_module
+ class CallInplaceTest:
+ @R.function
+ def pure_copy(x: R.Tensor((3, 4), "float32")):
+ z = R.call_inplace_packed(
+ "vm.builtin.copy",
+ x,
+ inplace_indices=0,
+ sinfo_args=(R.Tensor((3, 4), dtype="float32")),
+ )
+ return z
+
+ @tvm.register_func("test.inplace.add")
+ def inplace_add(a, b):
+ arr_a = a.numpy()
+ arr_b = b.numpy()
+ for i in range(len(arr_a)):
+ for j in range(len(arr_a[i])):
+ arr_a[i][j] = arr_a[i][j] + arr_b[i][j]
+ a.copyfrom(arr_a)
+ return a
+
+ @tvm.script.ir_module
+ class CallInplaceAddTest:
+ @R.function
+ def inplace_add(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4),
"float32")):
+ z = R.call_inplace_packed(
+ "test.inplace.add",
+ x,
+ y,
+ inplace_indices=0,
+ sinfo_args=(R.Tensor((3, 4), dtype="float32")),
+ )
+ return z
+
+ np.random.seed(1) # to avoid flakiness
+ arr_a = np.random.rand(3, 4).astype("float32")
+ arr_b = np.random.rand(3, 4).astype("float32")
+ sum = arr_a + arr_b
+ tvm_arr_a = tvm.nd.array(arr_a)
+ result = run_cpu(CallInplaceAddTest, "inplace_add", tvm_arr_a,
tvm.nd.array(arr_b))
+ assert result == tvm_arr_a
+ assert (result.numpy() == sum).all()
+
+ @tvm.register_func("test.inplace.tuple_add")
+ def inplace_tuple_add(a, b):
+ arr_a = a.numpy()
+ arr_b = b.numpy()
+ c = tvm.nd.array(arr_a + arr_b)
+ for i in range(len(arr_a)):
+ for j in range(len(arr_a[i])):
+ arr_a[i][j] = arr_a[i][j] + arr_b[i][j]
+ a.copyfrom(arr_a)
+ return tvm.runtime.container.ADT(0, [a, c])
+
+ @tvm.script.ir_module
+ class CallInplaceTuple:
+ @R.function
+ def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4),
"float32")):
+ z = R.call_inplace_packed(
+ "test.inplace.tuple_add",
+ x,
+ y,
+ inplace_indices=[0, -1],
+ sinfo_args=(R.Tensor((3, 4), dtype="float32"), R.Tensor((3,
4), dtype="float32")),
+ )
+ return z
+
+ np.random.seed(2) # to avoid flakiness
+ arr_a = np.random.rand(3, 4).astype("float32")
+ arr_b = np.random.rand(3, 4).astype("float32")
+ sum = arr_a + arr_b
+ tvm_arr_a = tvm.nd.array(arr_a)
+ tvm_arr_b = tvm.nd.array(arr_b)
+ result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b)
+ assert result[0] == tvm_arr_a
+ assert (result[0].numpy() == sum).all()
+ assert result[1] != tvm_arr_a and result[1] != tvm_arr_b
+ assert (result[1].numpy() == sum).all()
+
+
def test_op_to_device():
@tvm.script.ir_module
class CallToDevice: