mbs-octoml commented on a change in pull request #9693:
URL: https://github.com/apache/tvm/pull/9693#discussion_r768007923
##########
File path: src/relay/op/memory/on_device.cc
##########
@@ -45,53 +43,74 @@ const Op& OnDeviceOp() {
return op;
}
-Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed) {
- ICHECK(!se_scope->IsFullyUnconstrained());
+Call OnDevice(Expr body, SEScope se_scope, bool constrain_result, bool
constrain_body) {
+ ICHECK((!constrain_result && !constrain_body) ||
!se_scope->IsFullyUnconstrained());
auto attrs = make_object<OnDeviceAttrs>();
- attrs->se_scope = std::move(se_scope);
- attrs->is_fixed = is_fixed;
- Span span = expr->span;
- return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)),
/*type_args=*/{},
+ attrs->se_scope =
+ (constrain_result || constrain_body) ? std::move(se_scope) :
SEScope::FullyUnconstrained();
+ attrs->constrain_result = constrain_result;
+ attrs->constrain_body = constrain_body;
+ Span span = body->span; // about to be moved
+ return Call(OnDeviceOp(), {std::move(body)}, Attrs(std::move(attrs)),
/*type_args=*/{},
std::move(span));
}
TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice);
-Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed) {
+Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool
constrain_body) {
if (se_scope->IsFullyUnconstrained()) {
// Nothing to annotate with.
- return expr;
+ return body;
}
- if (expr->IsInstance<OpNode>() || expr->IsInstance<ConstructorNode>()) {
+ if (body->IsInstance<OpNode>() || body->IsInstance<ConstructorNode>()) {
// These operators are device polymorphic so no annotation is required.
- return expr;
+ return body;
}
- if (expr->IsInstance<GlobalVarNode>() || expr->IsInstance<VarNode>()) {
+ if (body->IsInstance<GlobalVarNode>() || body->IsInstance<VarNode>()) {
// The device can be recovered from the binding site of the global or
local variable.
- return expr;
+ return body;
}
- if (expr->IsInstance<FunctionNode>()) {
+ if (body->IsInstance<FunctionNode>()) {
// If a primitive function then it is device polymorphic. Otherwise the
device is captured
// by the function's "result_se_scope" attribute.
- return expr;
+ return body;
}
- OnDeviceProps props = GetOnDeviceProps(expr);
+ OnDeviceProps props = GetOnDeviceProps(body);
if (props.body.defined()) {
- // Don't nest on_devices.
- // If the inner and outer device types differ then we need to be careful:
- // - If the inner on_device is_fixed then it disagrees with the outer.
- // - If the outer on_device is_fixed then it implies a hidden device_copy
- // Otherwise just use the inner device type and ignore the outer.
- ICHECK(props.se_scope == se_scope || (!is_fixed && !props.is_fixed));
- return OnDevice(props.body, se_scope, is_fixed || props.is_fixed);
+ // The user is asking for
+ // on_device(on_device(body, se_scope=inner), se_scope=outer)
+ // ^ ^ ^
+ // outer middle inner
+ // First recover the implied constraints (if any) for outer and inner, and
check they don't
+ // contradict.
+ const SEScope& inner = props.se_scope;
+ const SEScope& outer = se_scope;
+ bool constrain_outer = constrain_result;
+ bool constrain_inner = props.constrain_body;
+ if (constrain_outer && constrain_inner) {
+ ICHECK(inner == outer)
+ << "Cannot constrain result and body of nested on_device calls to
different SEScopes";
+ }
+ // There are two possible ways the middle sub-expression may be
constrained, check they don't
+ // contradict.
+ bool constrain_middle_via_outer = constrain_body;
+ bool constrain_middle_via_inner = props.constrain_result;
+ if (constrain_middle_via_outer && constrain_middle_via_inner) {
+ ICHECK(inner == outer)
+ << "Cannot constrain intermediate result of nested on_device calls
to different SEScopes";
+ }
+ // We can now ignore the intermediate constraints, if any.
+ return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer :
inner,
Review comment:
Actually it is correct -- added a comment to explain.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]