tqchen commented on code in PR #14394:
URL: https://github.com/apache/tvm/pull/14394#discussion_r1183160296
##########
src/relax/utils.cc:
##########
@@ -111,6 +113,75 @@ bool IsLeafOrTuple(const Expr& expr) {
expr.as<OpNode>() || expr.as<TupleNode>();
}
+bool IsImpureCall(const Call& call) {
+ if (auto op_ptr = call->op.as<OpNode>()) {
+ auto op = GetRef<Op>(op_ptr);
+ static auto purity_map = Op::GetAttrMap<Bool>("FPurity");
+ ICHECK(purity_map.count(op)) << "Cannot find the registered purity of this
op: " << op->name;
+ return !(purity_map[op]->value);
+ }
+ // the StructInfo must be FuncStructInfo
+ auto func_struct_info = GetStructInfoAs<FuncStructInfoNode>(call->op);
+ return !func_struct_info->purity;
+}
+
+Call WrapCallPure(const Call& call) {
Review Comment:
Same here, once we have special purity deduction, per op, likely we don;t
need to rely on wrapping unwrapping in most cases
##########
src/relax/op/op.cc:
##########
@@ -73,6 +73,40 @@ StructInfo InferStructInfoShapeOf(const Call& call, const
BlockBuilder& ctx) {
return ShapeStructInfo(tensor_shape->values);
}
+// call_pure_packed
+
+StructInfo InferStructInfoCallPurePacked(const Call& call, const BlockBuilder&
ctx) {
+ if (call->args.size() < 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "call_pure_packed must be called with at least one
argument");
+ }
+
+ // the callee must be an opaque function
+ auto callee = call->args[0];
+ ICHECK(!callee.as<OpNode>()) << "call_pure_packed cannot be used with an op
node";
+ auto opt = MatchStructInfo<FuncStructInfo>(callee);
+ ICHECK(opt) << "Callee must have a function struct info";
+ FuncStructInfo finfo = opt.value();
+ ICHECK(finfo->IsOpaque()) << "call_pure_packed must be called with an opaque
function, but "
+ << callee << " is not opaque";
+
+ // derives the struct info of the result as it would for a call to the inner
args
+ auto hypothetical_call = UnwrapCallPure(call);
Review Comment:
We can directly have a derivation rule similar to call packed instead of
relying on wrapping unwrapping
##########
src/relax/utils.cc:
##########
@@ -111,6 +113,75 @@ bool IsLeafOrTuple(const Expr& expr) {
expr.as<OpNode>() || expr.as<TupleNode>();
}
+bool IsImpureCall(const Call& call) {
+ if (auto op_ptr = call->op.as<OpNode>()) {
+ auto op = GetRef<Op>(op_ptr);
+ static auto purity_map = Op::GetAttrMap<Bool>("FPurity");
+ ICHECK(purity_map.count(op)) << "Cannot find the registered purity of this
op: " << op->name;
+ return !(purity_map[op]->value);
+ }
+ // the StructInfo must be FuncStructInfo
+ auto func_struct_info = GetStructInfoAs<FuncStructInfoNode>(call->op);
+ return !func_struct_info->purity;
+}
+
+Call WrapCallPure(const Call& call) {
+ static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
+ static const Op& call_pure_dps_packed_op =
Op::Get("relax.call_pure_dps_packed");
+ static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
+ static const Op& invoke_closure_op = Op::Get("relax.invoke_closure");
+ static const Op& invoke_pure_closure_op =
Op::Get("relax.invoke_pure_closure");
+
+ Call ret;
+ if (call->op == call_dps_packed_op) {
+ ret = std::move(Call(call_pure_dps_packed_op, call->args, call->attrs,
call->sinfo_args));
+ } else if (call->op == invoke_closure_op) {
+ ret = std::move(Call(invoke_pure_closure_op, call->args, call->attrs,
call->sinfo_args));
+ } else {
+ Array<Expr> call_args = {call->op};
+ for (auto arg : call->args) {
+ call_args.push_back(arg);
+ }
+ ret = std::move(Call(call_pure_packed_op, call_args, call->attrs,
call->sinfo_args));
+ }
+
+ // transfer over struct info if we can
+ if (call->struct_info_) {
+ UpdateStructInfo(ret, GetStructInfo(call));
+ }
+ return ret;
+}
+
+Call UnwrapCallPure(const Call& call) {
+ static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed");
Review Comment:
Now that we have purity per operator, do we still need generic unwrapping?
Likely we can simplify by directly detect and derive using is_pure attribute of
these functions
##########
include/tvm/relax/utils.h:
##########
@@ -81,6 +81,42 @@ TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool
permit_unknown_rank
*/
TVM_DLL bool IsLeafOrTuple(const Expr& expr);
+/*!
+ * \brief Check if the given Call node is an impure operation. If the callee
is a general
+ * expression, this simply requires checking the purity field of the
FuncStructInfo. If it is an Op,
+ * then this checks the `fPurity` field.
+ *
+ * \param call The input call
+ *
+ * \return True iff the call is impure (definitely or possibly results in a
visible side effect).
+ * That is, a call is considered pure only if definitely does not result in
a visible side effect.
+ */
+TVM_DLL bool IsImpureCall(const Call& call);
+
+/*!
+ * \brief Wrap the Call node in the call_pure op, transferring over the
attributes and sinfo_args.
+ *
+ * \param call The input call
+ *
+ * \return A Call to the call_pure op that wraps the original call.
+ *
+ * \note Transfers over StructInfo from the input to the return value.
+ */
+TVM_DLL Call WrapCallPure(const Call& call);
Review Comment:
How about we assume `call_dps_packed` to be pure? If we want to be extra
careful we can name as `call_dps_pure_packed`. I cannot think of a case where
impure dps call can be too useful
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]