This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b20caee2d [Bugfix] [Relay] Insertion of "device_copy" CallNode to 
Resolve Device Conflict on Unconstrained Nodes (#15090)
6b20caee2d is described below

commit 6b20caee2d4098222f7c05a894c148b09e1df911
Author: lecoan <[email protected]>
AuthorDate: Thu Jun 22 08:33:16 2023 +0800

    [Bugfix] [Relay] Insertion of "device_copy" CallNode to Resolve Device 
Conflict on Unconstrained Nodes (#15090)
    
    * Fix: add a new subpass in PlanDevice to add device_copy op for 
conflicated inputs
    
    * Fix some spelling errors in comments
    
    * Fix some spelling errors in comments
---
 src/relay/transforms/device_planner.cc       | 227 ++++++++++++++++++++++++++-
 tests/python/relay/test_pass_plan_devices.py |  47 ++++++
 2 files changed, 268 insertions(+), 6 deletions(-)

diff --git a/src/relay/transforms/device_planner.cc 
b/src/relay/transforms/device_planner.cc
index c9050c730d..80ae66ea9e 100644
--- a/src/relay/transforms/device_planner.cc
+++ b/src/relay/transforms/device_planner.cc
@@ -60,7 +60,7 @@
  *    'result_virtual_device' function attributes we introduce below. This is 
so the pass is
  * idempotent and can be re-run to flow additional memory scope constraints.
  *
- * We proceed in four phases:
+ * We proceed in five phases:
  *
  * Phase 0
  * -------
@@ -77,6 +77,13 @@
  *
  * Phase 1
  * -------
+ * We iteratively process the programs and find nodes with conflicting virtual 
devices. If the
+ * virtual devices ( \p d1 and \p d2 ) are joinable, they are replaced with a 
joined device \p d. If
+ * they are unjoinable, a "device_copy" CallNode is inserted to copy the node 
output to the second
+ * device.
+ *
+ * Phase 2
+ * -------
  * We flow constraints from the "on_device" and "device_copy" calls, PrimFunc 
buffer memory scopes,
  * and some special ops, to all other Relay sub-expressions.
  *
@@ -109,7 +116,7 @@
  * devices from their original Relay Function representations. However we know 
all calls to those
  * functions are device-consistent, thus no information is lost.
  *
- * Phase 2
+ * Phase 3
  * -------
  * After flowing constraints we apply some defaulting heuristics (using a 
global default \p
  * VirtualDevice) to fix the device for any as-yet unconstrained 
sub-expressions.
@@ -121,7 +128,7 @@
  * This requires a formal notion of 'choicepoint' inside the compiler which 
can integrate with
  * automation.
  *
- * Phase 3
+ * Phase 4
  * -------
  * Finally, the result of this analysis is reified into the result as:
  *  - Additional "param_virtual_devices" (an \p Array<VirtualDevice>) and 
"result_virtual_device"
@@ -404,6 +411,201 @@ class RewriteOnDevices : public ExprMutator {
 
 /* =============== Phase 1 =============== */
 
+/*!
+ * \brief Add "device_copy" calls for nodes that have conflicting virtual 
devices.
+ *
+ * Eg Suppose an IRModule contains the following expr:
+ * \code
+ *   %0 = add(%a, %b);
+ *   %1 = on_device(%0, virtual_device=d1);
+ *   %2 = add(%b, %c);
+ *   %3 = on_device(%2, virtual_device=d2);
+ * \endcode
+ * In the above example, node %b has two possible virtual devices: \p d1 and 
\p d2.
+ *
+ * - If \p d1 and \p d2 are joinable, replace \p d1 and \p d2 with the joined 
device \p d:
+ * \code
+ *   %0 = add(%a, %b);
+ *   %1 = on_device(%0, virtual_device=d);
+ *   %2 = add(%b, %c);
+ *   %3 = on_device(%2, virtual_device=d);
+ * \endcode
+ *
+ * - If \p d1 and \p d2 are unjoinable, insert a "device_copy" CallNode to 
copy \p %b to \p d2:
+ * \code
+ *   %0 = add(%a, %b);
+ *   %1 = on_device(%0, virtual_device=d);
+ *   %2 = device_copy(%b, src_dev_type=d1, dst_dev_type=d2);
+ *   %3 = add(%2, %c);
+ *   %4 = on_device(%3, virtual_device=d);
+ * \endcode
+ */
+struct DeviceContext {
+  VirtualDevice VirtualDeviceFor(const ExprNode* expr) {
+    auto itr = expr_to_device.find(expr);
+    if (itr != expr_to_device.end()) {
+      return itr->second;
+    }
+    auto default_dev = VirtualDevice::FullyUnconstrained();
+    expr_to_device.emplace(expr, default_dev);
+    return default_dev;
+  }
+
+  bool Update(const ExprNode* expr, VirtualDevice dev) {
+    bool success = true;
+    auto pair = expr_to_device.emplace(expr, dev);
+    if (!pair.second) {
+      auto replaced_item = pair.first;
+      auto joined_dev = VirtualDevice::Join(replaced_item->second, dev);
+      if (joined_dev == nullptr) {
+        success = false;
+      } else {
+        replaced_item->second = joined_dev.value();
+      }
+    }
+    return success;
+  }
+
+  bool IsConflicted(const ExprNode* expr) {
+    auto itr = conflicted_nodes.find(expr);
+    return itr != conflicted_nodes.end();
+  }
+
+  std::unordered_set<const ExprNode*> conflicted_nodes;
+  std::unordered_map<const ExprNode*, VirtualDevice> expr_to_device;
+};
+
+/*!
+ * \brief Flow the device constraints over the module and find all the 
conflicted nodes. The
+ * conflicted nodes only contain nodes that have no explicit constraints. For 
example, "on_device"
+ * nodes are not considered as conflicted.
+ */
+class ConflictedNodeFinder : ExprVisitor {
+ public:
+  explicit ConflictedNodeFinder(IRModule mod)
+      : mod_(std::move(mod)), dev_ctx_(std::make_unique<DeviceContext>()) {}
+
+  std::unique_ptr<DeviceContext> Finder() {
+    VLOG_CONTEXT << "ConflictedNodeFinder";
+    for (const auto& kv : mod_->functions) {
+      if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
+        VisitExpr(GetRef<Function>(function_node));
+      }
+    }
+    for (auto const node : dev_ctx_->conflicted_nodes) {
+      if (node->IsInstance<CallNode>()) {
+        auto call = Downcast<Call>(GetRef<Expr>(node));
+        // "DeviceCapturer" will insert "device_copy" for "on_device" calls.
+        // Therefore, "on_device" should not be considered as conflicted.
+        if (call->op == OnDeviceOp()) {
+          dev_ctx_->conflicted_nodes.erase(node);
+        }
+      }
+    }
+    return std::move(dev_ctx_);
+  }
+
+ private:
+  void VisitExpr_(const CallNode* call_node) final {
+    VLOG(2) << "Initial call node: " << std::endl << 
PrettyPrint(GetRef<Call>(call_node));
+    auto call_dev = dev_ctx_->VirtualDeviceFor(call_node);
+    auto body_dev = call_dev;
+
+    auto on_dev_props = GetOnDeviceProps(call_node);
+    auto dev_cp_props = GetDeviceCopyProps(call_node);
+    if (call_node->op == OnDeviceOp()) {
+      if (on_dev_props.constrain_body) {
+        body_dev = on_dev_props.virtual_device;
+      }
+      if (on_dev_props.constrain_result) {
+        call_dev = on_dev_props.virtual_device;
+      }
+    } else if (call_node->op == DeviceCopyOp()) {
+      body_dev = dev_cp_props.src_virtual_device;
+      call_dev = dev_cp_props.dst_virtual_device;
+    }
+
+    if (!dev_ctx_->Update(call_node, call_dev) && call_node->op != 
OnDeviceOp()) {
+      LOG(FATAL) << "Mismatched device type after iterating args. Implied 
device: " << std::endl
+                 << PrettyPrint(call_dev) << "and practial device:" << 
std::endl
+                 << PrettyPrint(dev_ctx_->VirtualDeviceFor(call_node)) << 
std::endl
+                 << "With CallNode: " << std::endl
+                 << PrettyPrint(GetRef<Call>(call_node));
+    }
+
+    for (auto& arg : call_node->args) {
+      VLOG(3) << "Handle call node arg: " << std::endl << PrettyPrint(arg);
+      if (!dev_ctx_->Update(arg.get(), body_dev)) {
+        VLOG(2) << "Conflicted node found:" << std::endl
+                << PrettyPrint(GetRef<Expr>(arg.get())) << std::endl
+                << "With corresponding Callee:" << std::endl
+                << PrettyPrint(GetRef<Call>(call_node));
+        dev_ctx_->conflicted_nodes.emplace(arg.get());
+      }
+    }
+    for (auto& expr : call_node->args) {
+      VisitExpr(expr);
+    }
+  }
+
+  IRModule mod_;
+  std::unique_ptr<DeviceContext> dev_ctx_;
+};
+
+/*!
+ * \brief Insert "device_copy" CallNode for all the conflicted nodes found by 
\p
+ * ConflictedNodeFinder.
+ */
+class ConflictedNodeRewriter : ExprMutator {
+ public:
+  ConflictedNodeRewriter(IRModule mod, CompilationConfig config,
+                         std::unique_ptr<DeviceContext> dev_ctx)
+      : mod_(mod), config_(config), dev_ctx_(std::move(dev_ctx)) {}
+
+  IRModule Rewrite() {
+    VLOG_CONTEXT << "ConflictedNodeRewriter";
+    IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), 
mod_->source_map,
+                    mod_->attrs);
+    for (const auto& kv : mod_->functions) {
+      if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
+        auto func = Mutate(GetRef<Function>(function_node));
+        result->Add(kv.first, Downcast<Function>(func));
+      } else {
+        result->Add(kv.first, kv.second);
+      }
+    }
+
+    return result;
+  }
+
+ private:
+  Expr VisitExpr_(const CallNode* call_node) final {
+    VLOG(3) << "Initial call node:" << std::endl << 
PrettyPrint(GetRef<Call>(call_node));
+    auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
+    tvm::Array<Expr> call_args;
+    call_args.reserve(call_node->args.size());
+    for (auto arg : call->args) {
+      if (dev_ctx_->IsConflicted(arg.get())) {
+        auto src_dev = 
config_->CanonicalVirtualDevice(dev_ctx_->VirtualDeviceFor(arg.get()));
+        auto dst_dev = 
config_->CanonicalVirtualDevice(dev_ctx_->VirtualDeviceFor(call_node));
+        call_args.push_back(MaybeDeviceCopy(arg, src_dev, dst_dev));
+        VLOG(2) << "Adding DeviceCopy Op: " << std::endl << 
PrettyPrint(call_args.back());
+      } else {
+        call_args.push_back(arg);
+      }
+    }
+    auto new_call = WithFields(GetRef<Call>(call_node), call_node->op, 
call_args);
+    VLOG(3) << "Final call node:" << std::endl << 
PrettyPrint(GetRef<Call>(call_node));
+    return new_call;
+  }
+
+  IRModule mod_;
+  CompilationConfig config_;
+  std::unique_ptr<DeviceContext> dev_ctx_;
+};
+
+/* =============== Phase 2 =============== */
+
 /*
  * \brief Collects the system of device constraints for all sub-expressions in 
a module.
  * It is possible some devices remain free and will need to be defaulted by \p 
DeviceDefaulter.
@@ -707,7 +909,7 @@ class DeviceAnalyzer : public MixedModeVisitor {
   std::unique_ptr<DeviceDomains> domains_;
 };
 
-/* =============== Phase 2 =============== */
+/* =============== Phase 3 =============== */
 
 /*!
  * \brief Calls to 'free' "on_device" annotations (ie where both 
constrain_body=false and
@@ -865,7 +1067,7 @@ class DeviceDefaulter : public ExprVisitor {
   std::unique_ptr<DeviceDomains> domains_;
 };
 
-/* =============== Phase 3 =============== */
+/* =============== Phase 4 =============== */
 /*!
  * \brief Inserts missing "device_copy" CallNodes, and ensures the device type 
of every
  * sub-expression in a module can be easily recovered by a later 
transformation using simple
@@ -1276,6 +1478,17 @@ tvm::transform::Pass Rewrite() {
   return tvm::relay::transform::CreateFunctionPass(pass_func, 0, 
"PlanDevicesRewrite", {});
 }
 
+/*! \brief Check the conflicted nodes and add "device_copy" calls. */
+tvm::transform::Pass Check(CompilationConfig config) {
+  return tvm::transform::CreateModulePass(
+      [config = std::move(config)](IRModule mod,
+                                   tvm::transform::PassContext pass_cnxt) -> 
IRModule {
+        auto dev_ctx = ConflictedNodeFinder(mod).Finder();
+        return ConflictedNodeRewriter(mod, config, 
std::move(dev_ctx)).Rewrite();
+      },
+      /*opt_level=*/0, "PlanDevicesCheckConflicts", {});
+}
+
 /*! \brief Run the remaining phases. */
 tvm::transform::Pass PlanDevicesCore(CompilationConfig config) {
   return tvm::transform::CreateModulePass(
@@ -1308,7 +1521,9 @@ tvm::transform::Pass PlanDevicesCore(CompilationConfig 
config) {
 tvm::transform::Pass PlanDevices(CompilationConfig config) {
   std::vector<Pass> passes;
   passes.emplace_back(Rewrite());
-  passes.emplace_back(PlanDevicesCore(std::move(config)));
+  passes.emplace_back(Check(config));
+  passes.emplace_back(InferType());
+  passes.emplace_back(PlanDevicesCore(config));
   return tvm::transform::Sequential(passes, "PlanDevices");
 }
 
diff --git a/tests/python/relay/test_pass_plan_devices.py 
b/tests/python/relay/test_pass_plan_devices.py
index 3ff49389cb..937ece1f82 100644
--- a/tests/python/relay/test_pass_plan_devices.py
+++ b/tests/python/relay/test_pass_plan_devices.py
@@ -1830,5 +1830,52 @@ def test_primitive():
     print(mod)
 
 
+def test_conflicated_inputs():
+    metatable = {"VirtualDevice": [CPU, GPU]}
+
+    def input():
+        return tvm.relay.parse(
+            """
+            #[version = "0.0.5"]
+            def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
+                        %c: Tensor[(5, 7), float32]) {
+                %0 = add(%a, %b);
+                %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
+                %2 = add(%b, %c);
+                %3 = on_device(%2, virtual_device=meta[VirtualDevice][1]);
+                subtract(%1, %3)
+            }
+            """,
+            "from_string",
+            None,
+            metatable,
+        )
+
+    def expected():
+        return tvm.relay.parse(
+            """
+            #[version = "0.0.5"]
+            def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 
7), float32],
+                        %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 
7), float32],
+                        %c {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 
7), float32]) {
+                %0 = add(%a, %b);
+                %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], 
constrain_result=True);
+                %2 = device_copy(%b, 
src_virtual_device=meta[VirtualDevice][0], 
dst_virtual_device=meta[VirtualDevice][1]);
+                %3 = device_copy(%1, 
src_virtual_device=meta[VirtualDevice][0], 
dst_virtual_device=meta[VirtualDevice][1]);
+                %4 = add(%2, %c);
+                subtract(%3, %4)
+            }
+            """,
+            "from_string",
+            None,
+            metatable,
+        )
+
+    def ref(a, b, c):
+        return np.subtract(np.add(a, b), np.add(b, c))
+
+    exercise(input(), expected(), ref, rands((5, 7), 3))
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to