This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d705cb2de4 [Relax][IR] Skip in-place multiply when two operands are 
views of the same tensor (#19644)
d705cb2de4 is described below

commit d705cb2de4aef1ed8a5594eaf92f7c597e77ff8d
Author: ConvolutedDog <[email protected]>
AuthorDate: Sun May 31 23:50:18 2026 +0800

    [Relax][IR] Skip in-place multiply when two operands are views of the same 
tensor (#19644)
    
    This PR will fix https://github.com/apache/tvm/issues/19577.
    
    In this issue, the IRModule before applying any pass looks like:
    
    ```
      %x: Tensor[(4,), float32]  // function param
      with R.dataflow():
        %lv  = expand_dims(%x, axis=1)     // (4, 1)
        %lv1 = expand_dims(%x, axis=1)     // (4, 1) second call, new Var
        %lv2 = multiply(%lv, %lv1)         // (4, 1)
        %lv3 = concat(%lv2, %lv1, axis=1)  // (4, 2)
        ...
    ```
    
    When the users manually apply the `DataflowUseInplaceCalls` pass, the
    pass will rewrite the statement `%lv2 = multiply(%lv, %lv1)` to be like
    `%lv = multiply(%lv, %lv1); %lv3 = concat(%lv, %lv1, axis=1)`, which
    reuses the %lv buffer to avoid storage waste.
    
    But this rewrite will chang the buffer context of %lv, and also in LLVM
    generated code, %lv1 shared the same storage with %lv, so when executing
    `%lv = concat(%lv, %lv1, axis=1)`, the %lv1 context has also been
    changed to `multiply(%lv, %lv1)`. So the failure is due to the shared
    storage of different views of the same tensor %x.
    
    During the execution, %lv1 holds `x^2` instead of `x` after `multiply`.
    `concat` reads %lv1 for the right column and its result is
    [[1,1],[4,4],[9,9],[16,16]] instead of [[1,1],[4,2],[9,3],[16,4]] (the
    correct result should be : left col `x^2`, right col should stay `x`).
    
    Change: View-like ops (expand_dims, squeeze, reshape, permute_dims,
    memory.view, ensure_zero_offset) take the input's alias set in alias
    analysis instead of a new id: %lv and %lv1 share alias with %x. Then the
    pass rejects in-place of `multiply(%lv, %lv1)`: %lv and %lv1 are
    different vars but alias ids intersect, so no operand may be reused
    in-place.
---
 include/tvm/runtime/tensor.h                |  20 ++
 src/relax/transform/dataflow_inplace.cc     |  67 ++++-
 src/runtime/tensor.cc                       |  30 ++-
 tests/python/relax/test_dataflow_inplace.py | 390 ++++++++++++++++++++++++++++
 4 files changed, 505 insertions(+), 2 deletions(-)

diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h
index 33a78a48d6..d3497c8ff7 100644
--- a/include/tvm/runtime/tensor.h
+++ b/include/tvm/runtime/tensor.h
@@ -183,6 +183,26 @@ class Tensor : public tvm::ffi::Tensor {
    */
   TVM_RUNTIME_DLL static void CopyFromBytes(const DLTensor* to, void* from, 
size_t nbytes,
                                             TVMStreamHandle stream = nullptr);
+
+  /*!
+   * \brief Check if two tensors share the same underlying storage.
+   *
+   * This detects runtime storage aliasing (e.g. views from CreateView, etc.) 
but does
+   * not imply either tensor was created by CreateView.
+   *
+   * \param a The first tensor.
+   * \param b The second tensor.
+   * \return True if the tensors share the same storage.
+   */
+  TVM_RUNTIME_DLL static bool IsStorageShared(const DLTensor* a, const 
DLTensor* b);
+
+  /*!
+   * \brief Tensor overload of IsStorageShared.
+   * \param a The first tensor.
+   * \param b The second tensor.
+   * \return True if the tensors share the same storage.
+   */
+  static bool IsStorageShared(const Tensor& a, const Tensor& b);
 };
 
 /*!
diff --git a/src/relax/transform/dataflow_inplace.cc 
b/src/relax/transform/dataflow_inplace.cc
index 8072ee5d14..c3ed7ef0b6 100644
--- a/src/relax/transform/dataflow_inplace.cc
+++ b/src/relax/transform/dataflow_inplace.cc
@@ -39,6 +39,67 @@
 namespace tvm {
 namespace relax {
 
+// Ops that may return a tensor sharing storage with the first argument.
+// These ops has been verified to share storage with the first argument in
+// tests/python/relax/test_dataflow_inplace.py.
+bool IsViewMemoryOp(const OpNode* op_node) {
+  // TODO: Consider to add more ops that may return a tensor sharing storage 
with
+  // the first argument in the future.
+  static const std::unordered_set<std::string> kViewOps = {
+      "relax.expand_dims", "relax.squeeze",
+      "relax.reshape",     "relax.permute_dims",
+      "relax.flatten",     "relax.nn.batch_flatten",
+      "relax.memory.view", "relax.memory.ensure_zero_offset",
+  };
+  return kViewOps.count(op_node->name);
+}
+
+// Look up alias ids for a call argument (only Var args are expected in 
dataflow blocks).
+std::unordered_set<int> GetVarAliasSetFromExpr(
+    const Expr& arg, const std::unordered_map<Var, std::unordered_set<int>>& 
alias_sets) {
+  if (auto* var_node = arg.as<VarNode>()) {
+    Var var = ffi::GetRef<Var>(var_node);
+    if (!alias_sets.count(var)) {
+      return {-1};
+    }
+    return alias_sets.at(var);
+  }
+  return {-1};
+}
+
+// In-place on arg `candidate` is invalid if another distinct operand may 
alias the same
+// storage (e.g. two expand_dims views of x bound to different vars). Reject 
on any shared
+// alias id; -1 in the other operand's set does not skip checking other ids. 
Same var twice
+// (e.g. add(z, z)) is allowed.
+bool InplaceArgDisjointFromOtherCallArgs(
+    const CallNode* call_node, int candidate,
+    const std::unordered_map<Var, std::unordered_set<int>>& alias_sets) {
+  const auto* cand_var_node = call_node->args[candidate].as<VarNode>();
+  if (!cand_var_node) {
+    return false;
+  }
+  auto cand_set = GetVarAliasSetFromExpr(call_node->args[candidate], 
alias_sets);
+  if (cand_set.count(-1)) {
+    return false;
+  }
+  for (size_t j = 0; j < call_node->args.size(); j++) {
+    if (static_cast<int>(j) == candidate) {
+      continue;
+    }
+    const Expr& other_arg = call_node->args[j];
+    if (other_arg.same_as(call_node->args[candidate])) {
+      continue;
+    }
+    auto other_set = GetVarAliasSetFromExpr(other_arg, alias_sets);
+    for (int alias_idx : other_set) {
+      if (cand_set.count(alias_idx)) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
 // Perform liveness analysis on a dataflow block, returning a map of vars to
 // pairs of indices (the liveness interval, from the starting index to the end 
index).
 // A starting index of -1 means the var is defined before the block starts and 
an end index
@@ -274,6 +335,9 @@ class AliasAnalyzer {
           } else {
             ret.insert(get_fresh_idx());
           }
+        } else if (IsViewMemoryOp(op_node) && !call_node->args.empty()) {
+          // View-like ops may share storage with their input (and with other 
views of it).
+          return GetAliasSet(call_node->args[0], bound_var);
         } else {
           // We are assuming most op calls return fresh values.
           // We may have to track more exceptions
@@ -654,7 +718,8 @@ FindInplaceOpportunities(const DataflowBlock& block, const 
ffi::Array<Var>& inpu
         std::unordered_set<int> remove_candidates;
         for (auto candidate : candidates) {
           if (!InplaceConditionsMet(live_ranges, alias_sets, tuple_map, 
currently_live,
-                                    call_node->args[candidate], i)) {
+                                    call_node->args[candidate], i) ||
+              !InplaceArgDisjointFromOtherCallArgs(call_node, candidate, 
alias_sets)) {
             remove_candidates.insert(candidate);
           }
         }
diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc
index 2b694b1742..887d576537 100644
--- a/src/runtime/tensor.cc
+++ b/src/runtime/tensor.cc
@@ -29,6 +29,8 @@
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/tensor.h>
 
+#include <algorithm>
+
 #include "../support/base64.h"
 #include "../support/bytes_io.h"
 #include "tvm/runtime/data_type.h"
@@ -217,6 +219,30 @@ Tensor Tensor::CopyTo(const Device& dev, 
ffi::Optional<ffi::String> mem_scope) c
   return ret;
 }
 
+inline char* StorageBegin(const DLTensor* tensor) {
+  TVM_FFI_ICHECK(tensor != nullptr);
+  return static_cast<char*>(tensor->data) + tensor->byte_offset;
+}
+
+inline char* StorageEnd(const DLTensor* tensor) {
+  TVM_FFI_ICHECK(tensor != nullptr);
+  return StorageBegin(tensor) + ffi::GetDataSize(*tensor);
+}
+
+bool Tensor::IsStorageShared(const DLTensor* a, const DLTensor* b) {
+  TVM_FFI_ICHECK(a != nullptr && b != nullptr);
+  if (a->device.device_type != b->device.device_type ||
+      a->device.device_id != b->device.device_id) {
+    return false;
+  }
+  return StorageBegin(a) == StorageBegin(b) && StorageEnd(a) == StorageEnd(b);
+}
+
+bool Tensor::IsStorageShared(const Tensor& a, const Tensor& b) {
+  TVM_FFI_ICHECK(a.defined() && b.defined());
+  return IsStorageShared(a.operator->(), b.operator->());
+}
+
 void Tensor::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle 
stream) {
   size_t from_size = ffi::GetDataSize(*from);
   size_t to_size = ffi::GetDataSize(*to);
@@ -270,5 +296,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def("runtime.TVMTensorCopyToBytes",
            [](DLTensor* arr, void* data, size_t nbytes) { 
Tensor::CopyToBytes(arr, data, nbytes); })
       .def("runtime.TVMTensorCopyFromTo",
-           [](DLTensor* from, DLTensor* to) { Tensor::CopyFromTo(from, to); });
+           [](DLTensor* from, DLTensor* to) { Tensor::CopyFromTo(from, to); })
+      .def("runtime.TVMTensorIsStorageShared",
+           [](Tensor a, Tensor b) { return Tensor::IsStorageShared(a, b); });
 }
diff --git a/tests/python/relax/test_dataflow_inplace.py 
b/tests/python/relax/test_dataflow_inplace.py
index 61791b2b32..1b23e14482 100644
--- a/tests/python/relax/test_dataflow_inplace.py
+++ b/tests/python/relax/test_dataflow_inplace.py
@@ -18,9 +18,12 @@
 
 
 import numpy as np
+import pytest
+import torch
 
 import tvm
 from tvm import relax, testing
+from tvm.relax import VMInstrumentReturnKind
 from tvm.relax.testing.transform import (
     dataflow_alias_analysis,
     dataflow_inplace_analysis,
@@ -643,5 +646,392 @@ def test_dynamic_mismatch():
     tvm.ir.assert_structural_equal(new_mod, DynamicMistmatchTestCase)
 
 
+class TestViewOpSharedStorageAndNoInplace:
+    storage_ptr_x_1d = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
+    storage_ptr_x_2d = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
+    storage_ptr_x_squeeze = np.array([[[1.0], [2.0], [3.0], [4.0]]], 
dtype=np.float32)
+    storage_ptr_x_ensure_zero_offset = np.array([[1.0], [2.0], [3.0], [4.0]], 
dtype=np.float32)
+
+    @I.ir_module
+    class _SharedStorageExpandDimsModule:
+        @R.function
+        def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((4, 1), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 1), dtype="float32") = R.expand_dims(x, 
axis=[1])
+                lv1: R.Tensor((4, 1), dtype="float32") = R.expand_dims(x, 
axis=[1])
+                gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class _SharedStorageSqueezeModule:
+        @R.function
+        def main(x: R.Tensor((1, 4, 1), dtype="float32")) -> R.Tensor((4, 1), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 1), dtype="float32") = R.squeeze(x, axis=[0])
+                lv1: R.Tensor((4, 1), dtype="float32") = R.squeeze(x, axis=[0])
+                gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class _SharedStorageReshapeModule:
+        @R.function
+        def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((4, 1), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 1), dtype="float32") = R.reshape(x, (4, 1))
+                lv1: R.Tensor((4, 1), dtype="float32") = R.reshape(x, (4, 1))
+                gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class _SharedStoragePermuteDimsModule:
+        @R.function
+        def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((4, 1), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 1), dtype="float32") = R.permute_dims(x, 
axes=[1, 0])
+                lv1: R.Tensor((4, 1), dtype="float32") = R.permute_dims(x, 
axes=[1, 0])
+                gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class _SharedStorageViewModule:
+        @R.function
+        def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((1, 4), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 4), dtype="float32") = R.memory.view(
+                    x, R.shape([1, 4]), R.tuple(), R.tuple()
+                )
+                lv1: R.Tensor((1, 4), dtype="float32") = R.memory.view(
+                    x, R.shape([1, 4]), R.tuple(), R.tuple()
+                )
+                gv: R.Tensor((1, 4), dtype="float32") = R.add(lv, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class _SharedStorageBatchFlattenModule:
+        @R.function
+        def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((1, 4), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 4), dtype="float32") = R.nn.batch_flatten(x)
+                lv1: R.Tensor((1, 4), dtype="float32") = R.nn.batch_flatten(x)
+                gv: R.Tensor((1, 4), dtype="float32") = R.add(lv, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class _SharedStorageFlattenModule:
+        @R.function
+        def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((4,), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4,), dtype="float32") = R.flatten(x)
+                lv1: R.Tensor((4,), dtype="float32") = R.flatten(x)
+                gv: R.Tensor((4,), dtype="float32") = R.add(lv, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class _SharedStorageEnsureZeroOffsetModule:
+        @R.function
+        def main(x: R.Tensor((4, 1), dtype="float32")) -> R.Tensor((4, 1), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 1), dtype="float32") = 
R.memory.ensure_zero_offset(x)
+                lv1: R.Tensor((4, 1), dtype="float32") = 
R.memory.ensure_zero_offset(x)
+                gv: R.Tensor((4, 1), dtype="float32") = R.add(lv, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class _IndependentReluModule:
+        """Just a testcase to verify that non-view ops do not share storage."""
+
+        @R.function
+        def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((4,), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4,), dtype="float32") = R.nn.relu(x)
+                lv1: R.Tensor((4,), dtype="float32") = R.nn.relu(x)
+                gv: R.Tensor((4,), dtype="float32") = R.add(lv, lv1)
+                R.output(gv)
+            return gv
+
+    @classmethod
+    def _capture_op_tensors(cls, mod, input_nps, op_substr):
+        """Capture TVM tensors passed to VM calls whose name contains 
op_substr."""
+        captures = []
+
+        def instrument(func, name, before_run, ret_value, *args):
+            del func, ret_value
+            if not before_run:
+                return VMInstrumentReturnKind.NO_OP
+            if op_substr not in name.lower():
+                return VMInstrumentReturnKind.NO_OP
+            tensor_args = [arg for arg in args if isinstance(arg, 
tvm.runtime.Tensor)]
+            if not tensor_args:
+                return VMInstrumentReturnKind.NO_OP
+            captures.append({"call_name": name, "tensors": tensor_args})
+            return VMInstrumentReturnKind.NO_OP
+
+        if isinstance(input_nps, np.ndarray):
+            input_nps = [input_nps]
+
+        ex = relax.build(mod, tvm.target.Target("llvm"))
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+        vm.set_instrument(instrument)
+        vm["main"](*(tvm.runtime.tensor(arr, tvm.cpu()) for arr in input_nps))
+        return captures
+
+    @pytest.mark.parametrize(
+        "mod,input_nps,op_substr,expect_same_storage",
+        [
+            pytest.param(
+                _SharedStorageExpandDimsModule,
+                [storage_ptr_x_1d],
+                "add",
+                True,
+                id="shared_storage_expand_dims",
+            ),
+            pytest.param(
+                _SharedStorageSqueezeModule,
+                [storage_ptr_x_squeeze],
+                "add",
+                True,
+                id="shared_storage_squeeze",
+            ),
+            pytest.param(
+                _SharedStorageReshapeModule,
+                [storage_ptr_x_1d],
+                "add",
+                True,
+                id="shared_storage_reshape",
+            ),
+            pytest.param(
+                _SharedStoragePermuteDimsModule,
+                [storage_ptr_x_2d],
+                "add",
+                True,
+                id="shared_storage_permute_dims",
+            ),
+            pytest.param(
+                _SharedStorageFlattenModule,
+                [storage_ptr_x_2d],
+                "add",
+                True,
+                id="shared_storage_flatten",
+            ),
+            pytest.param(
+                _SharedStorageBatchFlattenModule,
+                [storage_ptr_x_2d],
+                "add",
+                True,
+                id="shared_storage_batch_flatten",
+            ),
+            pytest.param(
+                _SharedStorageViewModule,
+                [storage_ptr_x_1d],
+                "add",
+                True,
+                id="shared_storage_memory_view",
+            ),
+            pytest.param(
+                _SharedStorageEnsureZeroOffsetModule,
+                [storage_ptr_x_ensure_zero_offset],
+                "add",
+                True,
+                id="shared_storage_ensure_zero_offset",
+            ),
+            pytest.param(
+                _IndependentReluModule,
+                [storage_ptr_x_1d],
+                "add",
+                False,
+                id="independent_storage_relu",
+            ),
+        ],
+    )
+    def test_tensor_storage_ptr_extraction(self, mod, input_nps, op_substr, 
expect_same_storage):
+        """Validate runtime storage overlap/sharing via VM instrumentation."""
+        storage_shared = 
tvm.get_global_func("runtime.TVMTensorIsStorageShared")
+        captures = self._capture_op_tensors(mod, input_nps, op_substr)
+        assert len(captures), f"VM instrumentation did not see a {op_substr} 
call."
+        assert len(captures) == 1, f"VM instrumentation should see exactly one 
{op_substr} call."
+        cap = captures[0]
+        assert len(cap["tensors"]) == 3, (
+            f"VM instrumentation should see three {op_substr} tensor operands."
+        )
+        tensor_a, tensor_b = cap["tensors"][0], cap["tensors"][1]
+        call_name = cap["call_name"]
+        if expect_same_storage:
+            assert storage_shared(tensor_a, tensor_b), (
+                f"{mod.__name__}: operands should share the same storage (call 
{call_name!r})"
+            )
+        else:
+            assert not storage_shared(tensor_a, tensor_b), (
+                f"{mod.__name__}: operands must not share storage (call 
{call_name!r})"
+            )
+
+    @staticmethod
+    def _emit_duplicate_view(op, x):
+        if op == "relax.expand_dims":
+            a = relax.op.expand_dims(x, axis=1)
+            b = relax.op.expand_dims(x, axis=1)
+        elif op == "relax.squeeze":
+            a = relax.op.squeeze(x, axis=[0])
+            b = relax.op.squeeze(x, axis=[0])
+        elif op == "relax.reshape":
+            a = relax.op.reshape(x, (4, 1))
+            b = relax.op.reshape(x, (4, 1))
+        elif op == "relax.permute_dims":
+            a = relax.op.permute_dims(x, axes=[1, 0])
+            b = relax.op.permute_dims(x, axes=[1, 0])
+        elif op == "relax.memory.view":
+            a = relax.op.memory.view(x, (4, 1))
+            b = relax.op.memory.view(x, (4, 1))
+        elif op == "relax.memory.ensure_zero_offset":
+            a = relax.op.memory.ensure_zero_offset(x)
+            b = relax.op.memory.ensure_zero_offset(x)
+        elif op == "relax.flatten":
+            a = relax.op.flatten(x)
+            b = relax.op.flatten(x)
+        elif op == "relax.nn.batch_flatten":
+            a = relax.op.nn.batch_flatten(x)
+            b = relax.op.nn.batch_flatten(x)
+        else:
+            raise ValueError(op)
+        return a, b
+
+    @staticmethod
+    def _concat_axis_for_view_op(op):
+        if op == "relax.flatten":
+            return 0
+        return 1
+
+    @classmethod
+    def _build_module(cls, op):
+        if op == "relax.expand_dims":
+            x_sinfo = relax.TensorStructInfo((4,), "float32")
+        elif op == "relax.squeeze":
+            x_sinfo = relax.TensorStructInfo((1, 4, 1), "float32")
+        elif op == "relax.reshape":
+            x_sinfo = relax.TensorStructInfo((4,), "float32")
+        elif op == "relax.permute_dims":
+            x_sinfo = relax.TensorStructInfo((1, 4), "float32")
+        elif op == "relax.memory.view":
+            x_sinfo = relax.TensorStructInfo((4,), "float32")
+        elif op == "relax.memory.ensure_zero_offset":
+            x_sinfo = relax.TensorStructInfo((4, 1), "float32")
+        elif op in ("relax.flatten", "relax.nn.batch_flatten"):
+            x_sinfo = relax.TensorStructInfo((1, 4), "float32")
+        else:
+            raise ValueError(op)
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", x_sinfo)
+        concat_axis = cls._concat_axis_for_view_op(op)
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                a_expr, b_expr = cls._emit_duplicate_view(op, x)
+                a = bb.emit(a_expr)
+                b = bb.emit(b_expr)
+                prod = bb.emit(relax.op.multiply(a, b))
+                out = bb.emit(relax.op.concat([prod, b], axis=concat_axis))
+                gv = bb.emit_output(out)
+            bb.emit_func_output(gv)
+        return bb.finalize()
+
+    @classmethod
+    def _input_for_view_op(cls, op):
+        if op == "relax.squeeze":
+            return cls.storage_ptr_x_squeeze
+        if op == "relax.memory.ensure_zero_offset":
+            return cls.storage_ptr_x_ensure_zero_offset
+        if op in ("relax.permute_dims", "relax.flatten", 
"relax.nn.batch_flatten"):
+            return cls.storage_ptr_x_2d
+        return cls.storage_ptr_x_1d
+
+    @staticmethod
+    def _torch_duplicate_view(x, op):
+        if op == "relax.expand_dims":
+            return x.unsqueeze(1)
+        if op == "relax.squeeze":
+            return x.squeeze(0)
+        if op == "relax.reshape":
+            return x.reshape(4, 1)
+        if op == "relax.permute_dims":
+            return x.permute(1, 0)
+        if op == "relax.memory.view":
+            return x.reshape(4, 1)
+        if op == "relax.memory.ensure_zero_offset":
+            return x
+        if op == "relax.flatten":
+            return x.flatten()
+        if op == "relax.nn.batch_flatten":
+            # TVM: ndim==2 input keeps shape (1, 4).
+            return x
+        raise ValueError(op)
+
+    @classmethod
+    def _expected_for_view_op(cls, op):
+        x = torch.from_numpy(np.asarray(cls._input_for_view_op(op), 
dtype=np.float32))
+        a = cls._torch_duplicate_view(x, op)
+        b = cls._torch_duplicate_view(x, op)
+        prod = a * b
+        concat_axis = cls._concat_axis_for_view_op(op)
+        return torch.cat([prod, b], dim=concat_axis).numpy()
+
+    @pytest.mark.parametrize(
+        "view_op",
+        (
+            # Keep this list in sync with IsViewMemoryOp() in
+            # src/relax/transform/dataflow_inplace.cc
+            "relax.expand_dims",
+            "relax.squeeze",
+            "relax.reshape",
+            "relax.permute_dims",
+            "relax.flatten",
+            "relax.nn.batch_flatten",
+            "relax.memory.view",
+            "relax.memory.ensure_zero_offset",
+        ),
+    )
+    def test_no_inplace_when_view_ops_share_input(self, view_op):
+        mod = self._build_module(view_op)
+        func = mod["main"]
+        block = func.body.blocks[0]
+        params = list(func.params)
+
+        alias_sets, _ = dataflow_alias_analysis(block, params)
+        a_var = block.bindings[0].var
+        b_var = block.bindings[1].var
+        assert alias_sets[a_var] & alias_sets[b_var], (
+            f"{view_op}: duplicate views should share alias sets, but got "
+            f"{alias_sets[a_var]} and {alias_sets[b_var]}"
+        )
+
+        _, exact_match = dataflow_inplace_analysis(block, params, mod)
+        assert exact_match == [], f"{view_op}: expected no in-place 
opportunities"
+
+        x_np = self._input_for_view_op(view_op).copy()
+        mod_inplace = DataflowUseInplaceCalls()(mod)
+        tvm.ir.assert_structural_equal(mod_inplace, mod)
+
+        storage_shared = 
tvm.get_global_func("runtime.TVMTensorIsStorageShared")
+        captures = self._capture_op_tensors(mod_inplace, x_np, "multiply")
+        assert captures, f"{view_op}: VM instrumentation did not see a 
multiply call."
+        cap = next(c for c in captures if len(c["tensors"]) >= 2)
+        tensor_a, tensor_b = cap["tensors"][0], cap["tensors"][1]
+        assert storage_shared(tensor_a, tensor_b), (
+            f"{view_op}: multiply operands should share the same storage at 
runtime "
+            f"(call {cap['call_name']!r})"
+        )
+
+        ex = relax.build(mod_inplace, tvm.target.Target("llvm"))
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+        out = vm["main"](tvm.runtime.tensor(x_np, tvm.cpu()))
+        np.testing.assert_allclose(out.numpy(), 
self._expected_for_view_op(view_op))
+
+
 if __name__ == "__main__":
     testing.main()

Reply via email to