This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 491a0f69aa [Relax] Require correct input/output shapes `R.call_tir`
(#17285)
491a0f69aa is described below
commit 491a0f69aabcf812cc552df7666038414ca79a8f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Sep 6 08:32:31 2024 -0500
[Relax] Require correct input/output shapes `R.call_tir` (#17285)
Prior to this commit, the Relax well-formed checker validated
arguments provided to Relax functions, but did not validate arguments
provided to `R.call_tir`. As a result, incorrect arguments from Relax
to TIR would not be checked until runtime, if at all.
This commit updates the well-formed checker to verify that
`R.call_tir` has received the correct arguments, and has the correct
output shape specified in the `out_sinfo` parameter.
Initial implementation performed the validation as part of
`FNormalize`, to maximize coverage of this check. This increased
end-to-end compilation time by ~10%, and so the check was requested to
be restricted to the well-formed checker. Expensive operator-specific
validation is now performed in the new `FValidate` attribute.
---
include/tvm/relax/op_attr_types.h | 27 ++
src/relax/analysis/well_formed.cc | 11 +
src/relax/op/op.cc | 291 +++++++++++-
src/relax/transform/fuse_tir.cc | 3 +-
...est_distributed_transform_propagate_sharding.py | 8 -
tests/python/relax/test_analysis_well_formed.py | 514 ++++++++++++++++++++-
tests/python/relax/test_ast_printer.py | 9 +-
tests/python/relax/test_dataflow_inplace.py | 10 +-
tests/python/relax/test_dataflow_pattern.py | 2 +-
tests/python/relax/test_frontend_dynamo.py | 7 +-
tests/python/relax/test_frontend_nn_op.py | 18 +-
tests/python/relax/test_transform.py | 6 +-
.../relax/test_transform_dead_code_elimination.py | 30 +-
tests/python/relax/test_transform_fuse_ops.py | 8 +-
.../relax/test_transform_fuse_ops_by_pattern.py | 18 +-
.../relax/test_transform_lazy_transform_params.py | 20 +-
.../test_transform_rewrite_dataflow_reshape.py | 25 +-
tests/python/relax/test_tvmscript_parser.py | 15 +-
tests/python/relax/test_vm_build.py | 12 +-
19 files changed, 928 insertions(+), 106 deletions(-)
diff --git a/include/tvm/relax/op_attr_types.h
b/include/tvm/relax/op_attr_types.h
index 291bee597c..0ddc2baefb 100644
--- a/include/tvm/relax/op_attr_types.h
+++ b/include/tvm/relax/op_attr_types.h
@@ -56,6 +56,14 @@ using FCallPacked = String;
* expressed in multiple syntactically valid and semantically
* equivalent forms, to normalize to a single representation.
*
+ * Note: `FNormalize` is applied for each expression as part of the
+ * `relax::BlockBuilder`. While operator-specific validation may
+ * be performed within the `FNormalize` implementation, ensuring
+ * that errors are caught as early as possible, this should only be
+ * used when validation is fast to apply. If the validation logic
+ * may be slow, it should instead be implemented in `FValidate`,
+ * which is only run as part of the well-formed checker.
+ *
* \param bb The BlockBuilder context.
*
* \param call The call to be normalized. It is provided by-value, to
@@ -63,6 +71,25 @@ using FCallPacked = String;
*/
using FNormalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, Call
call)>;
+/*!
+ * \brief The function type of a validation function.
+ *
+ * A validation function is used to define constraints that should be
+ * verified for an operator as part of the well-formed checker.
+ *
+ * Note: `FValidate` is only applied as part of the well-formed
+ * checker. While this minimizes overhead while compiling Relax,
+ * this delay between generating an ill-formed `relax::Call` and
+ * identifying the ill-formed call may complicate debugging. If
+ * the validation logic is very fast to check, and doing so would
+ * not introduce a signficant overhead, consider validating as part
+ * of `FNormalize`, which is applied by the block builder for each
+ * `relax::Call`.
+ *
+ * \param call The call to be validated.
+ */
+using FValidate = runtime::TypedPackedFunc<void(const Call& call)>;
+
/*! \brief The function type of a legalization function.
*
* A legalization function is used to replace a `relax::Call` with
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index 626fadda27..235059ece2 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -352,6 +352,16 @@ class WellFormedChecker : public relax::ExprVisitor,
<< after_normalize);
}
}
+
+ if (auto func_validate = op_map_validate_.get(call->op, nullptr);
func_validate != nullptr) {
+ try {
+ func_validate(GetRef<Call>(call));
+ } catch (std::exception& err) {
+ Malformed(Diagnostic::Error(call) << "Operator-specific validation
(FValidate) for "
+ << call->op << " identified error:
\n"
+ << err.what());
+ }
+ }
}
void VisitExpr_(const IfNode* op) final {
@@ -574,6 +584,7 @@ class WellFormedChecker : public relax::ExprVisitor,
std::unordered_map<tir::Var, const FunctionNode*> symbolic_var_func_map_;
tvm::OpAttrMap<FNormalize> op_map_normalize_ =
Op::GetAttrMap<FNormalize>("FNormalize");
+ tvm::OpAttrMap<FValidate> op_map_validate_ =
Op::GetAttrMap<FValidate>("FValidate");
};
bool WellFormed(Variant<IRModule, Function> obj, bool check_struct_info) {
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 0a840248ff..3e0f0eba31 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -18,6 +18,7 @@
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
+#include <tvm/relax/distributed/struct_info.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/utils.h>
#include <tvm/relay/op.h>
@@ -242,15 +243,195 @@
TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInpla
// call_tir
+/* If possible, infer a legal value of `arg_sinfo`
+ *
+ * The `R.call_tir` operator and its variants accept an `arg_sinfo`
+ * parameter, which specifies the shape of the tensor or tensors
+ * returned by a PrimFunc. This output shape must be compatible with
+ * the shape defined by the PrimFunc's signature.
+ *
+ * For dynamic shapes, it is not always possible to infer the output
+ * of a TIR PrimFunc from its inputs. For example, a PrimFunc that
+ * accepts input buffer `T.Buffer([16], "float32")` and output buffer
+ * `T.Buffer([M, N], "float32")` infers the values of `M` and `N` from
+ * the shape of the provided output buffer.
+ *
+ * If the arguments provided are not compatible with the PrimFunc's
+ * signature, an error will be raised. If the arguments are
+ * compatible with the PrimFunc's signature, but are not sufficient to
+ * determine the output's StructInfo, then `NullOpt` will be returned.
+ *
+ * \param func_sinfo The StructInfo of the TIR callee.
+ * \param arg_sinfo The StructInfo of the argument tuple.
+ * \param packed_ints_sinfo The StructInfo of the ShapeTuple argument,
+ * if present.
+ * \param opt_inplace_indices For `R.call_tir_inplace`, an array of
+ * indices indicating which outputs are constructed from in-place
+ * mutation of the inputs. See
+ * `CallTIRInplaceAttrs::inplace_indices` for more details.
+ *
+ * \return The `arg_sinfo`, if it can be inferred from the arguments.
+ * Otherwise, NullOpt.
+ */
+static Optional<StructInfo> InferCallTIROutputStructInfoFromArguments(
+ StructInfo func_sinfo, StructInfo arg_sinfo, Optional<StructInfo>
packed_ints_sinfo,
+ Optional<Array<Integer>> opt_inplace_indices) {
+ auto opt_callee_sinfo = func_sinfo.as<FuncStructInfo>();
+ CHECK(opt_callee_sinfo) << "TypeError: "
+ << "The first argument to `R.call_tir` must be a
function, "
+ << "but instead received argument of type " <<
func_sinfo;
+ auto callee_sinfo = opt_callee_sinfo.value();
+
+ CHECK(callee_sinfo->params.defined())
+ << "ValueError: "
+ << "The first argument to `R.call_tir` must be a function "
+ << "with known argument types. "
+ << "However, the first argument was of type " << callee_sinfo;
+ auto callee_params = callee_sinfo->params.value();
+
+ const TupleStructInfoNode* args = arg_sinfo.as<TupleStructInfoNode>();
+ CHECK(args) << "TypeError: "
+ << "The second argument to `R.call_tir` must be a tuple, "
+ << "but instead received expression of type " << arg_sinfo;
+
+ // R.call_tir expects the PrimFunc to have three groups of arguments.
+ //
+ // 1. Input arguments that are explicitly provided as Relax arguments.
+ // 2. Output tensor arguments.
+ // 3. Shape arguments, represented as `T.int64` in the PrimFunc, and
+ // as an optional ShapeExpr argument in the `relax::Call` node.
+ //
+ // In order to determine the return type of `R.call_tir`, we must
+ // identify the PrimFunc arguments that will be in group (2).
+ size_t num_input_arguments = args->fields.size();
+ size_t num_trailing_int_arguments = 0;
+ const ShapeStructInfoNode* packed_tuple_sinfo = nullptr;
+ if (packed_ints_sinfo) {
+ auto packed_sinfo = packed_ints_sinfo.value();
+ packed_tuple_sinfo = packed_sinfo.as<ShapeStructInfoNode>();
+ CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim())
+ << "TypeError: "
+ << "The third argument to `R.call_tir`, if present, "
+ << "must be a ShapeTuple with known dimensionality. "
+ << "However, the argument received was of type " << packed_sinfo;
+ num_trailing_int_arguments = packed_tuple_sinfo->ndim;
+ } else {
+ num_trailing_int_arguments = 0;
+ }
+
+ CHECK_LE(num_input_arguments + num_trailing_int_arguments,
callee_params.size())
+ << "ValueError: "
+ << "R.call_tir attempted to call a function using " <<
num_input_arguments
+ << " input arguments and " << num_trailing_int_arguments << " trailing
integer arguments. "
+ << "However, the callee only accepts " << callee_params.size() << "
arguments in total.";
+
+ // While Relax can specify a distributed tensor, TIR cannot. The
+ // current implementation does not support determining the output
+ // shape for `R.dist.call_tir` calls, as it depends on the lowering
+ // of DistIR into regular Relax.
+ std::function<bool(StructInfo)> contains_dtensor =
[&contains_dtensor](StructInfo sinfo) -> bool {
+ if (sinfo.as<distributed::DTensorStructInfoNode>()) {
+ return true;
+ } else if (auto tuple = sinfo.as<TupleStructInfoNode>()) {
+ return std::any_of(tuple->fields.begin(), tuple->fields.end(),
contains_dtensor);
+ } else {
+ return false;
+ }
+ };
+ if (contains_dtensor(arg_sinfo)) {
+ return NullOpt;
+ }
+
+ // At this point, the return types are known. However, the shapes
+ // in `callee_params` may contain dynamic shape parameters that are
+ // not present in the caller's scope. The `DeriveCallRetStructInfo`
+ // utility can infer the value of dynamic parameters in
+ // `FuncStructInfoNode::ret` based on definitions in
+ // `FuncStructInfoNode::params`, inferring the correct values in the
+ // caller's scope.
+ //
+ // Since the callee of `R.call_tir` is provided with output
+ // arguments, where `DeriveCallRetStructInfo` requires a callee that
+ // produces its own outputs, a dummy function signature and
+ // arguments are used.
+
+ auto dummy_callee_sinfo = [&]() -> FuncStructInfo {
+ Array<StructInfo> dummy_params(callee_params.begin(),
+ callee_params.begin() +
num_input_arguments);
+
+ for (size_t i = callee_params.size() - num_trailing_int_arguments; i <
callee_params.size();
+ i++) {
+ dummy_params.push_back(callee_params[i]);
+ }
+
+ Array<StructInfo> dummy_ret(callee_params.begin() + num_input_arguments,
+ callee_params.end() -
num_trailing_int_arguments);
+
+ if (opt_inplace_indices) {
+ // For R.call_tir_inplace, the `inplace_indices` are used to
+ // indicate which elements of the `out_sinfo` will be generated
+ // as in-place mutation from an input. For any in-place
+ // mutation, the parameter's StructInfo must be inserted into
+ // `out_sinfo`.
+ auto inplace_indices = opt_inplace_indices.value();
+ for (size_t i = 0; i < inplace_indices.size(); i++) {
+ auto inplace_input_index = inplace_indices[i]->value;
+ if (inplace_input_index >= 0) {
+ dummy_ret.insert(dummy_ret.begin() + i,
callee_params[inplace_input_index]);
+ }
+ }
+ }
+
+ auto dummy_out_sinfo = [&]() -> StructInfo {
+ if (dummy_ret.size() == 1) {
+ return dummy_ret[0];
+ } else {
+ return TupleStructInfo(dummy_ret);
+ }
+ }();
+
+ return FuncStructInfo(dummy_params, dummy_out_sinfo);
+ }();
+
+ auto dummy_args = [&]() -> Array<Expr> {
+ Array<Expr> dummy_args = args->fields.Map(
+ [](const StructInfo& sinfo) -> Expr { return Var("dummy_leading_arg",
sinfo); });
+
+ for (size_t i = 0; i < num_trailing_int_arguments; i++) {
+ ICHECK(packed_tuple_sinfo);
+ PrimStructInfo dummy_arg_sinfo = [&]() {
+ if (packed_tuple_sinfo->values) {
+ return PrimStructInfo(packed_tuple_sinfo->values.value()[i]);
+ } else {
+ return PrimStructInfo(DataType::Int(64));
+ }
+ }();
+ dummy_args.push_back(Var("dummy_trailing_arg", dummy_arg_sinfo));
+ }
+
+ return dummy_args;
+ }();
+
+ auto derived_ret_sinfo = DeriveCallRetStructInfo(
+ dummy_callee_sinfo, Call(Var("dummy_callee", dummy_callee_sinfo),
dummy_args),
+ BlockBuilder::Create(NullOpt));
+
+ return derived_ret_sinfo;
+}
+
StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exactly 1 output struct
info.");
}
CHECK(call->args[0]->IsInstance<GlobalVarNode>())
- << "call_tir expects the first argument to be a GlobalVar referring to a
TIR PrimFunc. "
- << "However, gets " << call->args[0];
- return call->sinfo_args[0];
+ << "R.call_tir expects the first argument to be a GlobalVar referring to
a TIR PrimFunc. "
+ << "However, the argument " << call->args[0] << " instead has type "
+ << call->args[0]->GetTypeKey();
+
+ StructInfo explicit_sinfo = call->sinfo_args[0];
+
+ return explicit_sinfo;
}
Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) {
@@ -264,23 +445,37 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call)
{
<< "or three arguments [callee, arg_tuple, tir_args], "
<< "but " << call << " has " << call->args.size() << " arguments.";
- Expr arg_expr = call->args[1];
+ auto callee = call->args[0];
+ CHECK(callee->struct_info_.as<FuncStructInfoNode>())
+ << "Operation " << call->op << " expects the first argument to be a TIR
callee. "
+ << "However, the first argument " << callee << " has struct info " <<
callee->struct_info_;
- CHECK(arg_expr->struct_info_.as<TupleStructInfoNode>())
- << "Operation " << call->op << " expects the second argument to be a
tuple of relax Expr. "
- << "However, the second argument " << arg_expr << " has struct info "
- << arg_expr->struct_info_ << ".";
+ Expr arg_tuple = call->args[1];
- if (arg_expr.as<TupleNode>()) {
- return std::move(call);
- }
+ CHECK(arg_tuple->struct_info_.as<TupleStructInfoNode>())
+ << "Operation " << call->op << " expects the second argument to be a
tuple of relax Expr. "
+ << "However, the second argument " << arg_tuple << " has struct info "
+ << arg_tuple->struct_info_ << ".";
- CHECK(arg_expr.as<VarNode>())
+ CHECK(arg_tuple.as<TupleNode>() || arg_tuple.as<VarNode>())
<< "Operation " << call->op << " must hold its arguments as an in-line
tuple. "
- << "However, " << call << " has arguments " << arg_expr
+ << "However, " << call << " has arguments " << arg_tuple
<< ", which is neither an in-line tuple, "
<< "nor a variable binding that may be normalized to an in-line tuple.";
+ if (call->args.size() > 2) {
+ Expr packed_ints = call->args[2];
+ CHECK(packed_ints->struct_info_.as<ShapeStructInfoNode>())
+ << "Operation " << call->op << " expects the optional third argument, "
+ << "if present, to be a ShapeTuple. "
+ << "However, the third argument " << packed_ints << " has struct info "
+ << packed_ints->struct_info_;
+ }
+
+ CHECK_EQ(call->sinfo_args.size(), 1)
+ << "R.call_tir should have exactly one `sinfo_args` parameter, "
+ << "which defines the output of the PrimFunc.";
+
auto unwrap_binding = [&ctx](Expr expr) -> Optional<Expr> {
if (auto var = expr.as<Var>()) {
if (auto bound_value = ctx->LookupBinding(var.value())) {
@@ -290,14 +485,21 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call)
{
return NullOpt;
};
- while (auto unwrapped = unwrap_binding(arg_expr)) {
- arg_expr = unwrapped.value();
- }
+ Tuple new_arg_tuple = [&]() {
+ // No replacement required. The argument tuple is already
+ // provided as an in-line tuple.
+ if (auto opt = arg_tuple.as<Tuple>()) {
+ return opt.value();
+ }
+
+ Expr unwrapped_tuple = arg_tuple;
+ while (auto unwrapped = unwrap_binding(unwrapped_tuple)) {
+ unwrapped_tuple = unwrapped.value();
+ }
- Tuple new_arg_expr = [&]() {
// Preferred replacement. The argument tuple is provided as a
// variable, but we know the value bound to that variable.
- if (auto opt = arg_expr.as<Tuple>()) {
+ if (auto opt = unwrapped_tuple.as<Tuple>()) {
return opt.value();
}
@@ -306,20 +508,60 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call)
{
// example, if a relax function accepted a tuple as an parameter,
// then provided that same tuple as an argument to call_tir.
Array<Expr> tuple_elements;
- size_t num_fields =
Downcast<TupleStructInfo>(arg_expr->struct_info_)->fields.size();
+ size_t num_fields =
Downcast<TupleStructInfo>(arg_tuple->struct_info_)->fields.size();
for (size_t i = 0; i < num_fields; i++) {
- tuple_elements.push_back(TupleGetItem(arg_expr, i));
+ tuple_elements.push_back(TupleGetItem(arg_tuple, i));
}
return Tuple(tuple_elements);
}();
- auto new_args = call->args;
- new_args.Set(1, new_arg_expr);
- call.CopyOnWrite()->args = new_args;
+ if (!new_arg_tuple.same_as(arg_tuple)) {
+ auto new_args = call->args;
+ new_args.Set(1, new_arg_tuple);
+ call.CopyOnWrite()->args = new_args;
+ }
return std::move(call);
}
+void ValidateCallTIR(Call call) {
+ // This function is used for validation of `relax.call_tir`,
+ // along with the variants `relax.call_tir_with_grad` and
+ // `relax.call_tir_inplace`. Therefore, all error messages should
+ // be written in terms of `call->op`, and should not explicitly
+ // reference the `relax.call_tir` operator.`
+
+ auto callee = call->args[0];
+ Expr arg_tuple = call->args[1];
+
+ auto packed_int_sinfo = [&]() -> Optional<StructInfo> {
+ if (call->args.size() <= 2) {
+ return NullOpt;
+ } else {
+ return GetStructInfo(call->args[2]);
+ }
+ }();
+
+ auto opt_inplace_indices = [&]() -> Optional<Array<Integer>> {
+ if (const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>()) {
+ return attrs->inplace_indices;
+ } else {
+ return NullOpt;
+ }
+ }();
+
+ StructInfo explicit_sinfo = call->sinfo_args[0];
+ auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments(
+ GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo,
opt_inplace_indices);
+ if (inferred_sinfo.defined()) {
+ CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo))
+ << "TypeError: "
+ << "The `out_sinfo` argument for R.call_tir must be compatible with
the PrimFunc. "
+ << "However, the PrimFunc's signature implies that the output should
be " << inferred_sinfo
+ << ", but the `out_sinfo` argument was " << explicit_sinfo;
+ }
+}
+
RELAY_REGISTER_OP("relax.call_tir")
.set_num_inputs(3)
.add_argument("func", "Expr", "The destination-passing-style function.")
@@ -329,6 +571,7 @@ RELAY_REGISTER_OP("relax.call_tir")
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
+ .set_attr<FValidate>("FValidate", ValidateCallTIR)
.set_attr<Bool>("FPurity", Bool(true));
Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
@@ -374,6 +617,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad")
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
+ .set_attr<FValidate>("FValidate", ValidateCallTIR)
.set_attr<Bool>("FPurity", Bool(true));
Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array<TensorStructInfo>
out_sinfo_list,
@@ -514,6 +758,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace")
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<FNormalize>("FNormalize", NormalizeCallTIRInPlace)
+ .set_attr<FValidate>("FValidate", ValidateCallTIR)
// Warning: considered pure, but it has the potential to create visible
effects!
// This should only be used if it has been *checked* that it is safe (no
aliases, in-place
// arguments will no longer be live)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index b203b322ab..612e1459c8 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -1088,8 +1088,7 @@ class TIRFuseMutator : public ExprMutator {
const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod,
old_gvar);
GlobalVar new_gvar(old_gvar->name_hint);
- UpdateStructInfo(new_gvar,
-
FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type)));
+ UpdateStructInfo(new_gvar, GetStructInfo(prim_func));
mod->Remove(old_gvar);
updates->Add(new_gvar, prim_func);
diff --git
a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py
b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py
index e1f45d278d..865051b0b4 100644
---
a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py
+++
b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py
@@ -512,13 +512,11 @@ def test_decoder_layer():
cls.rotary_embedding,
(lv9, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
- tir_vars=R.shape([256]),
)
lv17 = R.call_tir(
cls.rotary_embedding,
(lv12, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
- tir_vars=R.shape([256]),
)
lv18: R.Tensor((256, 32, 128), dtype="float16") = R.reshape(
lv17, R.shape([256, 32, 128])
@@ -712,13 +710,11 @@ def test_decoder_layer():
cls.rotary_embedding,
(lv9, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]",
"S[2]"),
- tir_vars=R.shape([256]),
)
lv17 = R.dist.call_tir(
cls.rotary_embedding,
(lv12, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]",
"S[2]"),
- tir_vars=R.shape([256]),
)
lv18: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") =
R.reshape(
lv17, R.shape([256, 32, 128])
@@ -1278,13 +1274,11 @@ def test_decoder_layer_tir():
cls.rotary_embedding,
(lv9, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
- tir_vars=R.shape([256]),
)
lv17 = R.call_tir(
cls.rotary_embedding,
(lv12, cos_cached, sin_cached),
out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"),
- tir_vars=R.shape([256]),
)
lv18 = R.call_tir(
cls.reshape1, (lv17,), out_sinfo=R.Tensor((256, 32, 128),
dtype="float16")
@@ -1449,13 +1443,11 @@ def test_decoder_layer_tir():
LlamaAttentionLayerTIR.get_global_var("rotary_embedding"),
(lv9, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]",
"S[2]"),
- tir_vars=R.shape([256]),
)
lv17 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("rotary_embedding"),
(lv12, cos_cached, sin_cached),
out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]",
"S[2]"),
- tir_vars=R.shape([256]),
)
lv18 = R.dist.call_tir(
LlamaAttentionLayerTIR.get_global_var("reshape1"),
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
index 7deddfd28e..c0b962c3f3 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -14,15 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import pytest
+
import tvm
import tvm.testing
+
from tvm import relax as rx
from tvm import tir
-from tvm.script import relax as R
-from tvm.script import ir as I
-from tvm.script import tir as T
-from tvm.script import ir as I
+from tvm.script import ir as I, relax as R, tir as T
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
@@ -702,5 +702,511 @@ def test_pass_dltensor_arg_to_tir():
assert rx.analysis.well_formed(Module)
+def test_call_tir_with_matching_arguments():
+ """R.call_tir is well-formed when called with matching arguments"""
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16],
"float16"))
+ return B
+
+ @T.prim_func
+ def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+ for i in range(16):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi] + T.float16(1.0)
+
+ assert rx.analysis.well_formed(Module)
+
+
+def test_call_tir_input_ndim():
+ """Arguments to R.call_tir must have the correct dimensionality
+
+ Here, the `add_one` function expects a 1-d input tensor, but is
+ called with a 2-d tensor.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([4, 4], "float16")):
+ B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16],
"float16"))
+ return B
+
+ @T.prim_func
+ def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+ for i in range(16):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi] + T.float16(1.0)
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_output_ndim():
+ """Output shape R.call_tir must have the correct dimensionality
+
+ Here, the `add_one` function requires a 1-d output tensor, but is
+ provided with a 2-d tensor.
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([4, 4],
"float16"))
+ return B
+
+ @T.prim_func
+ def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+ for i in range(16):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi] + T.float16(1.0)
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_input_shape():
+ """Arguments to R.call_tir must have the correct shape
+
+ Here, the `add_one` function expects an input tensor with 16
+ elements, but is called with an input tensor with 32 elements.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([32], "float16")):
+ B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16],
"float16"))
+ return B
+
+ @T.prim_func
+ def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+ for i in range(16):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi] + T.float16(1.0)
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_output_shape():
+ """Output shape R.call_tir must have the correct shape
+
+ Here, the `add_one` function requires an output tensor with 16
+ elements, but is provided an output tensor with 32 elements.
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([32],
"float16"))
+ return B
+
+ @T.prim_func
+ def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+ for i in range(16):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi] + T.float16(1.0)
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_input_dtype():
+ """Arguments to R.call_tir must have the correct dtype
+
+ Here, the `add_one` function expects an input tensor containing
+ float16 value, but is called with an input tensor containing
+ float32 values.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float32")):
+ B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16],
"float16"))
+ return B
+
+ @T.prim_func
+ def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+ for i in range(16):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi] + T.float16(1.0)
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_output_dtype():
+ """Output shape R.call_tir must have the correct shape
+
+ Here, the `add_one` function requires an output tensor that may be
+ populated with float16 values, but is provided an output tensor
+ that may be populated with float32 elements.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16],
"float32"))
+ return B
+
+ @T.prim_func
+ def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")):
+ for i in range(16):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi] + T.float16(1.0)
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_with_correct_dynamic_output_shape():
+ """Output shape R.call_tir may not be verifiable
+
+ Here, the input arguments to the `reshape` function are not
+ sufficient to infer the shape of the outputs. This is legal,
+ since the output shape is determined by the `out_sinfo` parameter.
+
+ Inability to verify the output shape does not mean that the output
+ shape is invalid.
+
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 8],
"float16"))
+ return B
+
+ @T.prim_func
+ def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle):
+ M = T.int64()
+ N = T.int64()
+ B = T.match_buffer(B_handle, [M, N], dtype="float16")
+
+ for i, j in T.grid(M, N):
+ with T.block("compute"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi * N + vj]
+
+ assert rx.analysis.well_formed(Module)
+
+
[email protected](reason="Not supported")
+def test_call_tir_with_incorrect_dynamic_output_shape():
+ """Output shape R.call_tir may not be verifiable
+
+ Here, the input arguments to the `reshape` function are not
+ sufficient to infer the shape of the outputs. Even though the
+ IRModule will not provide well-defined output due to the
+ out-of-bounds read from buffer A, catching this error is beyond
+ the current scope of the Relax well-formed checker.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([16, 16],
"float16"))
+ return B
+
+ @T.prim_func
+ def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle):
+ M = T.int64()
+ N = T.int64()
+ B = T.match_buffer(B_handle, [M, N], dtype="float16")
+
+ for i, j in T.grid(M, N):
+ with T.block("compute"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi * N + vj]
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_incorrect_dimensionality_of_output_shape():
+ """Dimensionality may be verified
+
+ Here, the input arguments to the `reshape` function are not
+ sufficient to infer the shape of the outputs.
+
+ Even though the output shape may not be inferred from the input
+ arguments, the output dimensionality can still be inferred from
+ the PrimFunc signature. The IRModule below is ill-formed, because
+ the PrimFunc requires a 2-d output argument, but is provided with
+ a 3-d output argument.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 4, 2],
"float16"))
+ return B
+
+ @T.prim_func
+ def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle):
+ M = T.int64()
+ N = T.int64()
+ B = T.match_buffer(B_handle, [M, N], dtype="float16")
+
+ for i, j in T.grid(M, N):
+ with T.block("compute"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi * N + vj]
+
+ assert not rx.analysis.well_formed(Module)
+
+
[email protected](reason="Not yet supported")
+def test_call_tir_output_shape_with_mixed_static_and_dynamic():
+ """Some dimensions of the R.call_tir output shape may be verifiable
+
+ Here, the input arguments to the `reshape` function are not
+ sufficient to infer the shape of the outputs. This is legal,
+ since the output shape is taken from the `out_sinfo` parameter.
+
+ Identifying this failure mode is not yet supported in the current
+ implementation. This is because the output is inferred as
+ `R.Tensor(ndim=3, dtype="float16")`, and the explicit `out_sinfo`
+ is a 3-d tensor. The mismatch in the first dimension is not yet
+ counted, because the entire tensor shape is removed by
+ `EraseToWellDefined`.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([256], "float16")):
+ B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([8, 16, 2],
"float16"))
+ return B
+
+ @T.prim_func
+ def reshape(A: T.Buffer(256, "float16"), B_handle: T.handle):
+ M = T.int64()
+ N = T.int64()
+ B = T.match_buffer(B_handle, [16, M, N], dtype="float16")
+
+ for i, j, k in T.grid(16, M, N):
+ with T.block("compute"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ B[vi, vj, vk] = A[vi * N * M + vj * N + vk]
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_with_correct_inferred_dynamic_output_shape():
+ """Some dynamic output shapes of R.call_tir may be inferred
+
+ Here, the `flatten` function is dynamic, and will flatten any 2-d
+ TIR buffer. Even though it is dynamic, the input shapes are
+ sufficient to infer that `M==8` and `N==4`. As a result, the
+ output shape of `[M*N]` can be inferred to be `[32]`, and the
+ shape specified in `out_sinfo` can be validated.
+
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([8, 4], "float16")):
+ B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([32],
"float16"))
+ return B
+
+ @T.prim_func
+ def flatten(A_handle: T.handle, B_handle: T.handle):
+ M = T.int64()
+ N = T.int64()
+ A = T.match_buffer(A_handle, [M, N], dtype="float16")
+ B = T.match_buffer(B_handle, [M * N], dtype="float16")
+
+ for i in T.grid(M * N):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi // N, vi % N]
+
+ assert rx.analysis.well_formed(Module)
+
+
+def test_call_tir_with_incorrect_inferred_dynamic_output_shape():
+ """Some dynamic output shapes of R.call_tir may be inferred
+
+ Here, the `flatten` function is dynamic, and will flatten any 2-d
+ TIR buffer. Even though it is dynamic, the input shapes are
+ sufficient to infer that `M==8` and `N==4`. As a result, the
+ output shape of `[M*N]` can be inferred to be `[32]`, and the
+ shape specified in `out_sinfo` can be validated.
+
+ This unit test is identical to the above test
+ `test_call_tir_with_correct_inferred_dynamic_output_shape`, except
+ that the output shape is explicitly specified as `[64]`, which is
+ caught as a mismatch from the expected output shape.
+
+ """
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([8, 4], "float16")):
+ B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([64],
"float16"))
+ return B
+
+ @T.prim_func
+ def flatten(A_handle: T.handle, B_handle: T.handle):
+ M = T.int64()
+ N = T.int64()
+ A = T.match_buffer(A_handle, [M, N], dtype="float16")
+ B = T.match_buffer(B_handle, [M * N], dtype="float16")
+
+ for i in T.grid(M * N):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi // N, vi % N]
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_with_dtensor_arguments():
+ """R.call_tir and R.dist.call_tir share the same operation
+
+ Both `R.call_tir` and `R.dist.call_tir` produce the same
+ "relax.call_tir" operation, differing only in the StructInfo of
+ their arguments. Normalization of "relax.call_tir" must handle
+ `R.DTensor` arguments.
+
+ """
+
+ # from tvm.script.parser import relax as R
+
+ @I.ir_module
+ class Module:
+ I.module_attrs({"device_num": 4})
+ I.module_global_infos({"mesh": [R.dist.device_mesh([4], I.Range(0,
4))]})
+
+ @R.function
+ def main(A: R.dist.DTensor([8, 4], "float16", "mesh[0]", "S[0]")):
+ B = R.dist.call_tir(
+ Module.flatten, A, out_sinfo=R.dist.DTensor([64], "float16",
"mesh[0]", "S[0]")
+ )
+ return B
+
+ @T.prim_func
+ def flatten(A_handle: T.handle, B_handle: T.handle):
+ M = T.int64()
+ N = T.int64()
+ A = T.match_buffer(A_handle, [M, N], dtype="float16")
+ B = T.match_buffer(B_handle, [M * N], dtype="float16")
+
+ for i in T.grid(M * N):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = A[vi // N, vi % N]
+
+ assert rx.analysis.well_formed(Module)
+
+
+def test_call_tir_inplace_with_correct_shapes():
+ """R.call_tir_inplace is well-formed when called with matching arguments"""
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.call_tir_inplace(
+ Module.add_one,
+ A,
+ inplace_indices=[0],
+ out_sinfo=R.Tensor([16], "float16"),
+ )
+ return B
+
+ @T.prim_func
+ def add_one(A: T.Buffer(16, "float16")):
+ for i in range(16):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ A[vi] = A[vi] + T.float16(1.0)
+
+ assert rx.analysis.well_formed(Module)
+
+
+def test_call_tir_inplace_with_incorrect_shapes():
+ """R.call_tir_inplace is ill-formed when output shape does not match
input"""
+
+ @I.ir_module(check_well_formed=False)
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16")):
+ B = R.call_tir_inplace(
+ Module.add_one,
+ A,
+ inplace_indices=[0],
+ out_sinfo=R.Tensor([32], "float16"),
+ )
+ return B
+
+ @T.prim_func
+ def add_one(A: T.Buffer(16, "float16")):
+ for i in range(16):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ A[vi] = A[vi] + T.float16(1.0)
+
+ assert not rx.analysis.well_formed(Module)
+
+
+def test_call_tir_inplace_with_some_allocated_outputs():
+ """R.call_tir_inplace may contain some non-inplace outputs"""
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([16], "float16"), B: R.Tensor([32], "float16")):
+ out = R.call_tir_inplace(
+ Module.add_one,
+ (A, B),
+ inplace_indices=[-1, 1],
+ out_sinfo=[
+ R.Tensor([16], "float16"),
+ R.Tensor([32], "float16"),
+ ],
+ )
+ return out
+
+ @T.prim_func
+ def add_one(
+ A: T.Buffer(16, "float16"),
+ B: T.Buffer(32, "float16"),
+ C: T.Buffer(16, "float16"),
+ ):
+ for i in range(32):
+ with T.block("inplace_B"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = B[vi] + T.float16(1.0)
+
+ for i in range(16):
+ with T.block("output_C"):
+ vi = T.axis.remap("S", [i])
+ C[vi] = A[vi] + T.float16(1.0)
+
+ assert rx.analysis.well_formed(Module)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_ast_printer.py
b/tests/python/relax/test_ast_printer.py
index 64d5c73811..6005ecb0fa 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -43,6 +43,7 @@ def normalize(func: rx.Function) -> rx.Function:
"""
Normalize the expr to fill in the checked_type_ and struct_info fields
everywhere
"""
+
# using a default mutator to use the BlockBuilder's normalizer,
# which oddly differs from the Normalize pass
@rx.expr_functor.mutator
@@ -435,9 +436,13 @@ def test_call_tir():
@tvm.script.ir_module
class TestCallTIR:
@T.prim_func
- def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16),
"int32")) -> None:
+ def addone(A_handle: T.handle, B_handle: T.handle) -> None:
+ m = T.int64()
+ n = T.int64()
+ A = T.match_buffer(A_handle, (m, n), "float32")
+ B = T.match_buffer(B_handle, (m, n), "float32")
T.func_attr(({"global_symbol": "addone"}))
- for i, j in T.grid(16, 16):
+ for i, j in T.grid(m, n):
with T.block("addone"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.int32(1)
diff --git a/tests/python/relax/test_dataflow_inplace.py
b/tests/python/relax/test_dataflow_inplace.py
index 8d5eb07c78..cd6e285de4 100644
--- a/tests/python/relax/test_dataflow_inplace.py
+++ b/tests/python/relax/test_dataflow_inplace.py
@@ -172,8 +172,8 @@ def test_alias_call_tir():
T.func_attr({"global_symbol": "tir_id"})
m = T.int32()
n = T.int32()
- A = T.match_buffer(x, (m, n))
- B = T.match_buffer(y, (m, n))
+ A = T.match_buffer(x, (m, n), "int32")
+ B = T.match_buffer(y, (m, n), "int32")
for i, j in T.grid(m, n):
with T.block("id"):
@@ -185,9 +185,9 @@ def test_alias_call_tir():
T.func_attr({"global_symbol": "tir_id"})
m = T.int32()
n = T.int32()
- A = T.match_buffer(x, (m, n))
- B = T.match_buffer(y, (m, n))
- C = T.match_buffer(z, (m, n))
+ A = T.match_buffer(x, (m, n), "int32")
+ B = T.match_buffer(y, (m, n), "int32")
+ C = T.match_buffer(z, (m, n), "int32")
for i, j in T.grid(m, n):
with T.block("id"):
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index 03a3beb2f2..7a3b65cea1 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -72,7 +72,7 @@ class Module:
lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32),
dtype="float32"))
lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32),
dtype="float32"))
lv2 = R.call_tir(
- cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"),
tir_vars=R.ShapeExpr([32])
+ cls.tir_zeros, [], R.Tensor((32,), dtype="float32"),
tir_vars=R.ShapeExpr([32])
)
gv = (lv1, lv2)
R.output(gv)
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
index d83f83f4e1..21e1d82d28 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -114,9 +114,10 @@ def test_relax_dynamo():
with db:
opt_model = torch.compile(model, backend=relax_dynamo())
inp = torch.randn(10, 100)
- tvm.testing.assert_allclose(
- opt_model(inp).detach().numpy(), model(inp).detach().numpy(),
rtol=1e-5, atol=1e-5
- )
+
+ default_output = model(inp).detach().numpy()
+ optimized_output = opt_model(inp).detach().numpy()
+ tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5,
atol=1e-5)
def test_relax_dynamo_dynamic():
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 40624790cb..6a337b34c1 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -570,10 +570,18 @@ def test_tensor_ir_op():
@T.prim_func(private=True)
def fused_rope( # pylint: disable=too-many-locals
var_qkv: T.handle,
- offset: T.int64,
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
+ # Scalar arguments must be specified after tensor arguments,
+ # including the output tensor arguments
+ #
+ # TODO(Lunderberg): Update
+ # `tvm.relax.frontend.nn.op.tensor_ir_op` to use `PrimValue`
+ # instead of `tir_vars`, so that the order can be consistent
+ # between the function definition and the arguments in
+ # `op.tensor_ir_op`.
+ offset: T.int64,
):
batch_size = T.int64()
seq_len = T.int64()
@@ -601,7 +609,7 @@ def test_tensor_ir_op():
@I.ir_module
class Expected:
@T.prim_func(private=True)
- def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q:
T.handle, var_k: T.handle, var_v: T.handle):
+ def llama_fused_rope(var_qkv: T.handle, var_q: T.handle, var_k:
T.handle, var_v: T.handle, offset: T.int64):
batch_size, seq_len = T.int64(), T.int64()
qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16),
"float16")
q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16")
@@ -669,10 +677,11 @@ def test_tensor_ir_inplace_op():
def test(
self, embedding_table: Tensor, input_ids: Tensor, embedding_dst:
Tensor, offset: int
):
- tensor_expr_op_out = op.tensor_ir_op(
+ tensor_expr_op_out = op.tensor_ir_inplace_op(
inplace_take,
"inplace_take",
args=[embedding_table, input_ids, embedding_dst, offset],
+ inplace_indices=[2],
out=Tensor.placeholder(embedding_dst.shape,
embedding_dst.dtype),
)
return tensor_expr_op_out
@@ -719,10 +728,11 @@ def test_tensor_ir_inplace_op():
R.func_attr({"num_input": 4})
cls = Expected
with R.dataflow():
- lv1 = R.call_tir(
+ lv1 = R.call_tir_inplace(
cls.inplace_take,
(embedding_table, input_ids, embedding_dst),
out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype),
+ inplace_indices=[2],
tir_vars=R.shape([offset_1]),
)
gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1
diff --git a/tests/python/relax/test_transform.py
b/tests/python/relax/test_transform.py
index ee2df866fb..e3274aea88 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -86,7 +86,11 @@ def test_call_tir_rewrite():
@tvm.script.ir_module
class TestCallTIRRewrite:
@T.prim_func
- def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3),
"float32")):
+ def exp(A_handle: T.handle, B_handle: T.handle):
+ m = T.int64()
+ n = T.int64()
+ A = T.match_buffer(A_handle, (m, n), "float32")
+ B = T.match_buffer(B_handle, (m, n), "float32")
T.evaluate(0)
@R.function
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py
b/tests/python/relax/test_transform_dead_code_elimination.py
index 65970d6455..0ddf985ec4 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -277,18 +277,26 @@ def
test_tracking_through_externally_exposed_func(provide_entry_func_name):
def test_unused_relax_func_symbolic_shape():
# Test with relax function w/ symbolic shape.
- @tvm.script.ir_module
+ @tvm.script.ir_module(check_well_formed=False)
class InputModule:
@T.prim_func
- def tir_add(
- x: T.Buffer((16, 16), "float32"),
- y: T.Buffer((16, 16), "float32"),
- z: T.Buffer((16, 16), "float32"),
+ def tir_matmul(
+ x_handle: T.handle,
+ y_handle: T.handle,
+ z_handle: T.handle,
) -> None:
- for i, j in T.grid(16, 16):
- with T.block("add"):
- vi, vj = T.axis.remap("SS", [i, j])
- z[vi, vj] = x[vi, vj] + y[vi, vj]
+ m = T.int64()
+ n = T.int64()
+ k = T.int64()
+ x = T.match_buffer(x_handle, (m, n), "float32")
+ y = T.match_buffer(y_handle, (n, k), "float32")
+ z = T.match_buffer(z_handle, (m, k), "float32")
+ for i, j, k in T.grid(m, k, n):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ z[vi, vj] = 0.0
+ z[vi, vj] = z[vi, vj] + x[vi, vk] * y[vk, vj]
@R.function(private=True)
def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n",
"k"), "float32")):
@@ -298,7 +306,7 @@ def test_unused_relax_func_symbolic_shape():
@R.function
def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"),
"float32")):
m, k = T.int64(), T.int64()
- gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((m + 1, k),
dtype="float32"))
+ gv0 = R.call_tir(InputModule.tir_matmul, (x, w), R.Tensor((m, k),
dtype="float32"))
return gv0
mod = InputModule
@@ -306,7 +314,7 @@ def test_unused_relax_func_symbolic_shape():
new_mod = DeadCodeElimination()(mod)
assert check_if_func_exists(new_mod, "main")
- assert check_if_func_exists(new_mod, "tir_add")
+ assert check_if_func_exists(new_mod, "tir_matmul")
assert not check_if_func_exists(new_mod, "unused_func")
diff --git a/tests/python/relax/test_transform_fuse_ops.py
b/tests/python/relax/test_transform_fuse_ops.py
index 17bf586132..9ad66bec01 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -875,7 +875,7 @@ def test_layer_norm_silu():
def main(x: R.Tensor((1, 512, 64, 64), "float32"), mean: R.Tensor((64,
64), "float32"), var: R.Tensor((64, 64), "float32")):
cls = Module
with R.dataflow():
- gv0 = R.call_tir(cls.layer_norm, (x, mean, var),
out_sinfo=R.Tensor((1, 512, 64, 64)))
+ gv0 = R.call_tir(cls.layer_norm, (x, mean, var),
out_sinfo=R.Tensor((1, 512, 64, 64), 'float32'))
gv1 = R.call_tir(cls.relu, gv0, out_sinfo=R.Tensor((1, 512,
64, 64), "float32"))
R.output(gv1)
return gv1
@@ -955,7 +955,7 @@ def test_layer_norm_silu():
R.func_attr({"Primitive": 1})
cls = Expected
with R.dataflow():
- gv0 = R.call_tir(cls.layer_norm, (x, mean, var),
out_sinfo=R.Tensor((1, 512, 64, 64)))
+ gv0 = R.call_tir(cls.layer_norm, (x, mean, var),
out_sinfo=R.Tensor((1, 512, 64, 64), 'float32'))
gv = R.call_tir(cls.relu, (gv0,), out_sinfo=R.Tensor((1, 512,
64, 64), dtype="float32"))
R.output(gv)
return gv
@@ -1452,7 +1452,7 @@ def test_partially_used_tuple_param():
R.Tensor((2,), "float32"),
R.Tensor((2,), "float32"),
R.Tensor((2,), "float32"),
- )
+ ),
):
with R.dataflow():
x0 = x[0]
@@ -1486,7 +1486,7 @@ def test_partially_used_tuple_param():
R.Tensor((2,), dtype="float32"),
R.Tensor((2,), dtype="float32"),
R.Tensor((2,), dtype="float32"),
- )
+ ),
) -> R.Tensor((2,), dtype="float32"):
cls = Expected
with R.dataflow():
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 1582526042..a07875fcda 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -696,10 +696,10 @@ def test_ignore_call_tir():
class Conv2dReLUCallTIR:
@T.prim_func
def relu(
- data: T.Buffer((64, 64, 56, 56), "float32"),
- out: T.Buffer((64, 64, 56, 56), "float32"),
+ data: T.Buffer((1, 64, 56, 56), "float32"),
+ out: T.Buffer((1, 64, 56, 56), "float32"),
):
- for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
+ for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56):
with T.block("root"):
i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
out[i, j, k, l] = T.max(data[i, j, k, l], 0.0)
@@ -714,7 +714,7 @@ def test_ignore_call_tir():
relu1 = R.call_tir(
Conv2dReLUCallTIR.relu,
(conv1,),
- R.Tensor((64, 64, 56, 56), "float32"),
+ R.Tensor((1, 64, 56, 56), "float32"),
)
R.output(relu1)
@@ -724,11 +724,11 @@ def test_ignore_call_tir():
class Conv2dReLUCallTIR_partitioned:
@T.prim_func
def relu(
- data: T.Buffer((64, 64, 56, 56), "float32"),
- out: T.Buffer((64, 64, 56, 56), "float32"),
+ data: T.Buffer((1, 64, 56, 56), "float32"),
+ out: T.Buffer((1, 64, 56, 56), "float32"),
):
# with T.block("root"):
- for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
+ for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56):
with T.block("root"):
i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(data[i, j, k, l])
@@ -754,7 +754,7 @@ def test_ignore_call_tir():
def main(
data: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
- ) -> R.Tensor((64, 64, 56, 56), dtype="float32"):
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
cls = Conv2dReLUCallTIR_partitioned
with R.dataflow():
lv: R.Tensor((1, 64, 56, 56), dtype="float32") =
cls.fused_relax_nn_conv2d(
@@ -763,7 +763,7 @@ def test_ignore_call_tir():
relu1 = R.call_tir(
cls.relu,
(lv,),
- out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32"),
+ out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"),
)
R.output(relu1)
return relu1
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py
b/tests/python/relax/test_transform_lazy_transform_params.py
index 278ac825f7..87a5698f1b 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -43,7 +43,7 @@ def test_lazy_transform_params():
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
- )
+ ),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3),
dtype="float32")
):
@@ -124,7 +124,7 @@ def test_get_item_only():
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
- )
+ ),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3),
dtype="float32")
):
@@ -209,7 +209,7 @@ def test_extra_get_item_params():
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
- )
+ ),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3),
dtype="float32")
):
@@ -298,7 +298,7 @@ def test_extra_set_item_params():
def main_transform_params(
params: R.Tuple(
R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3,
3), dtype="float32")
- )
+ ),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3),
dtype="float32")
):
@@ -441,8 +441,8 @@ def test_lazy_transform_params_with_symbolic_vars():
@T.prim_func(private=True)
def slice_buffer(
Input: T.Buffer((16, 16), "float32"),
- slice_index: T.int64,
Output: T.Buffer(16, "float32"),
+ slice_index: T.int64,
):
for i in T.grid(16):
with T.block("slice_buffer"):
@@ -479,8 +479,8 @@ def test_lazy_transform_params_with_symbolic_vars():
@T.prim_func(private=True)
def slice_buffer(
Input: T.Buffer((16, 16), "float32"),
- slice_index: T.int64,
Output: T.Buffer(16, "float32"),
+ slice_index: T.int64,
):
for i in T.grid(16):
with T.block("slice_buffer"):
@@ -511,7 +511,7 @@ def test_param_shape_symbolic():
params: R.Tuple(
R.Tensor((3, "ic", 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
- )
+ ),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3,
3), dtype="float32")
):
@@ -637,7 +637,7 @@ def test_output():
params: R.Tuple(
R.Tensor((3, "ic", 3, 3), dtype="float32"),
R.Tensor((16, 16, 3, 3), dtype="float32"),
- )
+ ),
) -> R.Tuple(
R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3,
3), dtype="float32")
):
@@ -691,7 +691,7 @@ def test_duplicate_outputs():
class Before:
@R.function
def main_transform_params(
- params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16],
dtype="int32"))
+ params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16],
dtype="int32")),
):
R.func_attr({"relax.force_pure": True})
param0 = params[0]
@@ -966,7 +966,7 @@ def test_get_item_callback_dynamic_shape():
class Expected:
@R.function
def transform_params(
- fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object)
+ fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object),
) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2,
dtype="float32")):
R.func_attr({"num_input": 1})
m = T.int64()
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index f7befd3b88..5a7d76d8fe 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -252,11 +252,15 @@ def test_reshape_dynamic_shape():
]
@R.function
- def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3),
dtype="float32"):
+ def main(
+ x: R.Tensor((8, 16, 128), dtype="float16")
+ ) -> R.Tensor((1, 8, 16, 128), dtype="float16"):
cls = Module
with R.dataflow():
- y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4,
3), dtype="float32"))
- z = R.add(y, R.const(1, "float32"))
+ y = R.call_tir(
+ cls.reshape, (x,), out_sinfo=R.Tensor((1, 8, 16, 128),
dtype="float16")
+ )
+ z = R.add(y, R.const(1, "float16"))
R.output(z)
return z
@@ -290,10 +294,14 @@ def test_reshape_dynamic_shape():
]
@R.function
- def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3),
dtype="float32"):
+ def main(
+ x: R.Tensor((8, 16, 128), dtype="float16")
+ ) -> R.Tensor((1, 8, 16, 128), dtype="float16"):
with R.dataflow():
- y: R.Tensor((2, 4, 3), dtype="float32") = R.reshape(x,
R.shape([2, 4, 3]))
- z: R.Tensor((2, 4, 3), dtype="float32") = R.add(y, R.const(1,
"float32"))
+ y: R.Tensor((1, 8, 16, 128), dtype="float16") = R.reshape(
+ x, R.shape([1, 8, 16, 128])
+ )
+ z: R.Tensor((1, 8, 16, 128), dtype="float16") = R.add(y,
R.const(1, "float16"))
R.output(z)
return z
@@ -383,7 +391,7 @@ def test_tuple_get_reshape():
R.Tensor((2, 4096, 320), dtype="float16"),
R.Tensor((2, 4096, 320), dtype="float16"),
R.Tensor((2, 4096, 320), dtype="float16"),
- )
+ ),
) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
cls = Module
with R.dataflow():
@@ -444,7 +452,7 @@ def test_tuple_get_reshape():
R.Tensor((2, 4096, 320), dtype="float16"),
R.Tensor((2, 4096, 320), dtype="float16"),
R.Tensor((2, 4096, 320), dtype="float16"),
- )
+ ),
) -> R.Tensor((2, 4096, 8, 40), dtype="float16"):
with R.dataflow():
lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0]
@@ -735,7 +743,6 @@ def test_rewrite_dynamic_reshape():
z_handle: T.handle,
N: T.int64,
):
-
y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32")
y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32")
z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32")
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index ea99d49270..64f2efd4af 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -77,7 +77,7 @@ def test_mismatch_cast_dims_and_ndim():
@R.function
def f(
- x: R.Tensor((2, 3), "float32", ndim=3)
+ x: R.Tensor((2, 3), "float32", ndim=3),
): # error: ndim and the shape dims are mismatch
return x
@@ -961,11 +961,11 @@ def test_call_tir_with_tir_var():
class Module:
@R.function
def main(
- dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2",
"float32"))
+ dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2",),
"float32")
) -> R.Tensor(("n * 2",), "float32"):
n = T.int64()
cls = Module
- y = R.call_tir(cls.copy, (x,), R.Tensor(((n * 2,)),
dtype="float32"), tir_vars=(n,))
+ y = R.call_tir(cls.copy, x, R.Tensor((n * 2,), dtype="float32"),
tir_vars=(n,))
return y
@T.prim_func
@@ -2171,7 +2171,9 @@ def test_macro_hygienic():
@R.function(private=True)
def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]):
alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(
- R.shape([4, 4]), R.dtype("float32"), R.prim_value(2) # Make sure
prim_value is 2
+ R.shape([4, 4]),
+ R.dtype("float32"),
+ R.prim_value(2), # Make sure prim_value is 2
)
shape: R.Shape([4, 4]) = R.shape_of(alloc)
shape_1: R.Shape([4, 4]) = shape
@@ -2203,7 +2205,9 @@ def test_macro_non_hygienic():
@R.function(private=True)
def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]):
alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(
- R.shape([4, 4]), R.dtype("float32"), R.prim_value(1) # Make sure
prim_value is 1
+ R.shape([4, 4]),
+ R.dtype("float32"),
+ R.prim_value(1), # Make sure prim_value is 1
)
shape: R.Shape([4, 4]) = R.shape_of(alloc)
shape_1: R.Shape([4, 4]) = shape
@@ -2372,7 +2376,6 @@ def
test_conditional_may_use_symbolic_variables_from_function_scope():
B: R.Tensor(["N"], "float32"),
cond: R.Prim("bool"),
) -> R.Tensor(["N"], "float32"):
-
N = T.int64()
if cond:
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
index 30fd06d4f1..ecf33aa9da 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -988,8 +988,10 @@ def test_multi_systemlib(exec_mode):
I.module_attrs({"system_lib_prefix": "libA_"})
@T.prim_func
- def tir_init(x: T.Buffer((2), "float32")) -> None:
- for i in range(2):
+ def tir_init(x_handle: T.handle):
+ N = T.int64()
+ x = T.match_buffer(x_handle, [N], "float32")
+ for i in range(N):
x[i] = T.float32(0)
@R.function
@@ -1003,8 +1005,10 @@ def test_multi_systemlib(exec_mode):
I.module_attrs({"system_lib_prefix": "libB_"})
@T.prim_func
- def tir_init(x: T.Buffer((2), "float32")) -> None:
- for i in range(2):
+ def tir_init(x_handle: T.handle):
+ N = T.int64()
+ x = T.match_buffer(x_handle, [N], "float32")
+ for i in range(N):
x[i] = T.float32(1)
@R.function