This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new dd57556660 [Unity][Op] Introduce `call_inplace_packed` as a 
counterpart to `call_tir_inplace` (#15878)
dd57556660 is described below

commit dd575566606830925995875e8fed6d2bb2ac92c3
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Mon Oct 9 09:35:26 2023 -0400

    [Unity][Op] Introduce `call_inplace_packed` as a counterpart to 
`call_tir_inplace` (#15878)
    
    * Add call_inplace_packed operator
    
    * Whitespace
---
 include/tvm/relax/attrs/op.h                  |  14 +++
 python/tvm/relax/op/__init__.py               |   1 +
 python/tvm/relax/op/base.py                   |  66 ++++++++++++++
 python/tvm/script/ir_builder/relax/ir.py      |   2 +
 src/relax/op/op.cc                            | 119 ++++++++++++++++++++++++++
 src/relax/transform/remove_purity_checking.cc |   7 ++
 tests/python/relax/test_relax_operators.py    |  87 ++++++++++++++++++-
 7 files changed, 295 insertions(+), 1 deletion(-)

diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h
index e419a42111..8e3e9d92f5 100644
--- a/include/tvm/relax/attrs/op.h
+++ b/include/tvm/relax/attrs/op.h
@@ -57,6 +57,20 @@ struct CallTIRInplaceAttrs : public 
tvm::AttrsNode<CallTIRInplaceAttrs> {
   }
 };  // struct CallTIRInplaceAttrs
 
+/*! \brief Attributes used in call_inplace_packed */
+struct CallInplacePackedAttrs : public tvm::AttrsNode<CallInplacePackedAttrs> {
+  Array<Integer> inplace_indices;
+
+  TVM_DECLARE_ATTRS(CallInplacePackedAttrs, 
"relax.attrs.CallInplacePackedAttrs") {
+    TVM_ATTR_FIELD(inplace_indices)
+        .describe(
+            "Indices that describe which input corresponds to which output. If 
the `i`th member "
+            "has the value `k` >= 0, then that means that input `k` should be 
used to store the "
+            "`i`th output. If an element has the value -1, that means the 
output will be newly "
+            "allocated.");
+  }
+};  // struct CallInplacePackedAttrs
+
 /*! \brief Attributes used in to_vdevice */
 struct ToVDeviceAttrs : public tvm::AttrsNode<ToVDeviceAttrs> {
   VDevice dst_vdevice;
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 77f1d0ff44..60a4332d83 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -25,6 +25,7 @@ from .base import (
     assert_op,
     call_builtin_with_ctx,
     call_dps_packed,
+    call_inplace_packed,
     call_pure_packed,
     call_tir,
     call_tir_inplace,
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 603148d0cf..b363dc6952 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -604,6 +604,72 @@ def shape_to_tensor(expr: Expr) -> Expr:
     return _ffi_api.shape_to_tensor(expr)  # type: ignore # pylint: 
disable=no-member
 
 
+@args_converter.auto
+def call_inplace_packed(
+    func: Union[str, ExternFunc, GlobalVar],
+    *args: Expr,
+    inplace_indices: Union[int, List[int]],
+    sinfo_args: Union[StructInfo, List[StructInfo]],
+) -> Expr:
+    """
+    Construct a call to a packed function that consumes some of its arguments 
"in-place"
+    and returns the mutated arguments (aliased), but should be considered to 
be otherwise pure.
+    The `inplace_indices` argument indicates which of the outputs are mutated 
arguments.
+
+    The resulting call will have the same semantics as calling the packed 
function directly.
+
+    Note: This should be used for cases when the user knows that calling the 
packed function
+    with these arguments will **in reality** not cause any other side effects.
+    If it is used for a call that **does** result in other side effects, then 
the compiler
+    may end up removing, reordering, or repeating that call, with no guarantees
+    made about any side effects from the callee.
+
+    Warning: This operator as treated as pure by the type system even though 
it *is* performing
+    side effects (mutating some arguments). It is therefore incumbent upon the 
user to ensure
+    that it is being used safely (viz., that mutated arguments are not live 
after the mutation,
+    that they do not alias values live after the mutation).
+
+    Parameters
+    ----------
+    func : Union[str, ExternFunc]
+      The name (global symbol) for a PackedFunc or an ExternFunc node.
+
+    args: Expr
+      The arguments for the PackedFunc.
+
+    input_indices : Union[int, List[int]]
+      Specify which arguments should be used for in-place computations.
+      If `input_indices` is a single integer, it will be made into a singleton 
list.
+      Suppose `input_indices[i] = j`, where `j >= 0`. Then the `i`th output
+      will be an alias of `args[j]`.
+      If `input_indices[i] = -1`, then the `i`th output will be a freshly 
allocated tensor.
+      At least one member of `input_indices` must not be -1.
+
+    sinfo_args: Union[StructInfo, List[StructInfo]]
+        The list of structure info arguments (giving the structural info for 
the returned value).
+
+    Returns
+    -------
+    result : Expr
+      A Relax call, corresponding to
+      `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), sinfo_args)`
+    """
+    if isinstance(func, ExternFunc):
+        func = func.global_symbol
+
+    op = ExternFunc(func)
+    if sinfo_args is None:
+        raise ValueError("R.call_pure_packed is required to have type_args")
+    if isinstance(sinfo_args, tuple):  # type: ignore
+        sinfo_args = list(sinfo_args)
+    elif not isinstance(sinfo_args, list):
+        sinfo_args = [sinfo_args]
+    if not isinstance(inplace_indices, list):
+        inplace_indices = [inplace_indices]
+
+    return _ffi_api.call_inplace_packed(op, args, inplace_indices, sinfo_args) 
 # type: ignore # pylint: disable=no-member
+
+
 @args_converter.auto
 def call_pure_packed(
     func: Union[str, ExternFunc, GlobalVar],
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 2166827f48..142d0e6d96 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -52,6 +52,7 @@ from tvm.relax.op import (
     broadcast_to,
     builtin,
     call_builtin_with_ctx,
+    call_inplace_packed,
     call_pure_packed,
     call_tir,
     call_tir_inplace,
@@ -650,6 +651,7 @@ __all__ = [
     "bitwise_xor",
     "broadcast_to",
     "builtin",
+    "call_inplace_packed",
     "call_packed",
     "call_pure_packed",
     "call_tir",
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 326a219ed5..01d0d04be0 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -121,6 +121,125 @@ Expr MakeCallPurePacked(const Expr& callee, Array<Expr> 
args, const Attrs& attrs
 
 
TVM_REGISTER_GLOBAL("relax.op.call_pure_packed").set_body_typed(MakeCallPurePacked);
 
+// call_inplace_packed
+
+StructInfo InferStructInfoCallInplacePacked(const Call& call, const 
BlockBuilder& ctx) {
+  if (call->args.size() <= 1) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "call_inplace_packed must be called with at least two arguments"
+        << " (the packed call and at least one argument to the packed call"
+        << "if the packed call does not need arguments, use call_pure_packed 
instead)");
+  }
+
+  // the callee must be an opaque function
+  auto callee = call->args[0];
+  ICHECK(!callee.as<OpNode>()) << "call_pure_packed cannot be used with an op 
node";
+  auto opt = MatchStructInfo<FuncStructInfo>(callee);
+  ICHECK(opt) << "Callee must have a function struct info";
+  FuncStructInfo finfo = opt.value();
+  ICHECK(finfo->IsOpaque()) << "call_pure_packed must be called with an opaque 
function, but "
+                            << callee << " is not opaque";
+
+  // check the range for inplace indices, make sure at least one is not -1, 
ensure they're unique
+  const auto* attrs = call->attrs.as<CallInplacePackedAttrs>();
+  size_t num_args = call->args.size() - 1;
+  std::unordered_set<int> encountered;
+  for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
+    int index = attrs->inplace_indices[i].IntValue();
+    if (index < -1 || index >= static_cast<int>(num_args)) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "In-place index " << i << " is out of range (must be 
between -1 and "
+                       << (num_args - 1) << ", inclusive, but is " << index << 
")");
+    }
+    if (index != -1) {
+      if (encountered.count(index)) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "All in-place indices must be unique, but index " 
<< index
+                         << " appears more than once.");
+      }
+      encountered.insert(index);
+    }
+  }
+  if (encountered.empty()) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "At least one index must have 
a value other than "
+                                                "-1 (or else simply use 
call_pure_packed)");
+  }
+
+  // same logic as from DeriveCallRetStructInfo for ordinary calls
+  StructInfo ret;
+  if (finfo->derive_func.defined()) {
+    // derive using custom derivation function.
+    ret = finfo->derive_func.value()(call, ctx);
+  } else {
+    // directly return the normal value.
+    ret = finfo->ret;
+  }
+
+  // make sure that the derived return struct info matches that of the 
in-place args
+  // (note: arg 0 is the packed func, so we add 1 to the arg index)
+  if (attrs->inplace_indices.size() == 1) {
+    auto arg_idx = attrs->inplace_indices[0].IntValue() + 1;
+    auto arg_sinfo = GetStructInfo(call->args[arg_idx]);
+    if (!IsBaseOf(ret, arg_sinfo, ctx->GetAnalyzer())) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "The derived return StructInfo does not match that 
for "
+                       << "the in-place argument at index " << (arg_idx - 1) 
<< ": " << ret
+                       << " vs " << arg_sinfo);
+    }
+  } else {
+    auto* tup_info = ret.as<TupleStructInfoNode>();
+    if (!tup_info) {
+      ctx->ReportFatal(Diagnostic::Error(call) << "Multiple outputs given via 
the inplace indices "
+                                                  "but the derived StructInfo 
is not a tuple");
+    }
+    for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
+      if (attrs->inplace_indices[i] == -1) {
+        continue;
+      }
+      auto arg_idx = attrs->inplace_indices[i].IntValue() + 1;
+      auto arg_sinfo = GetStructInfo(call->args[arg_idx]);
+      auto ret_sinfo = tup_info->fields[i];
+      if (!IsBaseOf(ret_sinfo, arg_sinfo, ctx->GetAnalyzer())) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "The derived return StructInfo does not match that 
for "
+                         << "the in-place argument at index " << (arg_idx - 1) 
<< ": " << ret_sinfo
+                         << " vs " << arg_sinfo);
+      }
+    }
+  }
+
+  return ret;
+}
+
+TVM_REGISTER_NODE_TYPE(CallInplacePackedAttrs);
+
+RELAY_REGISTER_OP("relax.call_inplace_packed")
+    .set_num_inputs(-1)
+    .set_attrs_type<CallInplacePackedAttrs>()
+    .add_argument("args", "Array<Expr>",
+                  "The first argument is the function being called. The rest 
are the "
+                  "arguments to that function.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoCallInplacePacked)
+    // Warning: considered pure, but it has the potential to create visible 
effects!
+    // This should only be used if it has been *checked* that it is safe (no 
aliases, in-place
+    // arguments will no longer be live) and the user believes the packed func 
to have no
+    // side effects other than modifying the arguments specified as "inplace"
+    .set_attr<Bool>("FPurity", Bool(true));
+
+Expr MakeCallInplacePacked(Expr func, Array<Expr> args, Array<Integer> 
inplace_indices,
+                           Array<StructInfo> sinfo_args) {
+  ObjectPtr<CallInplacePackedAttrs> attrs = 
make_object<CallInplacePackedAttrs>();
+  attrs->inplace_indices = Array<Integer>(inplace_indices.begin(), 
inplace_indices.end());
+
+  static const Op& op = Op::Get("relax.call_inplace_packed");
+  Array<Expr> call_args = {func};
+  call_args.insert(call_args.end(), args.begin(), args.end());
+  return Call(op, call_args, Attrs(attrs), sinfo_args);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInplacePacked);
+
 // call_tir
 
 StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
diff --git a/src/relax/transform/remove_purity_checking.cc 
b/src/relax/transform/remove_purity_checking.cc
index f190968d9d..7e3a31d0ba 100644
--- a/src/relax/transform/remove_purity_checking.cc
+++ b/src/relax/transform/remove_purity_checking.cc
@@ -52,6 +52,12 @@ class PurityRemover : public ExprMutator {
                       call->attrs, call->sinfo_args);
       return VisitExpr(ret);
     }
