electriclilies commented on a change in pull request #9693:
URL: https://github.com/apache/tvm/pull/9693#discussion_r767017167



##########
File path: src/printer/text_printer.cc
##########
@@ -36,49 +37,71 @@ static const char* kSemVer = "0.0.5";
 Doc TextPrinter::PrintMod(const IRModule& mod) {
   Doc doc;
   int counter = 0;
+
+  // We'll print in alphabetical order to make a/b diffs easier to work with.

Review comment:
       Nice!

##########
File path: src/relay/op/call/call.cc
##########
@@ -63,23 +63,23 @@ bool CallLoweredRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs
 
 const Op& CallLoweredOp() { return Op::Get("call_lowered"); }
 
-Expr CallLowered(Expr func, Array<Expr> inputs, Attrs attrs, Array<Type> 
type_args, Span span) {
-  // Right now, call_lowered only supports func being a global var pointing to 
the lowered
-  // function.
-  ICHECK(func.as<GlobalVarNode>())
-      << "Function to call should be GlobalVarNode, but got:" << std::endl
-      << PrettyPrint(func);
-  ICHECK(attrs.as<CallLoweredAttrs>())
-      << "Expected attributes to be CallLoweredAttrs, but got " << 
attrs->GetTypeKey();
-  return Call(CallLoweredOp(), {std::move(func), Tuple(std::move(inputs))}, 
std::move(attrs),
-              std::move(type_args), std::move(span));
+Call CallLowered(GlobalVar lowered_func, Array<Expr> args, CallLoweredAttrs 
call_lowered_attrs,
+                 Span span) {
+  auto attrs = make_object<CallLoweredAttrs>(std::move(call_lowered_attrs));

Review comment:
       What's the difference between make_object<CallLoweredAttrs> and just 
declaring a new CallLoweredAttrs?

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -336,10 +337,9 @@ class VMFunctionCompiler : 
DeviceAwareExprFunctor<void(const Expr& n)> {
    * in emitted code. Note that the host device is always at index 0.
    */
   Index GetDeviceIndex(const SEScope& se_scope) {
-    VLOG(2) << "getting device index for " << se_scope;
+    ICHECK(!se_scope->IsFullyUnconstrained());

Review comment:
       This will be helpful for debugging, thx!

##########
File path: src/relay/op/memory/on_device.cc
##########
@@ -45,53 +43,74 @@ const Op& OnDeviceOp() {
   return op;
 }
 
-Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed) {
-  ICHECK(!se_scope->IsFullyUnconstrained());
+Call OnDevice(Expr body, SEScope se_scope, bool constrain_result, bool 
constrain_body) {
+  ICHECK((!constrain_result && !constrain_body) || 
!se_scope->IsFullyUnconstrained());
   auto attrs = make_object<OnDeviceAttrs>();
-  attrs->se_scope = std::move(se_scope);
-  attrs->is_fixed = is_fixed;
-  Span span = expr->span;
-  return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), 
/*type_args=*/{},
+  attrs->se_scope =
+      (constrain_result || constrain_body) ? std::move(se_scope) : 
SEScope::FullyUnconstrained();
+  attrs->constrain_result = constrain_result;
+  attrs->constrain_body = constrain_body;
+  Span span = body->span;  // about to be moved

Review comment:
       Is there any difference between putting body->span in the constructor 
directly and what you've done here (creating a new variable and std::moving it 
into the constructor)? (Not asking for a change, just curious)

##########
File path: include/tvm/relay/attrs/on_device.h
##########
@@ -54,44 +54,48 @@ namespace relay {
  *   multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z)
  * \endcode
  *
- * The Relay call
- * \code
- *   on_device(sub_expr, se_scope=S, is_fixed=True)
- * \endcode
- * is similar to the above, however the annotation itself must appear in an 
expression on the
- * same \p SEScope \p S. The compiler will check the \p SEScopes are 
consistent, and will not
- * insert any "device_copy" call. This form of annotation shouldn't be 
necessary in user programs.
- * However it is needed by the \p PlanDevices pass to fully specify the 
results of device planning
- * so that the pass is idempotent.
- *
- * E.g.: The following program is equivalent to the above:
- * \code
- *   let %a = on_device(add(%x, %y), se_scope=GPU, is_fixed=True)
- *   multiply(device_copy(%a, src_se_scope=GPU, dst_se_scope=CPU), %z)
- * \endcode
- * The "on_device" annotation with \p is_fixed=True indicates unambiguously 
that \p %a is stored
- * on the GPU.
+ * The \p constraint_body (default true) and \p constraint_result (default 
false) fields can be
+ * used by passes for finer-grained control over how the \p SEScope constraint 
should be applied.
  */
 struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
   /*!
-   * \brief (Virtual) \p SEScope on which the result of the argument 
expression should be stored.
+   * \brief The \p SEScope to constraint to apply to the body, result, or both 
body and result
+   * of the "on_device" call.
    */
   SEScope se_scope = SEScope::FullyUnconstrained();
+
+  /*!
+   * \brief If fales (the default), the result of the "on_device" call is not 
constrained to be

Review comment:
       fales => false

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -844,7 +932,8 @@ class DeviceCapturer : public ExprMutator {
         // match.
         return VisitExpr(device_copy_props.body);
       } else {
-        return VisitChild(/*lexical_se_scope=*/dst_se_scope,
+        return VisitChild(/*lexical_se_scope=*/
+                          dst_se_scope,

Review comment:
       unneeded whitespace change

##########
File path: src/relay/op/memory/on_device.h
##########
@@ -39,14 +39,50 @@ namespace relay {
 const Op& OnDeviceOp();
 
 /*!
- * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p 
is_fixed.
+ * \brief Wraps \p body in an "on_device" CallNode for \p se_scope.
  *
  * See \p OnDeviceAttrs for an overview.
  */
-Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed);
+Call OnDevice(Expr body, SEScope se_scope, bool constrain_result = false,
+              bool constrain_body = true);
+
+/*! \brief Result of \p GetOnDeviceProps. */
+struct OnDeviceProps {
+  Expr body;  // = null
+  SEScope se_scope = SEScope::FullyUnconstrained();
+  bool constrain_result = false;
+  bool constrain_body = false;
+
+  OnDeviceProps() = default;
+
+  OnDeviceProps(Expr body, SEScope se_scope, bool constrain_result, bool 
constrain_body)
+      : body(std::move(body)),
+        se_scope(std::move(se_scope)),
+        constrain_result(constrain_result),
+        constrain_body(constrain_body) {}
+
+  bool is_fixed() const { return constrain_result && constrain_body; }
+  bool is_normal() const { return !constrain_result && constrain_body; }
+};
+
+/*!
+ * \brief As for OnDevice, but taking all fields other than \p body from \p 
props.
+ */

Review comment:
       Can you update this comment to something like:
   Wrap body with OnDevice, but take all fields other than body from props 
   
   (I found "As for OnDevice" confusing)

##########
File path: src/relay/op/memory/on_device.cc
##########
@@ -45,53 +43,74 @@ const Op& OnDeviceOp() {
   return op;
 }
 
-Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed) {
-  ICHECK(!se_scope->IsFullyUnconstrained());
+Call OnDevice(Expr body, SEScope se_scope, bool constrain_result, bool 
constrain_body) {
+  ICHECK((!constrain_result && !constrain_body) || 
!se_scope->IsFullyUnconstrained());
   auto attrs = make_object<OnDeviceAttrs>();
-  attrs->se_scope = std::move(se_scope);
-  attrs->is_fixed = is_fixed;
-  Span span = expr->span;
-  return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), 
/*type_args=*/{},
+  attrs->se_scope =
+      (constrain_result || constrain_body) ? std::move(se_scope) : 
SEScope::FullyUnconstrained();
+  attrs->constrain_result = constrain_result;
+  attrs->constrain_body = constrain_body;
+  Span span = body->span;  // about to be moved
+  return Call(OnDeviceOp(), {std::move(body)}, Attrs(std::move(attrs)), 
/*type_args=*/{},
               std::move(span));
 }
 
 
TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice);
 
-Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed) {
+Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool 
constrain_body) {
   if (se_scope->IsFullyUnconstrained()) {
     // Nothing to annotate with.
-    return expr;
+    return body;
   }
-  if (expr->IsInstance<OpNode>() || expr->IsInstance<ConstructorNode>()) {
+  if (body->IsInstance<OpNode>() || body->IsInstance<ConstructorNode>()) {
     // These operators are device polymorphic so no annotation is required.
-    return expr;
+    return body;
   }
-  if (expr->IsInstance<GlobalVarNode>() || expr->IsInstance<VarNode>()) {
+  if (body->IsInstance<GlobalVarNode>() || body->IsInstance<VarNode>()) {
     // The device can be recovered from the binding site of the global or 
local variable.
-    return expr;
+    return body;
   }
-  if (expr->IsInstance<FunctionNode>()) {
+  if (body->IsInstance<FunctionNode>()) {
     // If a primitive function then it is device polymorphic. Otherwise the 
device is captured
     // by the function's "result_se_scope" attribute.
-    return expr;
+    return body;
   }
-  OnDeviceProps props = GetOnDeviceProps(expr);
+  OnDeviceProps props = GetOnDeviceProps(body);
   if (props.body.defined()) {
-    // Don't nest on_devices.
-    // If the inner and outer device types differ then we need to be careful:
-    //  - If the inner on_device is_fixed then it disagrees with the outer.
-    //  - If the outer on_device is_fixed then it implies a hidden device_copy
-    // Otherwise just use the inner device type and ignore the outer.
-    ICHECK(props.se_scope == se_scope || (!is_fixed && !props.is_fixed));
-    return OnDevice(props.body, se_scope, is_fixed || props.is_fixed);
+    // The user is asking for
+    //   on_device(on_device(body, se_scope=inner), se_scope=outer)
+    //   ^         ^         ^
+    //   outer     middle    inner
+    // First recover the implied constraints (if any) for outer and inner, and 
check they don't
+    // contradict.
+    const SEScope& inner = props.se_scope;
+    const SEScope& outer = se_scope;
+    bool constrain_outer = constrain_result;
+    bool constrain_inner = props.constrain_body;
+    if (constrain_outer && constrain_inner) {
+      ICHECK(inner == outer)
+          << "Cannot constrain result and body of nested on_device calls to 
different SEScopes";
+    }
+    // There are two possible ways the middle sub-expression may be 
constrained, check they don't
+    // contradict.
+    bool constrain_middle_via_outer = constrain_body;
+    bool constrain_middle_via_inner = props.constrain_result;
+    if (constrain_middle_via_outer && constrain_middle_via_inner) {
+      ICHECK(inner == outer)
+          << "Cannot constrain intermediate result of nested on_device calls 
to different SEScopes";
+    }
+    // We can now ignore the intermediate constraints, if any.
+    return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : 
inner,

Review comment:
       If the inner is constrained, we return outer?

##########
File path: src/relay/op/memory/on_device.h
##########
@@ -39,14 +39,50 @@ namespace relay {
 const Op& OnDeviceOp();
 
 /*!
- * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p 
is_fixed.
+ * \brief Wraps \p body in an "on_device" CallNode for \p se_scope.
  *
  * See \p OnDeviceAttrs for an overview.
  */
-Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed);
+Call OnDevice(Expr body, SEScope se_scope, bool constrain_result = false,
+              bool constrain_body = true);
+
+/*! \brief Result of \p GetOnDeviceProps. */
+struct OnDeviceProps {
+  Expr body;  // = null
+  SEScope se_scope = SEScope::FullyUnconstrained();
+  bool constrain_result = false;
+  bool constrain_body = false;
+
+  OnDeviceProps() = default;
+
+  OnDeviceProps(Expr body, SEScope se_scope, bool constrain_result, bool 
constrain_body)
+      : body(std::move(body)),
+        se_scope(std::move(se_scope)),
+        constrain_result(constrain_result),
+        constrain_body(constrain_body) {}
+
+  bool is_fixed() const { return constrain_result && constrain_body; }
+  bool is_normal() const { return !constrain_result && constrain_body; }
+};
+
+/*!
+ * \brief As for OnDevice, but taking all fields other than \p body from \p 
props.
+ */
+inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) {
+  return OnDevice(std::move(body), props.se_scope, props.constrain_result, 
props.constrain_body);
+}
 
 /*!
- * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p 
is_fixed if the
+ * \brief As for OnDevice, but don't constrain the body or result to any 
particular virtual device.

Review comment:
       again, "As for OnDevice" is a bit confusing




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to