mbs-octoml commented on code in PR #11173:
URL: https://github.com/apache/tvm/pull/11173#discussion_r865012034


##########
src/target/compilation_config.cc:
##########
@@ -53,194 +73,179 @@ VirtualDevice 
CompilationConfigNode::CanonicalVirtualDevice(
                                                     target, 
virtual_device->memory_scope));
 }
 
-void CompilationConfigNode::EstablishDefaultVirtualDevices(const 
transform::PassContext& pass_ctx) {
+void CompilationConfigNode::Init(const transform::PassContext& pass_ctx,
+                                 const Array<Target>& raw_targets) {
+  VLOG_CONTEXT << "CompilationConfig";
+  CHECK_GT(raw_targets.size(), 0U) << "Require at least one target";
+
   //
-  // Gather the hints as to what our default device type for the 'host' should 
be, and
-  // create an appropriate target if we don't already have one.
+  // Decide on the host target.
   //
-  DLDeviceType host_device_type;
-  if (host_target.defined()) {
-    CHECK(!host_target->host.defined()) << "Host targets are not expected to 
have hosts";
-    host_device_type = 
static_cast<DLDeviceType>(host_target->kind->device_type);
-    VLOG(1) << "Using the given host target " << host_target->ToDebugString() 
<< " of device type "
-            << host_device_type << " for the host target";
-    for (const auto& primitive_target : primitive_targets) {
-      if (primitive_target->host.defined() &&
-          !StructuralEqual()(primitive_target->host, host_target)) {
-        VLOG(1) << "The primitive target " << primitive_target->ToDebugString()
-                << " already has a host which disagrees with the desired host 
target. It "
-                << "will be ignored.";
-      }
-    }
-  } else if (primitive_targets.size() == 1 && 
primitive_targets.front()->host.defined()) {
-    host_target = primitive_targets.front()->GetHost().value();
-    CHECK(!host_target->host.defined()) << "Host targets are not expected to 
have hosts";
-    host_device_type = 
static_cast<DLDeviceType>(host_target->kind->device_type);
-    VLOG(1) << "Using the host of the unique primitive target, namely "
-            << host_target->ToDebugString() << " of device type " << 
host_device_type
-            << " for the host target";
-  } else if (primitive_targets.size() == 1 &&
-             primitive_targets.front()->kind->device_type == kDLCPU) {
-    // In the homogenous case without an explicit host target just use the 
given target so long as
-    // it's a CPU.
-    host_device_type = kDLCPU;
-    host_target = primitive_targets.front();
-    VLOG(1) << "Using the unique primitive target " << 
host_target->ToDebugString()
-            << " of device type " << host_device_type << " for the host 
target";
+
+  // Any CPU-like targets?
+  auto cpu_itr = std::find_if(raw_targets.begin(), raw_targets.end(), [](const 
Target& target) {
+    // TODO(tvm-team): AoT only works with kDLCPU device type. We can remove 
kDLHexagon
+    // here once we refactored kDLHexagon to kDLCPU.
+    return target->kind->device_type == kDLCPU || target->kind->device_type == 
kDLHexagon;
+  });
+
+  // Any targets with a host?
+  auto has_host_itr = std::find_if(raw_targets.begin(), raw_targets.end(),
+                                   [](const Target& target) { return 
target->host.defined(); });
+
+  if (has_host_itr != raw_targets.end()) {
+    // RULE A: If any raw target has a host, use the first such host for all 
the primitive
+    // targets.
+    host_target = Target((*has_host_itr)->GetHost().value(), 
/*host=*/Target());
+    VLOG(1) << "The target " << (*has_host_itr)->ToDebugString() << " supplies 
a host target "
+            << host_target->ToDebugString() << " of device type " << 
host_target->kind->device_type;
+  } else if (cpu_itr != raw_targets.end()) {
+    // RULE B: If any raw target is for a CPU-like device then also use that 
as the host.
+    host_target = Target(*cpu_itr, /*host=*/Target());
+    VLOG(1) << "Using target " << host_target->ToDebugString() << " of 
CPU-like device type "
+            << host_target->kind->device_type << " as the host target";
   } else {
-    // Fallback.
-    host_device_type = kDLCPU;
-    // Even if the list of available targets already includes one for kDLCPU 
we won't use it
-    // in the hetrogeneous case since its options may not be appropriate for 
host code
-    // (eg shape functions). Instead, create a fresh default Target.
-    host_target = MakeDefaultTarget(host_device_type);
-    VLOG(1) << "Using the default target " << host_target->ToDebugString() << 
" of device type "
-            << host_device_type << " for the host target";
+    // RULE C: Otherwise, create a default CPU host target.
+    host_target = MakeDefaultCPUTarget();
+    VLOG(1) << "Created a default target " << host_target->ToDebugString() << 
" of device type "
+            << host_target->kind->device_type << " for the host target";
   }
   ICHECK(host_target.defined());
   ICHECK(!host_target->host.defined());
 
-  if (host_device_type != kDLCPU) {
-    // I think we're on thin ice here until we've audited the code base for 
assumed kDLCPU.
-    VLOG(1) << "The host target is not a CPU.";
+  if (host_target->kind->device_type != kDLCPU) {
+    // I think we're on thin ice here until we've audited the code base for 
assumed CPU hosts.
+    VLOG(1) << "The host target is not a CPU. This is probably not going to 
work.";
   }
 
   //
   // Establish the host VirtualDevice.
   //
-  host_virtual_device =
-      virtual_device_cache_.Unique(VirtualDevice(host_device_type,
-                                                 /*virtual_device_id=*/0, 
host_target));
+  host_virtual_device = virtual_device_cache_.Unique(
+      VirtualDevice(static_cast<DLDeviceType>(host_target->kind->device_type),
+                    /*virtual_device_id=*/0, host_target));
+  ICHECK(host_virtual_device.defined());
+  ICHECK(host_virtual_device->target.defined());
 
   //
-  // Now that we've settled on a host, we can set it as the host on all 
primitive targets.
+  // Now that we've settled on a host, we can set it as the host on all the 
raw targets.
   //
-  Array<Target> new_primitve_targets;
-  new_primitve_targets.reserve(primitive_targets.size());
-  for (const auto& primitive_target : primitive_targets) {
-    new_primitve_targets.push_back(Target(primitive_target, host_target));
+  primitive_targets.clear();
+  primitive_targets.reserve(raw_targets.size());
+  for (const auto& raw_target : raw_targets) {
+    if (raw_target->host.defined() && !StructuralEqual()(raw_target->host, 
host_target)) {
+      VLOG(1) << "The target " << raw_target->ToDebugString()
+              << " already has a host which disagrees with the desired host 
target. It "
+              << "will be overridden.";
+    }
+    primitive_targets.push_back(Target(raw_target, host_target));
   }
-  primitive_targets = new_primitve_targets;
+  ICHECK_GT(primitive_targets.size(), 0U);
 
   //
-  // Gather the hints as to what our default device type for primitives should 
be.
+  // Check the primitive_targets are ordered correctly re 
Target::IsExternalCodegenFor.
+  //
+  std::unordered_set<DLDeviceType> primitive_target_device_types;
+  for (const auto& target : primitive_targets) {
+    
primitive_target_device_types.emplace(static_cast<DLDeviceType>(target->kind->device_type));
+  }
+  for (DLDeviceType device_type : primitive_target_device_types) {
+    Target first_primitive_target;
+    for (const auto& current_primitive_target : primitive_targets) {
+      if (current_primitive_target->kind->device_type != device_type) {
+        continue;
+      }
+      if (!first_primitive_target.defined()) {
+        first_primitive_target = current_primitive_target;
+        continue;
+      }
+      
CHECK(current_primitive_target.IsExternalCodegenFor(first_primitive_target))

Review Comment:
   I think that should be illegal so I'll add a check. Thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to