+    if (call->op == call_inplace_packed_op_) {
+      // call_inplace_packed has its own attrs so we don't pass those down
+      auto ret = Call(call->args[0], Array<Expr>(call->args.begin() + 1, 
call->args.end()),
+                      tvm::Attrs(), call->sinfo_args);
+      return VisitExpr(ret);
+    }
     if (call->op == invoke_pure_closure_op_) {
       auto ret = Call(invoke_closure_op_, call->args, call->attrs, 
call->sinfo_args);
       return VisitExpr(ret);
@@ -66,6 +72,7 @@ class PurityRemover : public ExprMutator {
 
  private:
   const Op& call_pure_packed_op_ = Op::Get("relax.call_pure_packed");
+  const Op& call_inplace_packed_op_ = Op::Get("relax.call_inplace_packed");
   const Op& invoke_pure_closure_op_ = Op::Get("relax.invoke_pure_closure");
   const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
 };
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
index 632ac96ff4..a278b09167 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -39,7 +39,9 @@ def run_cpu(mod, func_name, *input):
     target = tvm.target.Target("llvm")
     ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, tvm.cpu())
-    return vm[func_name](*input)
+    vm.set_input(func_name, *input)
+    vm.invoke_stateful(func_name)
+    return vm.get_outputs(func_name)
 
 
 def test_unique():
