This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7ef36ebb5d [Unity] Support symbolic PrimValue arguments (#15980)
7ef36ebb5d is described below
commit 7ef36ebb5d056320676faede712f2052d92f7a5d
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Oct 25 15:50:35 2023 -0500
[Unity] Support symbolic PrimValue arguments (#15980)
Prior this this commit, all symbolic variables needed to be defined
either by tensor shapes, or by an explicit `tvm.runtime.ShapeTuple`
argument. This commit allows arguments `arg: R.Prim(value="n")` to
serve as a source of definition for symbolic variables.
---
src/relax/backend/vm/codegen_vm.cc | 7 +-
src/relax/backend/vm/vm_shape_lower.cc | 158 +++++++++++++++++++++++----------
src/runtime/ndarray.cc | 4 +-
src/runtime/relax_vm/builtin.cc | 88 ++++++++++++++++++
tests/python/relax/test_vm_build.py | 120 +++++++++++++++++++++++++
5 files changed, 325 insertions(+), 52 deletions(-)
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index caee0a0c13..64b87c6c12 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -246,10 +246,11 @@ class CodeGenVM : public
ExprFunctor<Instruction::Arg(const Expr&)> {
Instruction::Arg VisitExpr_(const PrimValueNode* op) final {
if (auto* int_imm = op->value.as<IntImmNode>()) {
return builder_->ConvertConstant(int_imm->value);
- } else {
- auto* float_imm = op->value.as<FloatImmNode>();
- ICHECK(float_imm) << "PrimValue can only be IntImm/FloatImm for now";
+ } else if (auto* float_imm = op->value.as<FloatImmNode>()) {
return builder_->ConvertConstant(float_imm->value);
+ } else {
+ LOG(FATAL) << "PrimValue should only contain constant after
VMShapeLower, "
+ << "but received " << GetRef<Expr>(op) << " with type " <<
op->value->GetTypeKey();
}
}
diff --git a/src/relax/backend/vm/vm_shape_lower.cc
b/src/relax/backend/vm/vm_shape_lower.cc
index 8b8eb33f5b..41b27ea625 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -347,6 +347,41 @@ class VMShapeLowerMutator
return GetRef<Expr>(op);
}
+ std::pair<Expr, Expr> MakeSymbolicShapeArg(const PrimExpr& expr) {
+ using runtime::relax_vm::MakeShapeCode;
+
+ if (auto* int_expr = expr.as<IntImmNode>()) {
+ return {PrimValue::Int64(static_cast<int>(MakeShapeCode::kUseImm)),
+ PrimValue::Int64(int_expr->value)};
+ } else {
+ auto it = slot_map_.find(expr);
+ ICHECK(it != slot_map_.end());
+ auto* slot = it->second;
+ ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been
computed";
+ return {PrimValue::Int64(static_cast<int>(MakeShapeCode::kLoadShape)),
+ PrimValue::Int64(slot->index)};
+ }
+ }
+
+ Expr VisitExpr_(const PrimValueNode* op) final {
+ using runtime::relax_vm::MakeShapeCode;
+ // Constant shape can be preserved.
+ bool is_const_value =
+ op->value->IsInstance<IntImmNode>() ||
op->value->IsInstance<FloatImmNode>();
+ if (is_const_value) {
+ return GetRef<Expr>(op);
+ }
+
+ Array<Expr> args = {shape_heap_};
+ auto [code, value_or_index] = MakeSymbolicShapeArg(op->value);
+ args.push_back(code);
+ args.push_back(value_or_index);
+
+ // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
+ Call call(builtin_make_prim_value_, args, Attrs(),
{Downcast<StructInfo>(op->struct_info_)});
+ return call;
+ }
+
Expr VisitExpr_(const ShapeExprNode* op) final {
using runtime::relax_vm::MakeShapeCode;
// Constant shape can be preserved.
@@ -359,17 +394,9 @@ class VMShapeLowerMutator
Array<Expr> args = {shape_heap_,
PrimValue::Int64(static_cast<int64_t>(op->values.size()))};
for (PrimExpr expr : op->values) {
- if (auto* int_expr = expr.as<IntImmNode>()) {
-
args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kUseImm)));
- args.push_back(PrimValue::Int64(int_expr->value));
- } else {
- auto it = slot_map_.find(expr);
- ICHECK(it != slot_map_.end());
- auto* slot = it->second;
- ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been
computed";
-
args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kLoadShape)));
- args.push_back(PrimValue::Int64(slot->index));
- }
+ auto [code, value_or_index] = MakeSymbolicShapeArg(expr);
+ args.push_back(code);
+ args.push_back(value_or_index);
}
// make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
@@ -402,6 +429,45 @@ class VMShapeLowerMutator
// Place this pass as last pass before codegen.
StructInfo VisitExprDepStructInfoField(const StructInfo& sinfo) final {
return sinfo; }
+ /* \brief Internal utility function used for RunMatch()
+ *
+ * \param expr The expression to be matched
+ *
+ * \param require_value_computed Whether we require all expr to be computed.
+ *
+ * \return The MatchShapeCode, and a relax expression specifying the
+ * argument used by that MatchShapeCode.
+ */
+ std::pair<runtime::relax_vm::MatchShapeCode, Expr> MakeMatchArgs(const
PrimExpr& expr,
+ bool
require_value_computed) {
+ using runtime::relax_vm::MatchShapeCode;
+
+ if (auto* int_expr = expr.as<IntImmNode>()) {
+ return {MatchShapeCode::kAssertEqualToImm,
PrimValue::Int64(int_expr->value)};
+ }
+
+ auto it = slot_map_.find(expr);
+ ICHECK(it != slot_map_.end());
+ auto* slot = it->second;
+ if (slot->value_computed) {
+ return {MatchShapeCode::kAssertEqualToLoad,
PrimValue::Int64(slot->index)};
+ }
+
+ // the value is not yet computed
+ ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not
computed";
+ if (expr.as<tir::VarNode>()) {
+ // It is a var we will populate it in this round.
+
+ slot->value_computed = true;
+ ready_vars_.push_back(slot);
+
+ return {MatchShapeCode::kStoreToHeap, PrimValue::Int64(slot->index)};
+ }
+
+ // otherwise, we skip and mark it as outstanding
+ return {MatchShapeCode::kNoOp, PrimValue::Int64(0)};
+ }
+
//-------------------------------------------------------
// Shape computations.
//-------------------------------------------------------
@@ -426,52 +492,33 @@ class VMShapeLowerMutator
using runtime::relax_vm::MatchShapeCode;
for (const MatchShapeTodoItem& item : match_todos) {
- int64_t shape_len = static_cast<int64_t>(item.pattern.size());
bool all_nop = true;
- int num_outstanding_exprs = 0;
+ bool any_nop = false;
- Array<Expr> args = {item.input, shape_heap_,
PrimValue::Int64(shape_len)};
+ Array<Expr> args = {item.input, shape_heap_};
+
+ Expr match_op;
+ if (item.input->struct_info_.as<PrimStructInfoNode>()) {
+ match_op = builtin_match_prim_value_;
+ ICHECK_EQ(item.pattern.size(), 1);
+ } else {
+ match_op = builtin_match_shape_;
+ args.push_back(PrimValue::Int64(item.pattern.size()));
+ }
for (PrimExpr expr : item.pattern) {
- MatchShapeCode code = MatchShapeCode::kNoOp;
- int64_t rvalue = 0;
- if (auto* int_expr = expr.as<IntImmNode>()) {
- code = MatchShapeCode::kAssertEqualToImm;
- rvalue = int_expr->value;
- } else {
- auto it = slot_map_.find(expr);
- ICHECK(it != slot_map_.end());
- auto* slot = it->second;
- if (slot->value_computed) {
- code = MatchShapeCode::kAssertEqualToLoad;
- rvalue = slot->index;
- } else {
- // the value is not yet computed
- ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not
computed";
- if (expr.as<tir::VarNode>()) {
- // if it is a var, we will populate it in this round.
- // otherwise, we skip and mark it as outstanding
- code = MatchShapeCode::kStoreToHeap;
- rvalue = slot->index;
- slot->value_computed = true;
- ready_vars_.push_back(slot);
- } else {
- code = MatchShapeCode::kNoOp;
- rvalue = 0;
- ++num_outstanding_exprs;
- }
- }
- }
+ auto [code, rvalue] = MakeMatchArgs(expr, require_value_computed);
all_nop = all_nop && code == MatchShapeCode::kNoOp;
+ any_nop = any_nop || code == MatchShapeCode::kNoOp;
args.push_back(PrimValue::Int64(static_cast<int>(code)));
- args.push_back(PrimValue::Int64(rvalue));
+ args.push_back(rvalue);
}
- if (num_outstanding_exprs != 0) {
+ if (any_nop) {
outstanding_todos.push_back(item);
}
args.push_back(GetErrContext(item.err_ctx));
if (!all_nop) {
- Call call(builtin_match_shape_, args, Attrs(), {void_sinfo_});
+ Call call(match_op, args, Attrs(), {void_sinfo_});
builder_->Emit(call, "_");
}
}
@@ -592,8 +639,20 @@ class VMShapeLowerMutator
void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool
always_check,
bool dynamic_only, const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
- // TODO(relax-team) add PrimValue checks later.
- LOG(FATAL) << "MatchCast of PrimValue is not yet supported";
+ // emit runtime check of shape
+ if (always_check || !IsBaseOf(PrimStructInfo(op->dtype),
GetStructInfo(value))) {
+ // check_shape_info(value, ndim, err_ctx)
+ Call call(builtin_check_prim_value_info_,
+ {value, DataTypeImm(op->dtype), GetErrContext(err_ctx)},
Attrs(), {void_sinfo_});
+ builder_->Emit(call, "_");
+ }
+ if (op->value.defined()) {
+ MatchShapeTodoItem item;
+ item.input = value;
+ item.pattern = {op->value.value()};
+ item.err_ctx = err_ctx;
+ match_todos->push_back(item);
+ }
}
void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool
always_check,
@@ -729,6 +788,9 @@ class VMShapeLowerMutator
const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"};
const ExternFunc builtin_make_shape_{"vm.builtin.make_shape"};
const ExternFunc builtin_check_shape_info_{"vm.builtin.check_shape_info"};
+ const ExternFunc builtin_match_prim_value_{"vm.builtin.match_prim_value"};
+ const ExternFunc builtin_make_prim_value_{"vm.builtin.make_prim_value"};
+ const ExternFunc
builtin_check_prim_value_info_{"vm.builtin.check_prim_value_info"};
const ExternFunc builtin_check_tensor_info_{"vm.builtin.check_tensor_info"};
const ExternFunc builtin_check_tuple_info_{"vm.builtin.check_tuple_info"};
const ExternFunc builtin_check_func_info_{"vm.builtin.check_func_info"};
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index b7153ab50f..e47a399ae5 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -305,7 +305,9 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor*
to, TVMStreamHandle str
DeviceAPI::Get(dev)->CopyDataFromTo(const_cast<DLTensor*>(from), to, stream);
}
-ShapeTuple NDArray::Shape() const { return get_mutable()->shape_; }
+ShapeTuple NDArray::Shape() const {
+ return static_cast<const NDArray::Container*>(data_.get())->shape_;
+}
runtime::DataType NDArray::DataType() const {
return runtime::DataType(get_mutable()->dl_tensor.dtype);
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 8b27bb2d9e..a764c34cfa 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -66,6 +66,46 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) {
TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap);
+/*!
+ * \brief Builtin match R.Prim function.
+ *
+ * \param input_value The runtime value provided by the user
+ *
+ * \param heap The VM storage for symbolic shapes
+ *
+ * \param code_value The op code, defined in MatchShapeCode,
+ * indicating how this value should be interpreted.
+ *
+ * \param reg The register, if using kStoreToHeap or
+ * kAssertEqualToLoad, or a literal value if using kAssertEqualToImm
+ *
+ * \param err_ctx An optional string used in error messages, providing
+ * additional context
+ *
+ * \sa MatchShape
+ */
+void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value,
int64_t reg,
+ Optional<String> err_ctx) {
+ int64_t* heap_data = heap == nullptr ? nullptr :
static_cast<int64_t*>(heap->data);
+ MatchShapeCode code = static_cast<MatchShapeCode>(code_value);
+
+ if (code == MatchShapeCode::kAssertEqualToImm) {
+ CHECK_EQ(input_value, reg) << "RuntimeError: " << err_ctx.value_or("") <<
" match_cast error, "
+ << " PrimValue mismatch to specified constant.";
+ } else if (code == MatchShapeCode::kStoreToHeap) {
+ heap_data[reg] = input_value;
+ } else if (code == MatchShapeCode::kNoOp) {
+ } else if (code == MatchShapeCode::kAssertEqualToLoad) {
+ CHECK_EQ(input_value, heap_data[reg])
+ << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, "
+ << " PrimValue mismatch to a previous populated value.";
+ } else {
+ LOG(FATAL) << "Unknown match shape code: " << static_cast<int>(code);
+ }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.match_prim_value").set_body_typed(MatchPrimValue);
+
/*!
* \brief Builtin match shape function.
* \param args The packed function arguments.
@@ -117,6 +157,30 @@ void MatchShape(TVMArgs args, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("vm.builtin.match_shape").set_body(MatchShape);
+/*!
+ * \brief Builtin make prim value function.
+ * \param heap The shape heap to use
+ * \param shape_code The shape code of the value
+ * \param rv The return value.
+ *
+ * \sa MakeShape
+ */
+int64_t MakePrimValue(DLTensor* heap, int shape_code, int64_t reg) {
+ // NOTE: heap can be nullptr
+ int64_t* heap_data = heap == nullptr ? nullptr :
static_cast<int64_t*>(heap->data);
+
+ MakeShapeCode code = static_cast<MakeShapeCode>(shape_code);
+ if (code == MakeShapeCode::kUseImm) {
+ return reg;
+ } else if (code == MakeShapeCode::kLoadShape) {
+ return heap_data[reg];
+ } else {
+ LOG(FATAL) << "Invalid shape code: " << shape_code;
+ }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue);
+
/*!
* \brief Builtin make shape function.
* \param args The packed function arguments.
@@ -208,6 +272,30 @@ void CheckShapeInfo(ObjectRef arg, int ndim,
Optional<String> err_ctx) {
TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo);
+/*!
+ * \brief Builtin function to check if arg is PrimValue(dtype)
+ * \param arg The input argument.
+ * \param dtype Expected dtype of the PrimValue. Can be DataType::Void() for
unknown dtype.
+ * \param err_ctx Additional context if error occurs.
+ */
+void CheckPrimValueInfo(TVMArgValue arg, DataType dtype, Optional<String>
err_ctx) {
+ if (dtype.is_bool()) {
+ arg.operator bool();
+ } else if (dtype.is_int()) {
+ arg.operator int64_t();
+ } else if (dtype.is_uint()) {
+ arg.operator uint64_t();
+ } else if (dtype.is_float()) {
+ arg.operator double();
+ } else if (dtype.is_handle()) {
+ arg.operator void*();
+ } else {
+ LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", unsupported
dtype " << dtype;
+ }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.check_prim_value_info").set_body_typed(CheckPrimValueInfo);
+
/*!
* \brief Builtin function to check if arg is Tuple with size elements.
* \param arg The input argument.
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
index 82a6d6a2a4..b4816fd096 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -29,6 +29,7 @@ from tvm.contrib import utils, cc, popen_pool
from tvm.relax.testing import nn
from tvm.script import relax as R, tir as T, ir as I
from tvm.relax.testing.vm import check_saved_func
+from tvm.runtime import ShapeTuple
EXEC_MODE = ["bytecode", "compiled"]
@@ -515,6 +516,125 @@ def test_vm_relax_symbolic_shape(exec_mode):
tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7,
atol=1e-7)
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_relax_symbolic_shape_tuple(exec_mode):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(shape: R.Shape(["m", "n"])):
+ m = T.int64()
+ n = T.int64()
+ return R.shape([2 * m, 3 * n])
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ func = vm["main"]
+
+ assert func(ShapeTuple([2, 3])) == [4, 9]
+
+ with pytest.raises(ValueError):
+ func(ShapeTuple([2, 3, 4]))
+
+ with pytest.raises(TypeError):
+ func(R.prim_value(2))
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_relax_symbolic_prim_value(exec_mode):
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(shape: R.Prim(value="n")):
+ n = T.int64()
+ return R.prim_value(n * n)
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ func = vm["main"]
+
+ assert func(2) == 4
+
+ with pytest.raises(tvm.TVMError):
+ func(ShapeTuple([2]))
+
+
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_relax_multiple_symbolic_prim_value(exec_mode):
+ """Like test_vm_relax_symbolic_prim_value, but with multiple variables"""
+
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(
+ # Provides definition of "n"
+ _n: R.Prim(value="n"),
+ # Requires definitions of both "n" and "m", but cannot
+ # provide either.
+ _shape: R.Shape(["n*2", "m*2"]),
+ # Provides definition of "m"
+ _m: R.Prim(value="m"),
+ ):
+ n = T.int64()
+ m = T.int64()
+ return R.shape([n * n, m + 1])
+
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ func = vm["main"]
+
+ assert func(2, ShapeTuple([4, 12]), 6) == [4, 7]
+
+ with pytest.raises(RuntimeError):
+ func(2, ShapeTuple([4, 12]), 1)
+
+ with pytest.raises(tvm.TVMError):
+ func(ShapeTuple([2]))
+
+
[email protected](reason="Current support for R.Prim with known value is
primarily for int64")
[email protected]("exec_mode", EXEC_MODE)
+def test_vm_relax_prim_value_fp32(exec_mode):
+ """A PrimValue may be R.prim('float32')
+
+ Unlike shape tuples, which must contain int64, a PrimValue may be
+ any type that can be represented as a single primitive value.
+ """
+
+ @I.ir_module
+ class mod:
+ @R.function
+ def main(
+ # First failure occurs during parsing. The syntactic
+ # sugar for symbolic variables assumes that all symbolic
+ # variables are int64, rather than using the type that is
+ # later declared.
+ _x: R.Prim(value="half_fill_value"),
+ ):
+ half_fill_value = T.float32()
+ # Second failure occurs when calling `relax.op.full`. The
+ # `fill_value` is expected to be a scalar constant
+ # (R.Tensor with 0-dim shape), not a primitive value, even
+ # though these are semantically the same.
+ return R.full(shape=[16, 16], fill_value=R.prim_value(2 *
half_fill_value))
+
+ target = tvm.target.Target("llvm", host="llvm")
+ # Third failure occurs here. The current codegen assumes that all
+ # symbolic variables are int64.
+ ex = relax.build(mod, target, exec_mode=exec_mode)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ func = vm["main"]
+
+ res = func(16.0).numpy()
+ assert np.all(res == 32.0)
+
+
@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_relax_dyn_tir_shape(exec_mode):
# case where TIR variables are unbound in generated PrimFunc