This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4a8a7b9c63 [Unity] Implement LowerAllocTensor to remove 
R.builtin.alloc_tensor (#15809)
4a8a7b9c63 is described below

commit 4a8a7b9c63e3dd09eb88929f425f52d074a3b05b
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Sep 29 14:23:02 2023 -0500

    [Unity] Implement LowerAllocTensor to remove R.builtin.alloc_tensor (#15809)
    
    * [Unity] Implement LowerAllocTensor to remove R.builtin.alloc_tensor
    
    The `StaticPlanBlockMemory` transform is provided a module that
    expresses all allocations with `R.builtin.alloc_tensor`, and produces
    a module that uses `R.memory.alloc_storage` and
    `R.memory.alloc_tensor` to express static allocations, while dynamic
    allocations continue to use `R.builtin.alloc_tensor`.
    
    Prior to this commit, this mixed output was handled as part of
    `VMBuiltinLower`. This commit extracts the lowering of
    `R.builtin.alloc_tensor` to a new pass, `LowerAllocTensor`.  This pass
    runs after `StaticPlanBlockMemory`, and replaces any remaining
    `R.builtin.alloc_tensor` with calls to `R.memory.alloc_storage`
    and `R.memory.alloc_tensor`.
    
    * Updated unit tests
    
    * Correct order of LowerAllocTensor and KillAfterLastUse
    
    The `R.memory.alloc_storage` produced by `LowerAllocTensor` must be
    present in order to be appropriately deleted by `KillAfterLastUse`.
---
 python/tvm/relax/transform/transform.py            |  20 ++++
 python/tvm/relax/vm_build.py                       |   4 +-
 src/relax/backend/vm/vm_builtin_lower.cc           |  37 +------
 src/relax/transform/lower_alloc_tensor.cc          | 106 +++++++++++++++++++++
 tests/python/relax/test_lower_alloc_tensor.py      |  47 +++++++++
 tests/python/relax/test_transform.py               |  39 +-------
 .../test_transform_static_plan_block_memory.py     |   1 +
 tests/python/relax/test_vm_builtin_lower.py        |  86 +++++++++++++++++
 8 files changed, 268 insertions(+), 72 deletions(-)

diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 72a9966a4b..2b7a788e32 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -368,6 +368,26 @@ def StaticPlanBlockMemory() -> tvm.ir.transform.Pass:
     return _ffi_api.StaticPlanBlockMemory()  # type: ignore
 
 
+def LowerAllocTensor() -> tvm.ir.transform.Pass:
+    """Lower remaining instances of R.builtin.alloc_tensor
+
+    The static memory planner removes static instances of
+    `R.builtin.alloc_tensor`, replacing with `R.memory.alloc_storage`
+    and `R.memory.alloc_tensor`.  However, `R.builtin.alloc_tensor`
+    still remains for any dynamic allocations.
+
+    This transform replaces any remaining `R.builtin.alloc_tensor`
+    instances with `R.memory.alloc_storage` and
+    `R.memory.alloc_tensor`.  If no `R.builtin.alloc_tensor` are
+    present, this pass has no effect.
+
+    Returns
+    -------
+    ret : tvm.ir.transform.Pass
+    """
+    return _ffi_api.LowerAllocTensor()  # type: ignore
+
+
 def KillAfterLastUse() -> tvm.ir.transform.Pass:
     """Drop all tensor/storage objects after last use
 
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 142da5c451..d5edeeec69 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -310,11 +310,13 @@ def build(
     passes.append(relax.transform.RemovePurityChecking())
     passes.append(relax.transform.CallTIRRewrite())
     passes.append(relax.transform.StaticPlanBlockMemory())
-    passes.append(relax.transform.KillAfterLastUse())
 
     if 
tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", 
False):
         passes.append(relax.transform.RewriteCUDAGraph())
 
+    passes.append(relax.transform.LowerAllocTensor())
+    passes.append(relax.transform.KillAfterLastUse())
+
     passes.append(relax.transform.VMBuiltinLower())
     passes.append(relax.transform.VMShapeLower())
     passes.append(relax.transform.AttachGlobalSymbol())
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc 
b/src/relax/backend/vm/vm_builtin_lower.cc
index 784b3c9fd5..887998d004 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -54,7 +54,10 @@ class VMBuiltinLowerMutator : public ExprMutator {
     } else if (call->op == invoke_closure_op_) {
       return InvokeClosure(call);
     } else if (call->op == alloc_tensor_op_) {
-      return MakeAllocTensor(call);
+      LOG(FATAL) << "VMBuiltinLower encountered " << call->op << " in 
expression "
+                 << GetRef<Call>(call_node) << ".  "
+                 << "This operation should have been lowered earlier "
+                 << "using the 'relax.transform.LowerAllocTensor' pass.";
     } else if (call->op == mem_alloc_storage_op_) {
       return MakeMemAllocStorage(call);
     } else if (call->op == mem_alloc_tensor_op_) {
@@ -66,38 +69,6 @@ class VMBuiltinLowerMutator : public ExprMutator {
     }
   }
 
-  Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const {
-    // Question: what if the dtype of tensor_type is unknown?
-    // Symbolic/static shape case
-    if (auto* shape_expr = shape.as<ShapeExprNode>()) {
-      int64_t elem_bytes = runtime::GetVectorBytes(dtype);
-      PrimExpr ret = IntImm(DataType::Int(64), elem_bytes);
-      for (PrimExpr dim : shape_expr->values) {
-        ret = ret * dim;
-      }
-      return ShapeExpr({ret});
-    } else {
-      return Call(builtin_compute_alloc_shape_, {shape, DataTypeImm(dtype)}, 
Attrs(),
-                  {GetStructInfo(shape)});
-    }
-  }
-
-  Expr MakeAllocTensor(const Call& call) {
-    ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
-    DataTypeImm output_dtype = Downcast<DataTypeImm>(call->args[1]);
-    DataType dtype = output_dtype->value;
-    Expr storage_size = ComputeStorageSize(output_shape, dtype);
-    PrimValue runtime_device_index = Downcast<PrimValue>(call->args[2]);
-    Var storage = builder_->Emit(Call(vm_alloc_storage_op_,
-                                      {storage_size, runtime_device_index,
-                                       DataTypeImm(DataType::UInt(8)), 
StringImm("global")},
-                                      Attrs()),
-                                 "storage");
-    Expr shape = call->args[0];
-    PrimValue offset = PrimValue::Int64(0);
-    return Call(vm_alloc_tensor_op_, {storage, offset, shape, 
DataTypeImm(dtype)}, Attrs());
-  }
-
   Expr MakeMemAllocStorage(const Call& call) {
     PrimValue runtime_device_index = Downcast<PrimValue>(call->args[1]);
     StringImm storage_scope = Downcast<StringImm>(call->args[2]);
diff --git a/src/relax/transform/lower_alloc_tensor.cc 
b/src/relax/transform/lower_alloc_tensor.cc
new file mode 100644
index 0000000000..f0db2447d9
--- /dev/null
+++ b/src/relax/transform/lower_alloc_tensor.cc
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/lower_alloc_tensor.cc
+ * \brief Lower any relax.builtin.alloc_tensor remaining after static planning
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+class Mutator : public ExprMutator {
+  using ExprMutator::VisitExpr_;
+  Expr VisitExpr_(const CallNode* op) override {
+    static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+    static const Op& mem_alloc_storage_op = 
Op::Get("relax.memory.alloc_storage");
+    static const Op& mem_alloc_tensor_op = 
Op::Get("relax.memory.alloc_tensor");
+
+    if (op->op.same_as(alloc_tensor_op)) {
+      CHECK_EQ(op->args.size(), 3) << "Op " << op->op << " should have three 
arguments, "
+                                   << "[shape, dtype, runtime_device_index].  "
+                                   << "However, received " << GetRef<Call>(op);
+
+      auto shape_arg = op->args[0];
+      auto dtype = Downcast<DataTypeImm>(op->args[1]);
+      PrimValue runtime_device_index = Downcast<PrimValue>(op->args[2]);
+      std::string storage_scope = "global";
+
+      auto shape = [&]() -> Array<PrimExpr> {
+        if (auto ptr = shape_arg.as<ShapeExprNode>()) {
+          return ptr->values;
+        }
+
+        auto sinfo = GetStructInfo(shape_arg);
+        if (auto ptr = sinfo.as<ShapeStructInfoNode>()) {
+          if (ptr->values) {
+            return ptr->values.value();
+          }
+        }
+
+        LOG(FATAL) << "Shape argument for " << alloc_tensor_op << " should be 
a ShapeExpr, "
+                   << "or a variable that holds a ShapeExpr.  "
+                   << "However, received argument " << shape_arg << " with 
struct info " << sinfo;
+      }();
+
+      PrimExpr nbytes = [&]() -> PrimExpr {
+        PrimExpr nbytes = tir::make_const(DataType::Int(64), 
dtype->value.bytes());
+        for (const auto& dim : shape) {
+          nbytes *= dim;
+        }
+        return nbytes;
+      }();
+
+      auto offset = PrimValue::Int64(0);
+
+      Expr storage = relax::Call(mem_alloc_storage_op,
+                                 {ShapeExpr({nbytes}), runtime_device_index,
+                                  StringImm(storage_scope), 
DataTypeImm(DataType::UInt(8))});
+      storage = builder_->Emit(storage, "storage");
+      Expr tensor = relax::Call(mem_alloc_tensor_op, {storage, offset, 
shape_arg, dtype});
+      return tensor;
+    } else {
+      return ExprMutator::VisitExpr_(op);
+    }
+  }
+};
+}  // namespace
+
+Expr LowerAllocTensor(Expr expr) {
+  Mutator mutator;
+  return mutator(expr);
+}
+
+namespace transform {
+
+Pass LowerAllocTensor() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function func, IRModule m, PassContext pc) {
+        return Downcast<Function>(relax::LowerAllocTensor(std::move(func)));
+      };
+  return CreateFunctionPass(pass_func, /*opt_level=*/0, "LowerAllocTensor", 
{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.LowerAllocTensor").set_body_typed(LowerAllocTensor);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_lower_alloc_tensor.py 
b/tests/python/relax/test_lower_alloc_tensor.py
new file mode 100644
index 0000000000..3d1415d245
--- /dev/null
+++ b/tests/python/relax/test_lower_alloc_tensor.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+
+from tvm.script import ir as I, relax as R
+
+from tvm.relax.transform import LowerAllocTensor
+
+
+def test_basic():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main():
+            x = R.builtin.alloc_tensor(R.shape([16, 32]), "float32", 0)
+            return x
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main():
+            storage = R.memory.alloc_storage(R.shape([2048]), 0, "global", 
"uint8")
+            x = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32")
+            return x
+
+    After = LowerAllocTensor()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform.py 
b/tests/python/relax/test_transform.py
index d690cc55e3..9ab2ffc605 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -19,7 +19,6 @@ import pytest
 import tvm
 from tvm import relax
 from tvm.ir import structural_equal
-from tvm.ir.base import assert_structural_equal
 
 import tvm.script
 from tvm.script import tir as T, relax as R
@@ -485,41 +484,5 @@ def test_call_tir_inplace_all_new():
             return gv0
 
 
-def test_vm_builtin_lower():
-    @tvm.script.ir_module
-    class TestVMBuiltinLower:
-        @R.function
-        def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
-            # we expected RemovePurityChecking to have been called first
-            R.func_attr({"relax.force_pure": True})
-            m, n = T.int64(), T.int64()
-            alloc = R.builtin.alloc_tensor(R.shape([m, n]), 
runtime_device_index=0, dtype="float32")
-            _ = R.call_packed(
-                "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, 
dtype="float32"))
-            )
-            gv0 = alloc
-            return gv0
-
-    mod = TestVMBuiltinLower
-
-    # after vm builtin lowering
-    new_mod = relax.transform.VMBuiltinLower()(mod)
-    func = new_mod["foo"]
-
-    assert isinstance(new_mod, tvm.IRModule)
-    assert isinstance(func, tvm.relax.expr.Function)
-
-    block = func.body.blocks[0]
-    s1 = block.bindings[0].value
-    assert isinstance(s1, relax.Call)
-    assert s1.op.name == "relax.vm.alloc_storage"
-    s2 = block.bindings[1].value
-    assert isinstance(s2, relax.Call)
-    s3 = block.bindings[2].value
-    assert isinstance(s3, relax.Call)
-    assert isinstance(s3.op, relax.ExternFunc)
-    assert s3.op.global_symbol == "test.op.identity"
-
-
 if __name__ == "__main__":
-    pytest.main([__file__])
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 147fc049c9..451fe4fbe2 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -190,6 +190,7 @@ def test_basic():
 
     mod = relax.transform.StaticPlanBlockMemory()(Module)
     tvm.ir.assert_structural_equal(mod, Expected)
+    mod = relax.transform.LowerAllocTensor()(mod)
     mod = relax.transform.VMBuiltinLower()(mod)
     tvm.ir.assert_structural_equal(mod, ExpectedLowered)
 
diff --git a/tests/python/relax/test_vm_builtin_lower.py 
b/tests/python/relax/test_vm_builtin_lower.py
new file mode 100644
index 0000000000..df28db4d46
--- /dev/null
+++ b/tests/python/relax/test_vm_builtin_lower.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+from tvm import relax
+
+import tvm.script
+from tvm.script import ir as I, relax as R, tir as T
+
+
+def test_vm_builtin_lower_mem_alloc_storage():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
+            R.func_attr({"relax.force_pure": True})
+            m, n = T.int64(), T.int64()
+
+            storage = R.memory.alloc_storage(R.shape([m * n * 4]), 0, 
"global", "uint8")
+            alloc = R.memory.alloc_tensor(storage, 0, R.shape([m, n]), 
"float32")
+            _ = R.call_packed(
+                "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, 
dtype="float32"))
+            )
+            gv0 = alloc
+            return gv0
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
+            # we expected RemovePurityChecking to have been called first
+            R.func_attr({"relax.force_pure": True})
+            m, n = T.int64(), T.int64()
+
+            storage = R.vm.alloc_storage(R.shape([m * n * 4]), 
R.prim_value(0), "uint8", "global")
+            alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([m, 
n]), "float32")
+
+            _ = R.call_packed(
+                "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, 
dtype="float32"))
+            )
+            gv0 = alloc
+            return gv0
+
+    After = relax.transform.VMBuiltinLower()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_vm_builtin_alloc_tensor_raises_error():
+    """R.builtin.alloc_tensor should be handled earlier"""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
+            R.func_attr({"relax.force_pure": True})
+            m, n = T.int64(), T.int64()
+
+            alloc = R.builtin.alloc_tensor(R.shape([m, n]), 
runtime_device_index=0, dtype="float32")
+            _ = R.call_packed(
+                "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, 
dtype="float32"))
+            )
+            gv0 = alloc
+            return gv0
+
+    with pytest.raises(tvm.TVMError):
+        relax.transform.VMBuiltinLower()(Before)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to