@@ -248,6 +250,89 @@ def test_op_call_pure_packed():
     assert (copy_found.numpy() == arr).all()
 
 
+def test_op_call_inplace_packed():
+    # in this case we can use the same test as above
+    @tvm.script.ir_module
+    class CallInplaceTest:
+        @R.function
+        def pure_copy(x: R.Tensor((3, 4), "float32")):
+            z = R.call_inplace_packed(
+                "vm.builtin.copy",
+                x,
+                inplace_indices=0,
+                sinfo_args=(R.Tensor((3, 4), dtype="float32")),
+            )
+            return z
+
+    @tvm.register_func("test.inplace.add")
+    def inplace_add(a, b):
+        arr_a = a.numpy()
+        arr_b = b.numpy()
+        for i in range(len(arr_a)):
+            for j in range(len(arr_a[i])):
+                arr_a[i][j] = arr_a[i][j] + arr_b[i][j]
+        a.copyfrom(arr_a)
+        return a
+
+    @tvm.script.ir_module
+    class CallInplaceAddTest:
+        @R.function
+        def inplace_add(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), 
"float32")):
+            z = R.call_inplace_packed(
+                "test.inplace.add",
+                x,
+                y,
+                inplace_indices=0,
+                sinfo_args=(R.Tensor((3, 4), dtype="float32")),
+            )
+            return z
+
+    np.random.seed(1)  # to avoid flakiness
+    arr_a = np.random.rand(3, 4).astype("float32")
+    arr_b = np.random.rand(3, 4).astype("float32")
+    sum = arr_a + arr_b
+    tvm_arr_a = tvm.nd.array(arr_a)
+    result = run_cpu(CallInplaceAddTest, "inplace_add", tvm_arr_a, 
tvm.nd.array(arr_b))
+    assert result == tvm_arr_a
+    assert (result.numpy() == sum).all()
+
+    @tvm.register_func("test.inplace.tuple_add")
+    def inplace_tuple_add(a, b):
+        arr_a = a.numpy()
+        arr_b = b.numpy()
+        c = tvm.nd.array(arr_a + arr_b)
+        for i in range(len(arr_a)):
+            for j in range(len(arr_a[i])):
+                arr_a[i][j] = arr_a[i][j] + arr_b[i][j]
+        a.copyfrom(arr_a)
+        return tvm.runtime.container.ADT(0, [a, c])
+
+    @tvm.script.ir_module
+    class CallInplaceTuple:
+        @R.function
+        def inplace_tuple(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), 
"float32")):
+            z = R.call_inplace_packed(
+                "test.inplace.tuple_add",
+                x,
+                y,
+                inplace_indices=[0, -1],
+                sinfo_args=(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 
4), dtype="float32")),
+            )
+            return z
+
+    np.random.seed(2)  # to avoid flakiness
+    arr_a = np.random.rand(3, 4).astype("float32")
+    arr_b = np.random.rand(3, 4).astype("float32")
+    sum = arr_a + arr_b
+    tvm_arr_a = tvm.nd.array(arr_a)
+    tvm_arr_b = tvm.nd.array(arr_b)
+    result = run_cpu(CallInplaceTuple, "inplace_tuple", tvm_arr_a, tvm_arr_b)
+    assert result[0] == tvm_arr_a
+    assert (result[0].numpy() == sum).all()
+    assert result[1] != tvm_arr_a and result[1] != tvm_arr_b
+    assert (result[1].numpy() == sum).all()
+
+
 def test_op_to_device():
     @tvm.script.ir_module
     class CallToDevice:

Reply via email to