Mousius commented on a change in pull request #9077:
URL: https://github.com/apache/tvm/pull/9077#discussion_r714807766



##########
File path: include/tvm/runtime/ndarray.h
##########
@@ -38,9 +38,19 @@
 #include <vector>
 
 namespace tvm {
-namespace runtime {
 
-typedef DLDevice Device;
+// alias DLDevice
+using Device = DLDevice;
+
+// A 'null' device type, does not correspond to any DLDeviceType enum.
+// TODO(mbs): This is to help us as we transition away from representing the 
'homogenous' case
+// as a singleton target map indexed by the invalid DLDeviceType '0'.
+constexpr DLDeviceType kNullDeviceType = static_cast<DLDeviceType>(0);
+
+// An 'invalid' device type, does not correspond to any DLDeviceType enum.

Review comment:
       Should these not be added as valid `enum` variants to `dlpack.h` ? 
There's a risk here that dlpack will change and it'll become incompatible?

##########
File path: src/relay/op/memory/device_copy.h
##########
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/attrs/device_copy.h

Review comment:
       ```suggestion
    * \file relay/op/memory/device_copy.h
   ```

##########
File path: python/tvm/relay/op/annotation/annotation.py
##########
@@ -33,21 +43,26 @@ def on_device(data, device):
     device : Union[:py:class:`Device`, str]
         The device type to annotate.
 
+    is_fixed : bool
+        If true, annotation does not imply a device_copy may be inserted.
+        (This parameter is used internally by the compiler and unit tests and
+        should not need to be set in user programs.)
+
     Returns
     -------
     result : tvm.relay.Expr
         The annotated expression.
     """
-    if isinstance(device, _Device):
-        device = device.device_type
-    elif isinstance(device, str):
-        device = _nd.device(device).device_type
-    else:
-        raise ValueError(
-            "device is expected to be the type of Device or "
-            "str, but received %s" % (type(device))
-        )
-    return _make.on_device(data, device)
+    return _make.on_device(data, _device_to_int(device), is_fixed)
+
+
+# for testing only

Review comment:
       If this is for testing, can it not be moved to the relevant test file?

##########
File path: include/tvm/relay/attrs/annotation.h
##########
@@ -32,14 +32,55 @@ namespace tvm {
 namespace relay {
 
 /*!
- * \brief Options for the device annotation operators.
+ * \brief Attributes for the "on_device" operator.
+ *
+ * The relay call
+ * \code
+ *   on_device(expr, device_type=2)
+ * \endcode
+ * denotes that the result of \p expr should be stored on the device with \p 
DLDeviceType 2
+ * (i.e. \p kDLCuda). Semantically the operator is the identity function.
  */
 struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
+  // TODO(mbs): Replace device types with TargetDevice.
+  /*! \brief Device type on which argument expression should be evaluated. */
   int device_type;
+  /*!
+   * \brief If true, the result device must also be \p device_type and device 
planning should
+   * not insert any "device_copy" calls to respect this annotation.
+   *
+   * This is used by the device planning pass itself when annotating the 
planned program.
+   */
+  bool is_fixed;
 
   TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
     TVM_ATTR_FIELD(device_type)
-        .describe("The virutal device/context type that an expression is 
annotated with.")
+        .describe("The type of the virtual device which should hold the 
expression result.")
+        .set_default(0);
+    TVM_ATTR_FIELD(is_fixed)
+        .describe("If true, do not insert a \"device_copy\" call to respect 
this annotation.")
+        .set_default(false);
+  }
+};
+
+/*!
+ * \brief Attributes for Relay function definitions which capture the devices 
for the
+ * function parameters and result.
+ */
+struct FunctionOnDeviceAttrs : public tvm::AttrsNode<FunctionOnDeviceAttrs> {
+  constexpr static const char* kFunctionAttrsKey = "on_device";
+
+  /*! \brief Device type on which each of the function's arguments already 
resides. */
+  Array<Integer> param_device_types;
+  // TODO(mbs): Replace device types with TargetDevice.

Review comment:
       Do we need these `TODO`s, I'm assuming you're also tracking this work 
elsewhere?

##########
File path: src/relay/op/annotation/annotation.cc
##########
@@ -36,15 +38,40 @@
 namespace tvm {
 namespace relay {
 
-// relay.annotation.on_device
 TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);
 
+const Op& OnDeviceOp() {
+  static const Op& op = Op::Get("on_device");
+  return op;
+}
+
+Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) {
+  auto attrs = make_object<OnDeviceAttrs>();
+  attrs->device_type = device_type;
+  attrs->is_fixed = is_fixed;
+  Span span = expr->span;
+  return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), 
/*type_args=*/{}, span);
+}
+
+Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) {
+  if (device_type == kInvalidDeviceType) {
+    return expr;
+  }
+  if (expr->IsInstance<OpNode>() || expr->IsInstance<GlobalVarNode>() ||
+      expr->IsInstance<VarNode>() || expr->IsInstance<ConstructorNode>()) {

Review comment:
       May be worth documenting this condition in code:
   
   ```suggestion
     bool is_device_polymorphic = expr->IsInstance<OpNode>() || 
expr->IsInstance<ConstructorNode>();
     bool has_implied_device = expr->IsInstance<VarNode>() || 
expr->IsInstance<GlobalVarNode>();
     if (is_device_polymorphic || has_implied_device) {
   ```

##########
File path: include/tvm/relay/attrs/annotation.h
##########
@@ -32,14 +32,55 @@ namespace tvm {
 namespace relay {
 
 /*!
- * \brief Options for the device annotation operators.
+ * \brief Attributes for the "on_device" operator.
+ *
+ * The relay call
+ * \code
+ *   on_device(expr, device_type=2)
+ * \endcode
+ * denotes that the result of \p expr should be stored on the device with \p 
DLDeviceType 2
+ * (i.e. \p kDLCuda). Semantically the operator is the identity function.
  */
 struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
+  // TODO(mbs): Replace device types with TargetDevice.
+  /*! \brief Device type on which argument expression should be evaluated. */
   int device_type;
+  /*!
+   * \brief If true, the result device must also be \p device_type and device 
planning should
+   * not insert any "device_copy" calls to respect this annotation.
+   *
+   * This is used by the device planning pass itself when annotating the 
planned program.
+   */
+  bool is_fixed;
 
   TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
     TVM_ATTR_FIELD(device_type)
-        .describe("The virutal device/context type that an expression is 
annotated with.")
+        .describe("The type of the virtual device which should hold the 
expression result.")
+        .set_default(0);
+    TVM_ATTR_FIELD(is_fixed)
+        .describe("If true, do not insert a \"device_copy\" call to respect 
this annotation.")
+        .set_default(false);
+  }
+};
+
+/*!
+ * \brief Attributes for Relay function definitions which capture the devices 
for the
+ * function parameters and result.
+ */
+struct FunctionOnDeviceAttrs : public tvm::AttrsNode<FunctionOnDeviceAttrs> {
+  constexpr static const char* kFunctionAttrsKey = "on_device";

Review comment:
       Historically these would go into `include/tvm/ir/function.h` as the list 
of attributes on a function, is there a reason for them being here instead?

##########
File path: python/tvm/relay/op/annotation/annotation.py
##########
@@ -22,7 +22,17 @@
 from .. import op as reg
 
 
-def on_device(data, device):
+def _device_to_int(device):
+    if isinstance(device, _Device):
+        return device.device_type
+    if isinstance(device, str):
+        return _nd.device(device).device_type
+    raise ValueError(

Review comment:
       I think this is just refactored from the below, but do we have a test 
for `ValueError` being raised?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to