This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 4a8a7b9c63 [Unity] Implement LowerAllocTensor to remove
R.builtin.alloc_tensor (#15809)
4a8a7b9c63 is described below
commit 4a8a7b9c63e3dd09eb88929f425f52d074a3b05b
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Sep 29 14:23:02 2023 -0500
[Unity] Implement LowerAllocTensor to remove R.builtin.alloc_tensor (#15809)
* [Unity] Implement LowerAllocTensor to remove R.builtin.alloc_tensor
The `StaticPlanBlockMemory` transform is provided a module that
expresses all allocations with `R.builtin.alloc_tensor`, and produces
a module that uses `R.memory.alloc_storage` and
`R.memory.alloc_tensor` to express static allocations, while dynamic
allocations continue to use `R.builtin.alloc_tensor`.
Prior to this commit, this mixed output was handled as part of
`VMBuiltinLower`. This commit extracts the lowering of
`R.builtin.alloc_tensor` to a new pass, `LowerAllocTensor`. This pass
runs after `StaticPlanBlockMemory`, and replaces any remaining
`R.builtin.alloc_tensor` with calls to `R.memory.alloc_storage`
and `R.memory.alloc_tensor`.
* Updated unit tests
* Correct order of LowerAllocTensor and KillAfterLastUse
The `R.memory.alloc_storage` produced by `LowerAllocTensor` must be
present in order to be appropriately deleted by `KillAfterLastUse`.
---
python/tvm/relax/transform/transform.py | 20 ++++
python/tvm/relax/vm_build.py | 4 +-
src/relax/backend/vm/vm_builtin_lower.cc | 37 +------
src/relax/transform/lower_alloc_tensor.cc | 106 +++++++++++++++++++++
tests/python/relax/test_lower_alloc_tensor.py | 47 +++++++++
tests/python/relax/test_transform.py | 39 +-------
.../test_transform_static_plan_block_memory.py | 1 +
tests/python/relax/test_vm_builtin_lower.py | 86 +++++++++++++++++
8 files changed, 268 insertions(+), 72 deletions(-)
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 72a9966a4b..2b7a788e32 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -368,6 +368,26 @@ def StaticPlanBlockMemory() -> tvm.ir.transform.Pass:
return _ffi_api.StaticPlanBlockMemory() # type: ignore
+def LowerAllocTensor() -> tvm.ir.transform.Pass:
+ """Lower remaining instances of R.builtin.alloc_tensor
+
+ The static memory planner removes static instances of
+ `R.builtin.alloc_tensor`, replacing with `R.memory.alloc_storage`
+ and `R.memory.alloc_tensor`. However, `R.builtin.alloc_tensor`
+ still remains for any dynamic allocations.
+
+ This transform replaces any remaining `R.builtin.alloc_tensor`
+ instances with `R.memory.alloc_storage` and
+ `R.memory.alloc_tensor`. If no `R.builtin.alloc_tensor` are
+ present, this pass has no effect.
+
+ Returns
+ -------
+ ret : tvm.ir.transform.Pass
+ """
+ return _ffi_api.LowerAllocTensor() # type: ignore
+
+
def KillAfterLastUse() -> tvm.ir.transform.Pass:
"""Drop all tensor/storage objects after last use
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index 142da5c451..d5edeeec69 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -310,11 +310,13 @@ def build(
passes.append(relax.transform.RemovePurityChecking())
passes.append(relax.transform.CallTIRRewrite())
passes.append(relax.transform.StaticPlanBlockMemory())
- passes.append(relax.transform.KillAfterLastUse())
if
tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph",
False):
passes.append(relax.transform.RewriteCUDAGraph())
+ passes.append(relax.transform.LowerAllocTensor())
+ passes.append(relax.transform.KillAfterLastUse())
+
passes.append(relax.transform.VMBuiltinLower())
passes.append(relax.transform.VMShapeLower())
passes.append(relax.transform.AttachGlobalSymbol())
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc
b/src/relax/backend/vm/vm_builtin_lower.cc
index 784b3c9fd5..887998d004 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -54,7 +54,10 @@ class VMBuiltinLowerMutator : public ExprMutator {
} else if (call->op == invoke_closure_op_) {
return InvokeClosure(call);
} else if (call->op == alloc_tensor_op_) {
- return MakeAllocTensor(call);
+ LOG(FATAL) << "VMBuiltinLower encountered " << call->op << " in
expression "
+ << GetRef<Call>(call_node) << ". "
+ << "This operation should have been lowered earlier "
+ << "using the 'relax.transform.LowerAllocTensor' pass.";
} else if (call->op == mem_alloc_storage_op_) {
return MakeMemAllocStorage(call);
} else if (call->op == mem_alloc_tensor_op_) {
@@ -66,38 +69,6 @@ class VMBuiltinLowerMutator : public ExprMutator {
}
}
- Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const {
- // Question: what if the dtype of tensor_type is unknown?
- // Symbolic/static shape case
- if (auto* shape_expr = shape.as<ShapeExprNode>()) {
- int64_t elem_bytes = runtime::GetVectorBytes(dtype);
- PrimExpr ret = IntImm(DataType::Int(64), elem_bytes);
- for (PrimExpr dim : shape_expr->values) {
- ret = ret * dim;
- }
- return ShapeExpr({ret});
- } else {
- return Call(builtin_compute_alloc_shape_, {shape, DataTypeImm(dtype)},
Attrs(),
- {GetStructInfo(shape)});
- }
- }
-
- Expr MakeAllocTensor(const Call& call) {
- ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
- DataTypeImm output_dtype = Downcast<DataTypeImm>(call->args[1]);
- DataType dtype = output_dtype->value;
- Expr storage_size = ComputeStorageSize(output_shape, dtype);
- PrimValue runtime_device_index = Downcast<PrimValue>(call->args[2]);
- Var storage = builder_->Emit(Call(vm_alloc_storage_op_,
- {storage_size, runtime_device_index,
- DataTypeImm(DataType::UInt(8)),
StringImm("global")},
- Attrs()),
- "storage");
- Expr shape = call->args[0];
- PrimValue offset = PrimValue::Int64(0);
- return Call(vm_alloc_tensor_op_, {storage, offset, shape,
DataTypeImm(dtype)}, Attrs());
- }
-
Expr MakeMemAllocStorage(const Call& call) {
PrimValue runtime_device_index = Downcast<PrimValue>(call->args[1]);
StringImm storage_scope = Downcast<StringImm>(call->args[2]);
diff --git a/src/relax/transform/lower_alloc_tensor.cc
b/src/relax/transform/lower_alloc_tensor.cc
new file mode 100644
index 0000000000..f0db2447d9
--- /dev/null
+++ b/src/relax/transform/lower_alloc_tensor.cc
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/lower_alloc_tensor.cc
+ * \brief Lower any relax.builtin.alloc_tensor remaining after static planning
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+class Mutator : public ExprMutator {
+ using ExprMutator::VisitExpr_;
+ Expr VisitExpr_(const CallNode* op) override {
+ static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
+ static const Op& mem_alloc_storage_op =
Op::Get("relax.memory.alloc_storage");
+ static const Op& mem_alloc_tensor_op =
Op::Get("relax.memory.alloc_tensor");
+
+ if (op->op.same_as(alloc_tensor_op)) {
+ CHECK_EQ(op->args.size(), 3) << "Op " << op->op << " should have three
arguments, "
+ << "[shape, dtype, runtime_device_index]. "
+ << "However, received " << GetRef<Call>(op);
+
+ auto shape_arg = op->args[0];
+ auto dtype = Downcast<DataTypeImm>(op->args[1]);
+ PrimValue runtime_device_index = Downcast<PrimValue>(op->args[2]);
+ std::string storage_scope = "global";
+
+ auto shape = [&]() -> Array<PrimExpr> {
+ if (auto ptr = shape_arg.as<ShapeExprNode>()) {
+ return ptr->values;
+ }
+
+ auto sinfo = GetStructInfo(shape_arg);
+ if (auto ptr = sinfo.as<ShapeStructInfoNode>()) {
+ if (ptr->values) {
+ return ptr->values.value();
+ }
+ }
+
+ LOG(FATAL) << "Shape argument for " << alloc_tensor_op << " should be
a ShapeExpr, "
+ << "or a variable that holds a ShapeExpr. "
+ << "However, received argument " << shape_arg << " with
struct info " << sinfo;
+ }();
+
+ PrimExpr nbytes = [&]() -> PrimExpr {
+ PrimExpr nbytes = tir::make_const(DataType::Int(64),
dtype->value.bytes());
+ for (const auto& dim : shape) {
+ nbytes *= dim;
+ }
+ return nbytes;
+ }();
+
+ auto offset = PrimValue::Int64(0);
+
+ Expr storage = relax::Call(mem_alloc_storage_op,
+ {ShapeExpr({nbytes}), runtime_device_index,
+ StringImm(storage_scope),
DataTypeImm(DataType::UInt(8))});
+ storage = builder_->Emit(storage, "storage");
+ Expr tensor = relax::Call(mem_alloc_tensor_op, {storage, offset,
shape_arg, dtype});
+ return tensor;
+ } else {
+ return ExprMutator::VisitExpr_(op);
+ }
+ }
+};
+} // namespace
+
+Expr LowerAllocTensor(Expr expr) {
+ Mutator mutator;
+ return mutator(expr);
+}
+
+namespace transform {
+
+Pass LowerAllocTensor() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function func, IRModule m, PassContext pc) {
+ return Downcast<Function>(relax::LowerAllocTensor(std::move(func)));
+ };
+ return CreateFunctionPass(pass_func, /*opt_level=*/0, "LowerAllocTensor",
{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.LowerAllocTensor").set_body_typed(LowerAllocTensor);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_lower_alloc_tensor.py
b/tests/python/relax/test_lower_alloc_tensor.py
new file mode 100644
index 0000000000..3d1415d245
--- /dev/null
+++ b/tests/python/relax/test_lower_alloc_tensor.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+
+from tvm.script import ir as I, relax as R
+
+from tvm.relax.transform import LowerAllocTensor
+
+
+def test_basic():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main():
+ x = R.builtin.alloc_tensor(R.shape([16, 32]), "float32", 0)
+ return x
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main():
+ storage = R.memory.alloc_storage(R.shape([2048]), 0, "global",
"uint8")
+ x = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32")
+ return x
+
+ After = LowerAllocTensor()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform.py
b/tests/python/relax/test_transform.py
index d690cc55e3..9ab2ffc605 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -19,7 +19,6 @@ import pytest
import tvm
from tvm import relax
from tvm.ir import structural_equal
-from tvm.ir.base import assert_structural_equal
import tvm.script
from tvm.script import tir as T, relax as R
@@ -485,41 +484,5 @@ def test_call_tir_inplace_all_new():
return gv0
-def test_vm_builtin_lower():
- @tvm.script.ir_module
- class TestVMBuiltinLower:
- @R.function
- def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
- # we expected RemovePurityChecking to have been called first
- R.func_attr({"relax.force_pure": True})
- m, n = T.int64(), T.int64()
- alloc = R.builtin.alloc_tensor(R.shape([m, n]),
runtime_device_index=0, dtype="float32")
- _ = R.call_packed(
- "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2,
dtype="float32"))
- )
- gv0 = alloc
- return gv0
-
- mod = TestVMBuiltinLower
-
- # after vm builtin lowering
- new_mod = relax.transform.VMBuiltinLower()(mod)
- func = new_mod["foo"]
-
- assert isinstance(new_mod, tvm.IRModule)
- assert isinstance(func, tvm.relax.expr.Function)
-
- block = func.body.blocks[0]
- s1 = block.bindings[0].value
- assert isinstance(s1, relax.Call)
- assert s1.op.name == "relax.vm.alloc_storage"
- s2 = block.bindings[1].value
- assert isinstance(s2, relax.Call)
- s3 = block.bindings[2].value
- assert isinstance(s3, relax.Call)
- assert isinstance(s3.op, relax.ExternFunc)
- assert s3.op.global_symbol == "test.op.identity"
-
-
if __name__ == "__main__":
- pytest.main([__file__])
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 147fc049c9..451fe4fbe2 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -190,6 +190,7 @@ def test_basic():
mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Expected)
+ mod = relax.transform.LowerAllocTensor()(mod)
mod = relax.transform.VMBuiltinLower()(mod)
tvm.ir.assert_structural_equal(mod, ExpectedLowered)
diff --git a/tests/python/relax/test_vm_builtin_lower.py
b/tests/python/relax/test_vm_builtin_lower.py
new file mode 100644
index 0000000000..df28db4d46
--- /dev/null
+++ b/tests/python/relax/test_vm_builtin_lower.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+from tvm import relax
+
+import tvm.script
+from tvm.script import ir as I, relax as R, tir as T
+
+
+def test_vm_builtin_lower_mem_alloc_storage():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
+ R.func_attr({"relax.force_pure": True})
+ m, n = T.int64(), T.int64()
+
+ storage = R.memory.alloc_storage(R.shape([m * n * 4]), 0,
"global", "uint8")
+ alloc = R.memory.alloc_tensor(storage, 0, R.shape([m, n]),
"float32")
+ _ = R.call_packed(
+ "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2,
dtype="float32"))
+ )
+ gv0 = alloc
+ return gv0
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
+ # we expected RemovePurityChecking to have been called first
+ R.func_attr({"relax.force_pure": True})
+ m, n = T.int64(), T.int64()
+
+ storage = R.vm.alloc_storage(R.shape([m * n * 4]),
R.prim_value(0), "uint8", "global")
+ alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape([m,
n]), "float32")
+
+ _ = R.call_packed(
+ "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2,
dtype="float32"))
+ )
+ gv0 = alloc
+ return gv0
+
+ After = relax.transform.VMBuiltinLower()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_vm_builtin_alloc_tensor_raises_error():
+ """R.builtin.alloc_tensor should be handled earlier"""
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
+ R.func_attr({"relax.force_pure": True})
+ m, n = T.int64(), T.int64()
+
+ alloc = R.builtin.alloc_tensor(R.shape([m, n]),
runtime_device_index=0, dtype="float32")
+ _ = R.call_packed(
+ "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2,
dtype="float32"))
+ )
+ gv0 = alloc
+ return gv0
+
+ with pytest.raises(tvm.TVMError):
+ relax.transform.VMBuiltinLower()(Before)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()