This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eb5458e0e9 [Relax] Allow R.Prim('bool') in relax::If and assert_op
(#16642)
eb5458e0e9 is described below
commit eb5458e0e9b93001bb7e4a69d7d4e393cd55c933
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Mar 28 16:53:47 2024 -0500
[Relax] Allow R.Prim('bool') in relax::If and assert_op (#16642)
* [TIR][Analysis] Implemented tir.analysis.is_pure_function
This commit introduces two related utilities,
`tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`.
In contrast to the existing `tvm::tir::SideEffect`, which checks for
side effects on a for a `PrimExpr`, `is_pure_function` checks for side
effects for the function as a whole.
* [Transform] Implement relax.transform.ComputePrimValue
Prior to this commit, while expressions of type `DataType::Int(64)`
could be computed in the `relax.transform.VMShapeLower`, expressions
of any other type could not. This commit introduces
`relax.transform.ComputePrimValue`, which produces `PrimFunc`
subroutines to compute `PrimExpr` values of any dtype.
This functionality will allow boolean values to be computed based on
the symbolic values known at runtime.
* [Relax] Allow R.Prim('bool') in relax::If and assert_op
Prior to this commit, the condition used for `relax::If` node and the
`"relax.assert_op"` operator was required to be a scalar tensor. This
made it difficult to alter behavior based on a runtime shape
parameter. For example, delegating to a vectorized implementation
based on a whether a tensor shape is divisible by the vector size.
This commit adds support for expressions of type `R.Prim('bool')` as
the conditional for `relax::If` and `"relax.assert_op"`, to allow
these use cases.
* Lint fix
---
include/tvm/tir/analysis.h | 15 +-
python/tvm/error.py | 1 +
python/tvm/relax/op/base.py | 44 +++--
python/tvm/relax/pipeline.py | 1 +
python/tvm/relax/transform/__init__.py | 1 +
python/tvm/relax/transform/transform.py | 19 ++
python/tvm/script/ir_builder/relax/ir.py | 15 +-
python/tvm/script/parser/tir/parser.py | 33 +++-
python/tvm/tir/analysis/analysis.py | 10 ++
src/relax/analysis/struct_info_analysis.cc | 6 +-
src/relax/backend/vm/vm_shape_lower.cc | 1 +
src/relax/op/tensor/inspect.cc | 4 +-
src/relax/transform/compute_prim_value.cc | 94 ++++++++++
src/relax/transform/dataflow_inplace.cc | 45 ++---
src/relax/utils.cc | 17 +-
src/tir/analysis/is_pure_function.cc | 97 ++++++++++
src/tir/ir/function.cc | 43 +++++
src/tir/ir/specialize.cc | 10 +-
src/tir/transforms/renew_defs.cc | 6 +-
tests/python/relax/test_analysis_well_formed.py | 46 +++++
.../relax/test_backend_transform_shape_lower.py | 84 +++++++++
tests/python/relax/test_relax_operators.py | 195 ++++++++++++++++-----
tests/python/relax/test_transform.py | 12 +-
.../relax/test_transform_compute_prim_value.py | 104 +++++++++++
tests/python/relax/test_tvmscript_parser.py | 147 +++++++++++++++-
tests/python/relax/test_vm_codegen_tir.py | 2 +-
.../test_tir_analysis_is_pure_function.py | 104 +++++++++++
tests/python/tir-base/test_tir_specialize.py | 27 ++-
.../python/tvmscript/test_tvmscript_parser_tir.py | 109 ++++++++++++
29 files changed, 1154 insertions(+), 138 deletions(-)
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index c4ae5d573b..96459f25ec 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -117,13 +117,26 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr, const Array<Var>& defs);
/*!
- * \brief Analyze the side effect
+ * \brief Analyze the side effect of an expression
* \param expr The expression to be checked.
*
* \return CallEffectKind, can be kPure, kReadState or kUpdateState
*/
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
+/*!
+ * \brief Analyze the side effect of a function
+ *
+ * \param func The expression to be checked.
+ *
+ * \param assert_on_error If true, an error will be thrown for an
+ * impure function. If false (default), the purity of the PrimFunc
+ * will be returned.
+ *
+ * \return The purity of the function
+ */
+TVM_DLL bool IsPureFunction(const PrimFunc& func, bool assert_on_error =
false);
+
/*!
* \brief Whether the given Stmt uses any var in the given variable set.
* \param stmt The Stmt to be checked.
diff --git a/python/tvm/error.py b/python/tvm/error.py
index cc0180593d..6bf9b16850 100644
--- a/python/tvm/error.py
+++ b/python/tvm/error.py
@@ -54,6 +54,7 @@ register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError)
register_error("KeyError", KeyError)
register_error("IndexError", IndexError)
+register_error("AssertionError", AssertionError)
@register_error
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 3effec242d..756d250c16 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -503,19 +503,26 @@ def relax_assert_op(condition: tvm.Object, format_str:
str, *format_args: tvm.Ob
f"The format string argument to assert must be a string, given
{type(format_str)})"
)
- # should be guaranteed by the type system
- if not isinstance(condition, tvm.nd.NDArray):
- raise ValueError(f"The condition must be an NDArray, but given a
{type(condition)}.")
-
- # may happen if the original program had unknown shape or dtype for the
tensor's type
- dtype = condition.dtype
- if dtype != "bool":
- raise ValueError(f"The condition must be a bool scalar, but given a
{dtype} tensor")
- shape = condition.shape
- if len(shape) != 0:
- raise ValueError(f"The condition must be a scalar, but it has a shape
of {shape}")
-
- val = condition.numpy()
+ if isinstance(condition, (bool, int)):
+ val = condition
+ elif isinstance(condition, tvm.nd.NDArray):
+ # may happen if the original program had unknown shape or dtype for
the tensor's type
+ dtype = condition.dtype
+ if dtype != "bool":
+ raise ValueError(f"The condition must be a bool scalar, but given
a {dtype} tensor")
+ shape = condition.shape
+ if len(shape) != 0:
+ raise ValueError(f"The condition must be a scalar, but it has a
shape of {shape}")
+
+ val = condition.numpy()
+
+ else:
+ # should be guaranteed by the type system
+ raise ValueError(
+ f"The condition for relax assert must be a bool, int, or NDArray, "
+ f"but received a {type(condition)}."
+ )
+
if not val:
error_message = "Assertion Failed"
if format_args or format_str != "":
@@ -528,7 +535,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str,
*format_args: tvm.Ob
def assert_op(
- condition: Expr,
+ condition: Union[Expr, PrimExpr],
format_args: Optional[Union[Expr, List[Expr]]] = None,
format: Union[str, Expr] = "",
) -> Expr:
@@ -538,7 +545,7 @@ def assert_op(
Parameters
----------
- condition: Expr
+ condition: Union[Expr, PrimExpr]
The assertion condition.
format_args: Optional[Union[Expr, List[Expr]]]
@@ -552,12 +559,17 @@ def assert_op(
result : Expr
A Call to the Relax assert operation.
"""
+ if not isinstance(condition, Expr):
+ condition = tvm.relax.PrimValue(condition)
+
if format_args is None:
format_args = []
- if isinstance(format_args, Expr): # type: ignore
+ elif isinstance(format_args, Expr):
format_args = [format_args]
+
if isinstance(format, str):
format = StringImm(format)
+
return _ffi_api.assert_op(condition, format_args, format) # type: ignore
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index 474833bdfd..36ba46a1a5 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -92,6 +92,7 @@ def default_build_pipeline():
transform.LowerAllocTensor(),
transform.KillAfterLastUse(),
transform.VMBuiltinLower(),
+ transform.ComputePrimValue(),
transform.VMShapeLower(),
transform.AttachGlobalSymbol(),
],
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index 5f10c39d82..11e301c26c 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -28,6 +28,7 @@ from .transform import (
CallTIRRewrite,
CanonicalizeBindings,
CombineParallelMatmul,
+ ComputePrimValue,
ConvertLayout,
ConvertToDataflow,
DataflowBlockPass,
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index ef10f5791d..dbc35d48d3 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -486,6 +486,25 @@ def KillAfterLastUse() -> tvm.ir.transform.Pass:
return _ffi_api.KillAfterLastUse() # type: ignore
+def ComputePrimValue() -> tvm.ir.transform.Pass:
+ """Compute all R.prim_value instances
+
+ While high-level relax can include expressions in terms of its
+ symbolic variables, these expressions cannot natively be computed
+ within relax. In order to provide values for symbolic expressions
+ (e.g. `R.prim_value(N*N)`, where `N` is a symbolic variable), this
+ pass generates a PrimFunc in which the expression can be computed.
+ The relax graph is then updated to include a call to that
+ PrimFunc, in place of the original `R.prim_value(expr)`.
+
+ Returns
+ -------
+ ret : tvm.ir.transform.Pass
+
+ """
+ return _ffi_api.ComputePrimValue() # type: ignore
+
+
def VMBuiltinLower() -> tvm.ir.transform.Pass:
"""Lowering generic intrinsic to VM intrinsics.
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 3e1927290d..6dbf5c5dfd 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -511,18 +511,25 @@ def SeqExpr() -> frame.SeqExprFrame: # pylint:
disable=invalid-name
############################# If Then Else #############################
-def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name
+def If(condition: Union[Expr, PrimExpr]) -> frame.IfFrame: # pylint:
disable=invalid-name
"""Create an if frame.
+
Parameters
----------
- condition : Expr
- The condition of if statement, executes the true branch if the
condition is true,
- otherwise jump into the false branch.
+ condition : Union[Expr, PrimExpr]
+
+ The condition of if statement, executes the true branch if the
+ condition is true, otherwise jump into the false branch.
+
Returns
-------
res : frame.IfFrame
The result IfFrame.
+
"""
+ if not isinstance(condition, Expr):
+ condition = relax.PrimValue(condition)
+
return _ffi_api.If(condition) # type: ignore[attr-defined] # pylint:
disable=no-member
diff --git a/python/tvm/script/parser/tir/parser.py
b/python/tvm/script/parser/tir/parser.py
index 0f3f3de60f..679ae4e8ad 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -537,12 +537,31 @@ def visit_tvm_declare_function(self: Parser, node:
doc.FunctionDef) -> GlobalVar
The doc AST return node.
"""
- ret_type = None
- if node.returns is not None:
- ret_type = self.eval_expr(node.returns)
- if callable(ret_type):
- ret_type = PrimType(ret_type().dtype)
+ supplied_annotation = self.function_annotations
+ func_annotation = supplied_annotation.get(node.name, {})
- # Only ret_type is needed for func_signature.
- func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
+ ret_type = None
+ with self.var_table.with_frame():
+ if node.returns is not None:
+ ret_type = self.eval_expr(node.returns)
+ if callable(ret_type):
+ ret_type = PrimType(ret_type().dtype)
+
+ arg_annotations = []
+ for arg in node.args.args:
+ if arg.annotation is None:
+ self.report_error(arg, "Type annotation required for function
parameters.")
+ try:
+ ann = self.eval_expr(arg.annotation)
+ if callable(ann):
+ ann = ann()
+ except Exception: # pylint: disable=broad-except
+ ann = func_annotation.get(arg.arg, None)
+ if ann is None:
+ raise
+
+ IRBuilder.name(arg.arg, ann)
+ arg_annotations.append(ann)
+
+ func_signature = tvm.tir.PrimFunc(arg_annotations, None, ret_type=ret_type)
return I.decl_function(node.name, func_signature)
diff --git a/python/tvm/tir/analysis/analysis.py
b/python/tvm/tir/analysis/analysis.py
index 8d7e81d7d0..67eb7471d2 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -417,3 +417,13 @@ def get_vtcm_compaction_passes() ->
List[tvm.transform.Pass]:
returns list of passes
"""
return _ffi_api.get_vtcm_compaction_passes() # type: ignore # pylint:
disable=no-member
+
+
+def is_pure_function(func: PrimFunc) -> bool:
+ """Checks if the function is a pure function"""
+ return _ffi_api.is_pure_function(func, False) # type: ignore # pylint:
disable=no-member
+
+
+def assert_pure_function(func: PrimFunc) -> bool:
+ """Asserts that the function is a pure function"""
+ return _ffi_api.is_pure_function(func, True) # type: ignore # pylint:
disable=no-member
diff --git a/src/relax/analysis/struct_info_analysis.cc
b/src/relax/analysis/struct_info_analysis.cc
index b1932f9b5d..08e2acfbd0 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -840,8 +840,10 @@ class CallRetStructInfoDeriver : public
StructInfoBaseChecker {
auto params = finfo->params.value();
if (params.size() != call->args.size()) {
ctx->ReportFatal(Diagnostic::Error(call->span)
- << "number of arguments and parameters mismatch:"
- << " expected " << params.size() << ", given " <<
call->args.size());
+ << "Number of arguments and parameters mismatch:"
+ << " Function " << call->op << " has struct info " <<
finfo
+ << " and accepts " << params.size() << " parameters,
but was called with "
+ << call->args.size() << " arguments (" << call->args <<
")");
}
// Visit each param arg pair, check and populate the var map
for (size_t i = 0; i < params.size(); ++i) {
diff --git a/src/relax/backend/vm/vm_shape_lower.cc
b/src/relax/backend/vm/vm_shape_lower.cc
index 06c2e31767..8dca06c840 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -85,6 +85,7 @@ class PrimExprSlotCollector : public ExprVisitor, public
StructInfoVisitor {
collector.VisitExpr(param);
}
collector.VisitExpr(func->body);
+ collector.VisitStructInfo(func->ret_struct_info);
}
private:
diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc
index 186fc9fa86..3772e530ed 100644
--- a/src/relax/op/tensor/inspect.cc
+++ b/src/relax/op/tensor/inspect.cc
@@ -107,7 +107,7 @@ tir::PrimFunc
GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType
FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)},
PrimStructInfo(field_dtype));
- UpdateStructInfo(func, sinfo);
+ func->struct_info_ = sinfo;
return func;
}
@@ -338,7 +338,7 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const
Call& call) {
FuncStructInfo sinfo(
{TensorStructInfo(DataType::Void(), kUnknownNDim),
PrimStructInfo(axis->dtype)},
PrimStructInfo(field_dtype));
- UpdateStructInfo(func, sinfo);
+ func->struct_info_ = sinfo;
return func;
}();
diff --git a/src/relax/transform/compute_prim_value.cc
b/src/relax/transform/compute_prim_value.cc
new file mode 100644
index 0000000000..9fe2a3a06f
--- /dev/null
+++ b/src/relax/transform/compute_prim_value.cc
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace relax {
+
+namespace {
+
+class PrimValueComputeInjector : public ExprMutator {
+ public:
+ IRModule Finalize() const { return builder_->Finalize(); }
+
+ using ExprMutator::VisitExpr_;
+
+ Expr VisitExpr_(const PrimValueNode* op) override {
+ auto node = Downcast<PrimValue>(ExprMutator::VisitExpr_(op));
+
+ if (node->value->IsInstance<tir::IntImmNode>() ||
node->value->IsInstance<tir::VarNode>()) {
+ return node;
+ }
+
+ auto ret_dtype = node->value->dtype;
+ auto param_vars = tir::UndefinedVars(node->value);
+ tir::Stmt body = tir::Evaluate(tir::Call(ret_dtype, tir::builtin::ret(),
{node->value}));
+
+ tir::PrimFunc func(param_vars, body, PrimType(ret_dtype));
+ func = tir::RenewDefs(func);
+
+ auto callee = builder_->AddFunction(func, "compute_symbolic_expr");
+
+ return relax::Call(callee, param_vars.Map([](const tir::Var& tir_var) ->
relax::Expr {
+ return relax::PrimValue(tir_var);
+ }));
+ }
+};
+
+} // namespace
+
+namespace transform {
+
+Pass ComputePrimValue() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule mod, PassContext pc) -> IRModule {
+ PrimValueComputeInjector mutator;
+
+ IRModule updates;
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto func = base_func.as<Function>()) {
+ auto updated = Downcast<Function>(mutator(func.value()));
+ if (!updates.same_as(base_func)) {
+ updates->Add(gvar, updated);
+ }
+ }
+ }
+
+ if (updates->functions.size()) {
+ auto write_ptr = mod.CopyOnWrite();
+ write_ptr->Update(updates);
+ write_ptr->Update(mutator.Finalize());
+ }
+
+ return mod;
+ };
+ return CreateModulePass(pass_func, 0, "ComputePrimValue", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ComputePrimValue").set_body_typed(ComputePrimValue);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/dataflow_inplace.cc
b/src/relax/transform/dataflow_inplace.cc
index 755c5dbab4..0912981775 100644
--- a/src/relax/transform/dataflow_inplace.cc
+++ b/src/relax/transform/dataflow_inplace.cc
@@ -877,10 +877,12 @@ class ModuleInplaceTransformer : public ExprMutator {
auto inline_legal_op_name = legal_op->name_hint + "_inplace";
auto mod = builder_->GetContextIRModule();
- auto legal_primfunc = Downcast<tir::PrimFunc>(mod->Lookup(legal_op));
- auto* legal_primfunc_cow = legal_primfunc.CopyOnWrite();
+ auto old_primfunc = Downcast<tir::PrimFunc>(mod->Lookup(legal_op));
+
+ tir::Stmt new_body = old_primfunc->body;
+
size_t num_outs = inplace_indices.size();
- size_t num_params = legal_primfunc->params.size();
+ size_t num_params = old_primfunc->params.size();
// the replacement we must make:
// 1. For each output var, replace its corresponding buffers with the
corresponding inplace
@@ -893,42 +895,43 @@ class ModuleInplaceTransformer : public ExprMutator {
Map<tir::Var, tir::Var> var_subst_map;
for (size_t i = 0; i < num_outs; i++) {
// we will substitute output i with the corresponding param indicated by
inplace indices
- auto output_var = legal_primfunc->params[num_params - num_outs + i];
- auto inplace_var = legal_primfunc->params[inplace_indices[i].IntValue()];
+ auto output_var = old_primfunc->params[num_params - num_outs + i];
+ auto inplace_var = old_primfunc->params[inplace_indices[i].IntValue()];
var_subst_map.Set(output_var, inplace_var);
// also do the same with the buffer vars
- auto output_buffer = legal_primfunc->buffer_map.at(output_var);
- auto inplace_buffer = legal_primfunc->buffer_map.at(inplace_var);
+ auto output_buffer = old_primfunc->buffer_map.at(output_var);
+ auto inplace_buffer = old_primfunc->buffer_map.at(inplace_var);
var_subst_map.Set(output_buffer->data, inplace_buffer->data);
buffer_subst_map.Set(output_buffer, inplace_buffer);
}
// apply substitutions
- legal_primfunc_cow->body = RemapBuffers(legal_primfunc->body,
buffer_subst_map);
- legal_primfunc_cow->body = tir::Substitute(
- legal_primfunc->body, [&var_subst_map](const tir::Var& v) ->
Optional<PrimExpr> {
- if (var_subst_map.count(v)) {
- return var_subst_map.at(v);
- }
- return Optional<PrimExpr>();
- });
+ new_body = RemapBuffers(new_body, buffer_subst_map);
+ new_body = tir::Substitute(new_body, [&var_subst_map](const tir::Var& v)
-> Optional<PrimExpr> {
+ if (var_subst_map.count(v)) {
+ return var_subst_map.at(v);
+ }
+ return Optional<PrimExpr>();
+ });
// remove the now-unused outputs from the buffer map
- auto buffer_map = legal_primfunc->buffer_map;
+ auto new_buffer_map = old_primfunc->buffer_map;
for (size_t i = 0; i < num_outs; i++) {
- buffer_map.erase(legal_primfunc->params[num_params - num_outs + i]);
+ new_buffer_map.erase(old_primfunc->params[num_params - num_outs + i]);
}
- legal_primfunc_cow->buffer_map = buffer_map;
// now get rid of the last num_outputs arguments
// (couldn't do earlier or else it would have thrown off the indexing)
- legal_primfunc_cow->params = Array<tir::Var>(
- legal_primfunc->params.begin(), legal_primfunc->params.begin() +
(num_params - num_outs));
+ Array<tir::Var> new_params(old_primfunc->params.begin(),
+ old_primfunc->params.begin() + (num_params -
num_outs));
+
+ tir::PrimFunc new_primfunc(new_params, new_body, old_primfunc->ret_type,
new_buffer_map,
+ old_primfunc->attrs, old_primfunc->span);
// note: this might be a good time to get rid of the old legalized
function, but we don't do it
// now because later ops might need the same one. Instead, we will clean
up at the end
- auto new_gv = builder_->AddFunction(legal_primfunc, inline_legal_op_name);
+ auto new_gv = builder_->AddFunction(new_primfunc, inline_legal_op_name);
// update the call (change the op, update the argument, change the attrs)
legalized_call_cow->op = call_tir_inplace_op;
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index efb2d02204..a15ee79fac 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -220,12 +220,21 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
bool permit_unknown_dtype) {
- const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
- if (!tt) {
+ DataType dtype;
+ int ndim;
+
+ if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
+ dtype = tensor->dtype;
+ ndim = tensor->ndim;
+ } else if (const auto* prim = sinfo.as<PrimStructInfoNode>()) {
+ dtype = prim->dtype;
+ ndim = 0;
+ } else {
return false;
}
- bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype &&
tt->dtype.is_void());
- bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1);
+
+ bool correct_dtype = dtype.is_bool() || (permit_unknown_dtype &&
dtype.is_void());
+ bool correct_rank = ndim == 0 || (permit_unknown_rank && ndim == -1);
return correct_dtype && correct_rank;
}
diff --git a/src/tir/analysis/is_pure_function.cc
b/src/tir/analysis/is_pure_function.cc
new file mode 100644
index 0000000000..c9934c4bcf
--- /dev/null
+++ b/src/tir/analysis/is_pure_function.cc
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file is_pure_function.cc
+ * \brief PrimFunc purity analysis
+ */
+#include <tvm/ir/op.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../ir/tir_visitor_with_path.h"
+
+namespace tvm {
+namespace tir {
+
+namespace {
+class PurityChecker : TIRVisitorWithPath {
+ public:
+ static bool Check(const PrimFunc& func, bool assert_on_error) {
+ PurityChecker visitor(assert_on_error);
+ visitor(func);
+ return visitor.is_pure_;
+ }
+
+ private:
+ explicit PurityChecker(bool assert_on_error) :
assert_on_error_(assert_on_error) {}
+
+ void VisitStmt_(const AllocateNode* op, ObjectPath path) override {
+ internal_allocations_.insert(op->buffer_var);
+ TIRVisitorWithPath::VisitStmt_(op, path);
+ }
+
+ void VisitStmt_(const BufferStoreNode* op, ObjectPath path) override {
+ TIRVisitorWithPath::VisitStmt_(op, path);
+
+ if (!internal_allocations_.count(op->buffer->data)) {
+ is_pure_ = false;
+ LOG_IF(FATAL, assert_on_error_) << "AssertionError: "
+ << "Pure functions must not write to
buffers, "
+ << ", but function contains store to "
<< op->buffer
+ << op->indices << " of value " <<
op->value;
+ }
+ }
+
+ void VisitExpr_(const CallNode* call, ObjectPath path) override {
+ TIRVisitorWithPath::VisitExpr_(call, path);
+
+ static auto op_call_effect =
Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
+ CallEffectKind effect = [&]() {
+ if (auto opt = call->op.as<Op>()) {
+ return static_cast<CallEffectKind>(op_call_effect[opt.value()]->value);
+ } else {
+ return CallEffectKind::kOpaque;
+ }
+ }();
+
+ if (effect == CallEffectKind::kUpdateState || effect ==
CallEffectKind::kOpaque) {
+ is_pure_ = false;
+ LOG_IF(FATAL, assert_on_error_)
+ << "AssertionError: "
+ << "Pure functions must not contain calls to impure operators, "
+ << "but " << GetRef<PrimExpr>(call) << " calls operator " << call->op
+ << ", which has side effect " << effect;
+ }
+ }
+
+ bool assert_on_error_{false};
+ bool is_pure_{true};
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> internal_allocations_;
+};
+} // namespace
+
+bool IsPureFunction(const PrimFunc& func, bool assert_on_error) {
+ return PurityChecker::Check(func, assert_on_error);
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.is_pure_function").set_body_typed(IsPureFunction);
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index 5067d90838..8a3d2d6947 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -21,12 +21,52 @@
* \file src/tir/ir/function.cc
* \brief The function data structure.
*/
+#include <tvm/relax/struct_info.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace tir {
+namespace {
+relax::StructInfo InferStructInfo(const PrimFunc& prim_func) {
+ Array<relax::StructInfo> params;
+ for (const auto& param : prim_func->params) {
+ relax::StructInfo param_sinfo = [&]() -> relax::StructInfo {
+ if (auto opt_buf = prim_func->buffer_map.Get(param)) {
+ auto buf = opt_buf.value();
+ relax::ShapeExpr shape(
+ buf->shape.Map([](PrimExpr dim) { return cast(DataType::Int(64),
dim); }));
+ return relax::TensorStructInfo(shape, buf->dtype);
+ }
+
+ if (auto prim_type = param->type_annotation.as<PrimTypeNode>();
+ prim_type && prim_type->dtype.is_handle()) {
+ return relax::ObjectStructInfo();
+ }
+
+ return relax::PrimStructInfo(param->dtype);
+ }();
+ params.push_back(param_sinfo);
+ }
+
+ relax::StructInfo ret = [&]() -> relax::StructInfo {
+ if (const auto* prim = prim_func->ret_type.as<PrimTypeNode>()) {
+ return relax::PrimStructInfo(prim->dtype);
+ } else if (IsVoidType(prim_func->ret_type)) {
+ return relax::TupleStructInfo(Array<relax::StructInfo>{});
+ } else {
+ return relax::ObjectStructInfo();
+ }
+ }();
+
+ bool purity = prim_func->body.defined() ? IsPureFunction(prim_func) : false;
+
+ return relax::FuncStructInfo(params, ret, purity);
+}
+} // namespace
+
// Get the function type of a PrimFunc
PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span
span) {
@@ -42,8 +82,11 @@ PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type
ret_type,
n->buffer_map = std::move(buffer_map);
n->attrs = std::move(attrs);
n->checked_type_ = n->func_type_annotation();
+ n->struct_info_ = relax::FuncStructInfo::OpaqueFunc();
n->span = std::move(span);
data_ = std::move(n);
+
+ (*this)->struct_info_ = InferStructInfo(*this);
}
FuncType PrimFuncNode::func_type_annotation() const {
diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
index 8095b3141f..924ef9a0cd 100644
--- a/src/tir/ir/specialize.cc
+++ b/src/tir/ir/specialize.cc
@@ -105,14 +105,10 @@ class PrimFuncSpecializer : public StmtExprMutator {
Stmt body = specializer(f->body);
if (param_updated || buffer_map_updated || !f->body.same_as(body)) {
- PrimFuncNode* f_ptr = f.CopyOnWrite();
- f_ptr->params = std::move(params);
- f_ptr->buffer_map = std::move(buffer_map);
- f_ptr->body = std::move(body);
- f_ptr->struct_info_ = NullOpt;
- f_ptr->checked_type_ = Type(nullptr);
+ return PrimFunc(params, body, f->ret_type, buffer_map, f->attrs,
f->span);
+ } else {
+ return f;
}
- return f;
}
private:
diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc
index 8a122f8922..28d1100f6b 100644
--- a/src/tir/transforms/renew_defs.cc
+++ b/src/tir/transforms/renew_defs.cc
@@ -76,11 +76,7 @@ class RenewDefMutator : public StmtExprMutator {
// Visit body
Stmt body = generator(func->body);
// Recreate function
- auto n = make_object<PrimFuncNode>(*func.get());
- n->params = std::move(params);
- n->buffer_map = std::move(buffer_map);
- n->body = std::move(body);
- return PrimFunc(n);
+ return PrimFunc(params, body, func->ret_type, buffer_map, func->attrs,
func->span);
}
private:
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
index b76b95646a..7deddfd28e 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -22,6 +22,7 @@ from tvm import tir
from tvm.script import relax as R
from tvm.script import ir as I
from tvm.script import tir as T
+from tvm.script import ir as I
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
@@ -656,5 +657,50 @@ def test_well_formed_function_referencing_global_var():
assert rx.analysis.well_formed(Module["subroutine"])
+def test_pass_dltensor_arg_to_tir():
+ """Relax may pass R.Tensor as DLTensor
+
+ In TIR, a `DLTensor*` argument with unknown shape and dtype is
+ represented as a `tir.Var` with
+ `tvm::PrimType(DataType::Handle())`, and with no entry in the
+ `PrimFuncNode::buffer_map`. In Relax, this is represented as
+ `R.Tensor`. Calls from Relax to TIR that pass a tensor of unknown
+ rank/shape are well-formed.
+
+ In the test case below, a TIR function accepts an arbitrary
+ `R.Tensor`, and returns a boolean value based on inspection of the
+ runtime datatype.
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor) -> R.Prim("bool"):
+ return Module.is_bfloat16_dtype(A)
+
+ @T.prim_func(private=True)
+ def is_bfloat16_dtype(tensor: T.handle) -> T.bool:
+ T.func_attr({"tir.is_scheduled": True, "tir.is_host_func": True})
+
+ # From #include <tvm/tir/builtin.h>
+ kArrTypeCode = T.meta_var(5)
+ kArrTypeBits = T.meta_var(6)
+ kArrTypeLanes = T.meta_var(7)
+
+ # From #include <dlpack/dlpack.h>
+ kDLBfloat = T.meta_var(4)
+
+ type_code = T.tvm_struct_get(tensor, 0, kArrTypeCode,
dtype="uint8")
+ type_bits = T.tvm_struct_get(tensor, 0, kArrTypeBits,
dtype="uint8")
+ type_lanes = T.tvm_struct_get(tensor, 0, kArrTypeLanes,
dtype="uint16")
+
+ is_bfloat16: T.bool = (
+ (type_code == kDLBfloat) and (type_bits == 16) and (type_lanes
== 1)
+ )
+ return is_bfloat16
+
+ assert rx.analysis.well_formed(Module)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py
b/tests/python/relax/test_backend_transform_shape_lower.py
index 31eb4b26be..fccf3a5f8a 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -452,6 +452,90 @@ def test_return_match_check():
assert_structural_equal(after, expected)
+def test_return_match_check_with_new_expr():
+ """Like test_return_match_check, but requires a computation
+
+ When return body is not same as ret_struct_info, a runtime match
+ check is required. This match check may require a symbolic
+ expression to be computed.
+ """
+ MS = MatchShapeCode
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"],
"float32"):
+ R.func_attr({"relax.force_pure": True})
+ out = R.call_packed("flatten_matrix", x, sinfo_args=R.Object)
+ return out
+
+ # slot assignment:
+ sindex = {
+ "n": 0,
+ "n * n": 1,
+ }
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"],
"float32"):
+ R.func_attr({"relax.force_pure": True})
+ shape_heap = R.call_builtin_with_ctx(
+ "vm.builtin.alloc_shape_heap",
+ [R.prim_value(2)],
+ sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+ )
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "",
sinfo_args=[R.Tuple()]
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ x,
+ shape_heap,
+ 2,
+ MS.STORE_TO_HEAP,
+ sindex["n"],
+ MS.ASSERT_EQUAL_TO_LOAD,
+ sindex["n"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+
+ _ = Expected.shape_func(shape_heap)
+
+ out = R.call_packed("flatten_matrix", x, sinfo_args=R.Object)
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info",
+ out,
+ 1,
+ R.dtype("float32"),
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ out,
+ shape_heap,
+ 1,
+ MS.ASSERT_EQUAL_TO_LOAD,
+ sindex["n * n"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ return out
+
+ @T.prim_func(private=True)
+ def shape_func(H: T.Buffer(T.int64(2), "int64")):
+ # generated compute function
+ T.func_attr({"tir.is_host_func": 1})
+ H[T.int64(sindex["n * n"])] = H[T.int64(sindex["n"])] *
H[T.int64(sindex["n"])]
+
+ before = Before
+ expected = Expected
+ after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+ assert_structural_equal(after, expected)
+
+
def test_symbolic_shape_multiple_function():
MS = MatchShapeCode
MK = MakeShapeCode
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index a278b09167..41618a32cb 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -19,6 +19,8 @@ import sys
import tempfile
import numpy as np
+import pytest
+
import tvm
import tvm.testing
from tvm import relax
@@ -35,13 +37,18 @@ class InputModule:
return y, y_sorted
-def run_cpu(mod, func_name, *input):
+def run_cpu(mod, func_name, *args):
+ if isinstance(mod, relax.Function):
+ func = mod
+ args = [func_name, *args]
+ func_name = func.attrs["global_symbol"]
+ mod = tvm.IRModule.from_expr(func)
+
target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
- vm.set_input(func_name, *input)
- vm.invoke_stateful(func_name)
- return vm.get_outputs(func_name)
+
+ return vm[func_name](*args)
def test_unique():
@@ -88,67 +95,108 @@ def test_print():
sys.stdout = stdout
[email protected]_module
-class AssertOpTest:
+def test_assert_passes():
@R.function(pure=False)
- def passes(x: R.Tensor((), "int32")):
- p1 = R.assert_op(relax.const(True))
+ def func(x: R.Tensor((), "int32")):
+ _ = R.assert_op(relax.const(True))
return x
+ run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+
+
+def test_assert_passes_with_format_args():
@R.function(pure=False)
- def pass_with_args(x: R.Tensor((), "int32")):
- p1 = R.assert_op(relax.const(True), x, format="You won't see me")
+ def func(x: R.Tensor((), "int32")):
+ _ = R.assert_op(relax.const(True), x, format="You won't see me")
return x
+ run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+
+
+def test_assert_fails():
+ @R.function(pure=False)
+ def func(x: R.Tensor((), "int32")):
+ _ = R.assert_op(relax.const(False))
+ return x
+
+ with pytest.raises(AssertionError, match="Assertion Failed"):
+ run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+
+
+def test_assert_fails_with_message():
@R.function(pure=False)
- def simple_fail(x: R.Tensor((), "int32")):
- p1 = R.assert_op(relax.const(False))
+ def func(x: R.Tensor((), "int32")):
+ _ = R.assert_op(relax.const(False), format="I failed...")
return x
+ with pytest.raises(AssertionError, match="I failed..."):
+ run_cpu(func, tvm.nd.array(np.array(1).astype("int32")))
+
+
+def test_assert_fails_with_args():
@R.function(pure=False)
- def fail_with_message(x: R.Tensor((), "int32")):
- p1 = R.assert_op(relax.const(False), format="I failed...")
+ def func(x: R.Tensor((), "int32")):
+ _ = R.assert_op(relax.const(False), [x, x])
return x
+ with pytest.raises(AssertionError, match="5, 5"):
+ run_cpu(func, tvm.nd.array(np.array(5).astype("int32")))
+
+
+def test_assert_fails_with_formatted_args():
@R.function(pure=False)
- def fail_with_args(x: R.Tensor((), "int32")):
- # no format
- p1 = R.assert_op(relax.const(False), [x, x])
+ def func(x: R.Tensor((), "int32")):
+ _ = R.assert_op(relax.const(False), x, format="Number: {}")
return x
+ with pytest.raises(AssertionError, match="Number: 6"):
+ run_cpu(func, tvm.nd.array(np.array(6).astype("int32")))
+
+
+def test_assert_on_argument_passes():
@R.function(pure=False)
- def fail_with_formatted_message(x: R.Tensor((), "int32")):
- p1 = R.assert_op(relax.const(False), x, format="Number: {}")
+ def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")):
+ _ = R.assert_op(condition)
return x
+ condition = tvm.nd.array(np.array(True))
+ x = tvm.nd.array(np.array(5).astype("int32"))
+ run_cpu(func, condition, x)
-def test_assert_op():
- def check_assertion_error(func_name, func_arg, expected_message):
- passed = False
- try:
- run_cpu(AssertOpTest, func_name, func_arg)
- passed = True
- except TVMError as e:
- # TVM will print out a TVMError that will contain the
- # generated error at the bottom of a stack trace
- assert "AssertionError" in e.args[0]
- assert expected_message in e.args[0]
- except AssertionError:
- return
- assert not passed
-
- run_cpu(AssertOpTest, "passes", tvm.nd.array(np.array(1).astype("int32")))
- run_cpu(AssertOpTest, "pass_with_args",
tvm.nd.array(np.array(2).astype("int32")))
- check_assertion_error(
- "simple_fail", tvm.nd.array(np.array(3).astype("int32")), "Assertion
Failed"
- )
- check_assertion_error(
- "fail_with_message", tvm.nd.array(np.array(4).astype("int32")), "I
failed..."
- )
- check_assertion_error("fail_with_args",
tvm.nd.array(np.array(5).astype("int32")), "5, 5")
- check_assertion_error(
- "fail_with_formatted_message",
tvm.nd.array(np.array(6).astype("int32")), "Number: 6"
- )
+
+def test_assert_on_argument_fails():
+ @R.function(pure=False)
+ def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")):
+ _ = R.assert_op(condition)
+ return x
+
+ condition = tvm.nd.array(np.array(False))
+ x = tvm.nd.array(np.array(5).astype("int32"))
+ with pytest.raises(AssertionError):
+ run_cpu(func, condition, x)
+
+
+def test_assert_on_symbolic_var_passes():
+ @R.function(pure=False)
+ def func(x: R.Tensor(["N"], "int32")):
+ N = T.int64()
+ _ = R.assert_op(R.prim_value(N % 8 == 0))
+ return x
+
+ x = tvm.nd.array(np.arange(8, dtype="int32"))
+ run_cpu(func, x)
+
+
+def test_assert_on_symbolic_var_fails():
+ @R.function(pure=False)
+ def func(x: R.Tensor(["N"], "int32")):
+ N = T.int64()
+ _ = R.assert_op(R.prim_value(N % 8 == 0))
+ return x
+
+ x = tvm.nd.array(np.arange(10, dtype="int32"))
+ with pytest.raises(AssertionError):
+ run_cpu(func, x)
@tvm.script.ir_module
@@ -370,5 +418,60 @@ def test_op_to_vdevice():
assert (copy_found.numpy() == arr).all()
+def test_scalar_tensor_as_branch_condition():
+ """The condition of a branch may be a scalar tensor"""
+
+ @R.function
+ def func(condition: R.Tensor((), "bool")):
+ if condition:
+ out = R.prim_value(5)
+ else:
+ out = R.prim_value(10)
+ return out
+
+ res = run_cpu(func, tvm.nd.array(np.array(True)))
+ assert res == 5
+
+ res = run_cpu(func, tvm.nd.array(np.array(False)))
+ assert res == 10
+
+
+def test_prim_value_as_branch_condition():
+ """The condition may be a PrimValue"""
+
+ @R.function
+ def func(condition: R.Prim("bool")):
+ if condition:
+ out = R.prim_value(5)
+ else:
+ out = R.prim_value(10)
+ return out
+
+ res = run_cpu(func, True)
+ assert res == 5
+
+ res = run_cpu(func, False)
+ assert res == 10
+
+
+def test_computed_prim_value_as_branch_condition():
+ """The R.Prim condition may be computed within the function"""
+
+ @R.function
+ def func(x: R.Tensor(["N"], "int64")):
+ N = T.int64()
+ if R.prim_value(N % 16 == 0):
+ out = R.prim_value(5)
+ else:
+ out = R.prim_value(10)
+ return out
+
+ res = run_cpu(func, tvm.nd.array(np.arange(16)))
+ assert res == 5
+
+ res = run_cpu(func, tvm.nd.array(np.arange(20)))
+ assert res == 10
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform.py
b/tests/python/relax/test_transform.py
index 9ab2ffc605..7fbf9a2da1 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -343,14 +343,18 @@ def test_call_tir_inplace_multiple_args():
@tvm.script.ir_module
class Expected:
@T.prim_func
- def copy(A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32")):
+ def copy(
+ A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C:
T.Buffer((2, 3), "int32")
+ ):
+ # copies the contents of C into A and B
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_zeros"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
- T.reads(B[ax0, ax1])
- T.writes(A[ax0, ax1])
- A[ax0, ax1] = B[ax0, ax1]
+ T.reads(C[ax0, ax1])
+ T.writes(A[ax0, ax1], B[ax0, ax1])
+ A[ax0, ax1] = C[ax0, ax1]
+ B[ax0, ax1] = C[ax0, ax1]
@R.function
def foo(
diff --git a/tests/python/relax/test_transform_compute_prim_value.py
b/tests/python/relax/test_transform_compute_prim_value.py
new file mode 100644
index 0000000000..9fee35414d
--- /dev/null
+++ b/tests/python/relax/test_transform_compute_prim_value.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I, relax as R, tir as T
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+ transform = tvm.relax.transform.ComputePrimValue()
+
+
+class TestPrimValueInAssertCondition(BaseCompare):
+ @I.ir_module
+ class Before:
+ @R.function(pure=False)
+ def main(A: R.Tensor(["N"])):
+ N = T.int64()
+ _ = R.assert_op(N % 16 == 0)
+ return A
+
+ @I.ir_module
+ class Expected:
+ @R.function(pure=False)
+ def main(A: R.Tensor(["N"])):
+ N = T.int64()
+ condition: R.Prim("bool") =
Expected.compute_symbolic_expr(R.prim_value(N))
+ _ = R.assert_op(condition)
+ return A
+
+ @T.prim_func(private=True)
+ def compute_symbolic_expr(N: T.int64) -> T.bool:
+ T.ret(N % 16 == 0)
+
+
+class TestPrimValueInBranchCondition(BaseCompare):
+ @I.ir_module
+ class Before:
+ @R.function(pure=False)
+ def main(A: R.Tensor(["N"])):
+ N = T.int64()
+ if R.prim_value(N % 16 == 0):
+ out = R.call_packed("fast_vectorized_impl", A,
sinfo_args=[A.struct_info])
+ else:
+ out = R.call_packed("slow_non_vectorized_impl", A,
sinfo_args=[A.struct_info])
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function(pure=False)
+ def main(A: R.Tensor(["N"])):
+ N = T.int64()
+ condition: R.Prim("bool") =
Expected.compute_symbolic_expr(R.prim_value(N))
+ if condition:
+ out = R.call_packed("fast_vectorized_impl", A,
sinfo_args=[A.struct_info])
+ else:
+ out = R.call_packed("slow_non_vectorized_impl", A,
sinfo_args=[A.struct_info])
+ return out
+
+ @T.prim_func(private=True)
+ def compute_symbolic_expr(N: T.int64) -> T.bool:
+ T.ret(N % 16 == 0)
+
+
+class TestPrimValueInPureFunction(BaseCompare):
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(_N: R.Prim(value="N"), _M: R.Prim(value="M")) ->
R.Prim(value="N*M"):
+ N = T.int64()
+ M = T.int64()
+ out = R.prim_value(N * M)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(_N: R.Prim(value="N"), _M: R.Prim(value="M")) ->
R.Prim(value="N*M"):
+ N = T.int64()
+ M = T.int64()
+ out = Expected.compute_symbolic_expr(R.prim_value(N),
R.prim_value(M))
+ return out
+
+ @T.prim_func(private=True)
+ def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64:
+ T.ret(N * M)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 2221cb89eb..c8db26c81b 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1261,6 +1261,149 @@ def test_if_branch_var_scope():
return w
+def test_scalar_tensor_as_branch_condition():
+ """Branch condition can be 0-d tensor"""
+
+ @R.function
+ def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")):
+ if cond:
+ out = R.add(x, x)
+ else:
+ out = R.multiply(x, x)
+ return out
+
+ if_else = func.body.blocks[0].bindings[0].value
+ assert isinstance(if_else.cond, relax.Var)
+ tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Tensor([],
"bool"))
+
+
+def test_prim_value_as_branch_condition():
+ """In addition to scalar tensor, can use R.Prim condition"""
+
+ @R.function
+ def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")):
+ if cond:
+ out = R.add(x, x)
+ else:
+ out = R.multiply(x, x)
+ return out
+
+ if_else = func.body.blocks[0].bindings[0].value
+ assert isinstance(if_else.cond, relax.Var)
+ tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim("bool"))
+
+
+def test_computed_prim_value_as_branch_condition():
+ """The R.Prim condition may be computed within the function"""
+
+ @R.function
+ def func(x: R.Tensor(["N"], "float32")):
+ N = T.int64()
+ if R.prim_value(N % 16 == 0):
+ out = R.call_pure_packed("fast_vectorized_impl", x,
sinfo_args=[x.struct_info])
+ else:
+ out = R.call_pure_packed("slow_non_vectorized_impl", x,
sinfo_args=[x.struct_info])
+ return out
+
+ N = func.params[0].struct_info.shape[0]
+ if_else = func.body.blocks[0].bindings[0].value
+ assert isinstance(if_else.cond, relax.PrimValue)
+ tvm.ir.assert_structural_equal(N % 16 == 0, if_else.cond.value)
+ tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim(value=N %
16 == 0))
+
+
+def test_tir_expr_as_branch_condition():
+ """Syntactic sugar, wrap PrimExpr as PrimValue"""
+
+ @R.function(private=True)
+ def sugared(x: R.Tensor(["N"], "float32")):
+ N = T.int64()
+ if N % 16 == 0:
+ out = R.call_pure_packed("fast_vectorized_impl", x,
sinfo_args=[x.struct_info])
+ else:
+ out = R.call_pure_packed("slow_non_vectorized_impl", x,
sinfo_args=[x.struct_info])
+ return out
+
+ @R.function(private=True)
+ def unsugared(x: R.Tensor(["N"], "float32")):
+ N = T.int64()
+ if R.prim_value(N % 16 == 0):
+ out = R.call_pure_packed("fast_vectorized_impl", x,
sinfo_args=[x.struct_info])
+ else:
+ out = R.call_pure_packed("slow_non_vectorized_impl", x,
sinfo_args=[x.struct_info])
+ return out
+
+ tvm.ir.assert_structural_equal(unsugared, sugared)
+
+
+def test_scalar_tensor_as_assert_condition():
+ """Branch condition can be 0-d tensor"""
+
+ @R.function(pure=False)
+ def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")):
+ _ = R.assert_op(cond)
+ out = R.add(x, x)
+ return out
+
+ assert_op = func.body.blocks[0].bindings[0].value
+ condition = assert_op.args[0]
+ assert isinstance(condition, relax.Var)
+ tvm.ir.assert_structural_equal(condition.struct_info, R.Tensor([], "bool"))
+
+
+def test_prim_value_as_assert_condition():
+ """In addition to scalar tensor, can use R.Prim condition"""
+
+ @R.function(pure=False)
+ def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")):
+ _ = R.assert_op(cond)
+ out = R.add(x, x)
+ return out
+
+ assert_op = func.body.blocks[0].bindings[0].value
+ condition = assert_op.args[0]
+ assert isinstance(condition, relax.Var)
+ tvm.ir.assert_structural_equal(condition.struct_info, R.Prim("bool"))
+
+
+def test_computed_prim_value_as_assert_condition():
+ """The R.Prim condition may be computed within the function"""
+
+ @R.function(pure=False)
+ def func(x: R.Tensor(["N"], "float32")):
+ N = T.int64()
+ _ = R.assert_op(R.prim_value(N % 16 == 0))
+ out = R.call_packed("fast_vectorized_impl", x,
sinfo_args=[x.struct_info])
+ return out
+
+ N = func.params[0].struct_info.shape[0]
+ assert_op = func.body.blocks[0].bindings[0].value
+ condition = assert_op.args[0]
+ assert isinstance(condition, relax.PrimValue)
+ tvm.ir.assert_structural_equal(N % 16 == 0, condition.value)
+ tvm.ir.assert_structural_equal(condition.struct_info, R.Prim(value=N % 16
== 0))
+
+
+def test_tir_expr_as_assert_condition():
+ """Syntactic sugar, wrap PrimExpr as PrimValue"""
+
+ @R.function(pure=False, private=True)
+ def sugared(x: R.Tensor(["N"], "float32")):
+ N = T.int64()
+ _ = R.assert_op(N % 16 == 0)
+ out = R.call_packed("fast_vectorized_impl", x,
sinfo_args=[x.struct_info])
+ return out
+
+ @R.function(pure=False, private=True)
+ def unsugared(x: R.Tensor(["N"], "float32")):
+ N = T.int64()
+ _ = R.assert_op(R.prim_value(N % 16 == 0))
+ out = R.call_packed("fast_vectorized_impl", x,
sinfo_args=[x.struct_info])
+ return out
+
+ tvm.ir.assert_structural_equal(unsugared, sugared)
+
+
def test_erase_to_well_defined_removes_internal_vars():
@R.function
def foo(x: R.Tensor):
@@ -1664,9 +1807,9 @@ def test_context_aware_parsing():
class Module:
@T.prim_func
def add(
- X: T.Buffer(T.int64(8), "float32"),
+ X: T.Buffer([T.int64(2), T.int64(4)], "float32"),
Y: T.Buffer((), "float32"),
- Z: T.Buffer(T.int64(8), "float32"),
+ Z: T.Buffer([T.int64(2), T.int64(4)], "float32"),
):
T.evaluate(0)
diff --git a/tests/python/relax/test_vm_codegen_tir.py
b/tests/python/relax/test_vm_codegen_tir.py
index 21e192955b..9a4817f5fd 100644
--- a/tests/python/relax/test_vm_codegen_tir.py
+++ b/tests/python/relax/test_vm_codegen_tir.py
@@ -72,7 +72,7 @@ def test_tir_call():
H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
@R.function(pure=False)
- def foo(x: R.Tensor):
+ def foo(x: R.Tensor([4], "int64")):
R.func_attr({"global_symbol": "foo"})
_ = Before.shape_func(x)
return x
diff --git a/tests/python/tir-analysis/test_tir_analysis_is_pure_function.py
b/tests/python/tir-analysis/test_tir_analysis_is_pure_function.py
new file mode 100644
index 0000000000..6555ae3f77
--- /dev/null
+++ b/tests/python/tir-analysis/test_tir_analysis_is_pure_function.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm.testing
+from tvm.script import tir as T
+
+from tvm.tir.analysis import is_pure_function, assert_pure_function
+
+
+class CheckPureFunction:
+ def test_check_purity(self):
+ assert is_pure_function(self.func)
+
+ def test_assert_purity(self):
+ assert_pure_function(self.func)
+
+
+class CheckImpureFunction:
+ def test_check_purity(self):
+ assert not is_pure_function(self.func)
+
+ def test_assert_purity(self):
+ with pytest.raises(AssertionError):
+ assert_pure_function(self.func)
+
+
+class TestNoOp(CheckPureFunction):
+ @T.prim_func
+ def func():
+ pass
+
+
+class TestReturnValue(CheckPureFunction):
+ @T.prim_func
+ def func() -> T.int32:
+ T.ret(42)
+
+
+class TestComputeValueAndReturn(CheckPureFunction):
+ @T.prim_func
+ def func(N: T.int32, M: T.int32) -> T.int32:
+ T.ret(N * M)
+
+
+class TestReadBufferArgument(CheckPureFunction):
+ @T.prim_func
+ def func(A: T.Buffer(16, "float32")) -> T.float32:
+ T.ret(A[0])
+
+
+class TestWriteToBufferArgument(CheckImpureFunction):
+ @T.prim_func
+ def func(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
+ for i in range(16):
+ B[i] = A[i]
+
+
+class TestWriteToInternalAllocation(CheckPureFunction):
+ @T.prim_func
+ def func(A: T.Buffer([16, 16], "float32")) -> T.float32:
+ Sum = T.decl_buffer([], "float32")
+ Sum[()] = 0.0
+ for i, j in T.grid(16, 16):
+ Sum[()] = Sum[()] + A[i, j]
+
+ T.ret(Sum[()])
+
+
+class TestCallPureBuiltin(CheckPureFunction):
+ @T.prim_func
+ def func(x: T.float32) -> T.float32:
+ T.ret(T.cos(x))
+
+
+class TestCallPureExtern(CheckPureFunction):
+ @T.prim_func
+ def func():
+ T.call_pure_extern("some_pure_extern_func_name", dtype="void")
+
+
+class TestCallImpureExtern(CheckImpureFunction):
+ @T.prim_func
+ def func():
+ T.call_extern("some_impure_extern_func_name", dtype="void")
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/tir-base/test_tir_specialize.py
b/tests/python/tir-base/test_tir_specialize.py
index 0422887233..cead775e97 100644
--- a/tests/python/tir-base/test_tir_specialize.py
+++ b/tests/python/tir-base/test_tir_specialize.py
@@ -330,12 +330,11 @@ def test_specialize_buffer_var_to_expr():
tvm.ir.assert_structural_equal(expected, after)
-def test_specialization_removes_struct_info():
- """Reset struct info in specialization
+def test_specialization_updates_struct_info():
+ """Update struct info in specialization
- While a PrimFunc usually doesn't have a `relax.StructInfo`, the
- field can be populated in some edge cases. If that PrimFunc is
- specialized, the struct info should be reset.
+ A PrimFunc may have a `relax.StructInfo`. If that PrimFunc is
+ specialized, the struct info should be updated.
"""
@T.prim_func(private=True)
@@ -346,24 +345,20 @@ def test_specialization_removes_struct_info():
def expected() -> T.int32:
T.ret(50)
- sinfo = tvm.relax.FuncStructInfo(
+ sinfo_before = tvm.relax.FuncStructInfo(
[tvm.relax.PrimStructInfo("int32")], tvm.relax.PrimStructInfo("int32")
)
- tvm.relax.expr._update_struct_info(before, sinfo)
+ tvm.ir.assert_structural_equal(before.struct_info, sinfo_before)
+
+ sinfo_expected = tvm.relax.FuncStructInfo([],
tvm.relax.PrimStructInfo("int32"))
+ tvm.ir.assert_structural_equal(expected.struct_info, sinfo_expected)
n = before.params[0]
param_map = {n: 5}
after = before.specialize(param_map)
- tvm.ir.assert_structural_equal(expected, after)
- assert before.struct_info is not None
-
- # PrimFuncs do not expose the `struct_info_` field. Checking the
- # `struct_info` field when it isn't set raises an exception. This
- # is the desired behavior, since the struct info before
- # specialization is no longer valid.
- with pytest.raises(tvm.TVMError):
- after.struct_info
+ tvm.ir.assert_structural_equal(after, expected)
+ tvm.ir.assert_structural_equal(after.struct_info, sinfo_expected)
if __name__ == "__main__":
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 074603681f..465ffa5cb6 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -340,5 +340,114 @@ def test_thread_binding_dtype():
assert loop_j.thread_binding.var.dtype == "int32"
+def test_inferred_sinfo_with_prim_args():
+ """A PrimFunc may have inferred StructInfo"""
+
+ @T.prim_func
+ def func(M: T.int32, N: T.int32) -> T.int32:
+ T.ret(M * N)
+
+ expected = tvm.relax.FuncStructInfo(
+ [
+ tvm.relax.PrimStructInfo("int32"),
+ tvm.relax.PrimStructInfo("int32"),
+ ],
+ tvm.relax.PrimStructInfo("int32"),
+ purity=True,
+ )
+ tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
+def test_inferred_sinfo_with_buffer_args():
+ """PrimFunc buffer arguments are inferred as R.Tensor"""
+
+ @T.prim_func
+ def func(A: T.Buffer([16, 16], "float32"), B: T.Buffer([256], "int32")) ->
T.float32:
+ T.ret(T.float32(42.0))
+
+ expected = tvm.relax.FuncStructInfo(
+ [
+ tvm.relax.TensorStructInfo([16, 16], "float32"),
+ tvm.relax.TensorStructInfo([256], "int32"),
+ ],
+ tvm.relax.PrimStructInfo("float32"),
+ purity=True,
+ )
+ tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
+def test_inferred_sinfo_with_internal_allocation():
+ """A pure function may still write to internal allocations.
+
+ Whether a function writes to internal allocations is not a visible
+ effect, and does not impact the purity of a function.
+ """
+
+ @T.prim_func
+ def func(A: T.Buffer([16, 16], "float32")) -> T.float32:
+ Sum = T.decl_buffer([], "float32")
+ Sum[()] = 0.0
+ for i, j in T.grid(16, 16):
+ Sum[()] = Sum[()] + A[i, j]
+
+ T.ret(Sum[()])
+
+ expected = tvm.relax.FuncStructInfo(
+ [
+ tvm.relax.TensorStructInfo([16, 16], "float32"),
+ ],
+ tvm.relax.PrimStructInfo("float32"),
+ purity=True,
+ )
+ tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
+def test_inferred_sinfo_with_output_buffer():
+ """A pure function may not write to an argument buffer
+
+ If an argument buffer is written to, the function must be impure.
+ """
+
+ @T.prim_func
+ def func(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
+ for i in range(16):
+ B[i] = A[i]
+
+ expected = tvm.relax.FuncStructInfo(
+ [
+ tvm.relax.TensorStructInfo([16], "float32"),
+ tvm.relax.TensorStructInfo([16], "float32"),
+ ],
+ tvm.relax.TupleStructInfo([]),
+ purity=False,
+ )
+ tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
+def test_inferred_sinfo_with_dynamic_buffer():
+ """The inferred StructInfo may contain dynamic shapes"""
+
+ @T.prim_func
+ def func(a_handle: T.handle, b_handle: T.handle):
+ M = T.int64()
+ N = T.int64()
+ A = T.match_buffer(a_handle, [M, N], "float32")
+ B = T.match_buffer(b_handle, [M * N], "float32")
+ for i, j in T.grid(M, N):
+ B[i * N + j] = A[i, j]
+
+ M = tvm.tir.Var("M", "int64")
+ N = tvm.tir.Var("N", "int64")
+ expected = tvm.relax.FuncStructInfo(
+ [
+ tvm.relax.TensorStructInfo([M, N], "float32"),
+ tvm.relax.TensorStructInfo([M * N], "float32"),
+ ],
+ tvm.relax.TupleStructInfo([]),
+ purity=False,
+ )
+ tvm.ir.assert_structural_equal(func.struct_info, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()