mbs-octoml commented on a change in pull request #9038:
URL: https://github.com/apache/tvm/pull/9038#discussion_r717948406



##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay 
sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E 
is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We 
do not track the
+ * specific target associated with D (this is recovered independently via a 
TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 
'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or 
"device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 
'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of 
the call
+ *     respectively. It is ok if source and destination devices are the same, 
such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 
'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the 
context
+ *    unconstrained. These are called 'annotations' in the rest of the code, 
have no operational
+ *    significance by themselves, but may trigger the insertion of a new 
"device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be 
constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or 
occurs as an
+ *       immediately let-bound expression. In this situation the extra degree 
of freedom in
+ *       the function result and let-binding leads to surprising device 
copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which 
indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps 
make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound 
are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if 
it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy 
the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by 
rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some 
special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect 
any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and 
results must be on the
+ * same device. However each call site can use a different device. In other 
words primitives are
+ * 'device polymorphic' since we compile and execute them for each required 
device.
+ *
+ * For most Relay expressions the device for the overall expression is the 
same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same 
device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device 
as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime 
using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, 
but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and 
"alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any 
device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the 
same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the 
function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on 
different devices, even
+ *    different from each other. Every call to the function must use the same 
choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay 
functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a 
global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default 
device.
+ *  - Unconstrained function parameters devices default to the device for the 
function result.
+ *  - Unconstrained let-bound expression devices default to the device for the 
overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all 
free domaints to
+ * the global default device. Worth a design doc with motivating examples I 
think.
+ *
+ * Phase 3
+ * -------
+ * Finally, the result of this analysis is reified into the result as:
+ *  - Additional "on_device" attributes (an Attrs resolving to a \p 
FunctionOnDeviceAttrs) for
+ *    every function (both top-level and local). These describe the devices 
for the function's
+ *    parameters and the result.
+ *  - Additional "device_copy" CallNodes where a copy is required in order to 
respect the
+ *    intent of the original "on_device" CallNodes.
+ *  - Additional "on_device" CallNodes where the device type of an expression 
does not match
+ *    that of the lexically enclosing "on_device" CallNode or function 
attribute. In practice
+ *    this means "on_device" CallNodes may appear in two places:
+ *     - On a let-bound expression if its device differs from the overall let 
expression.
+ *     - On a call argument if its device differs from the call result. In 
particular, the
+ *       argument to a "device_copy" call will always be wrapped in an 
"on_device". (That may
+ *       seem pedantic but simplifies downstream handling.)
+ *    However since we make it easy to track devices for variables we never 
wrap an "on_device"
+ *    around a var or global var. These uses of "on_device" imply both the 
argument and result are
+ *    on the same device. We signal this by setting the 'is_fixed' 
OnDeviceAttrs field to true,
+ *    which helps make this pass idempotent.
+ *
+ * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms 
to recover the device
+ * for any expression for their own use, e.g. during memory planning. All 
downstream passes must
+ * preserve the lexical scoping of the "on_device" CallNodes. In particular 
conversion to ANF
+ * must respect the lexical scoping convention:
+ * \code
+ * f(on_device(g(h(a, b), c), device_type=CPU))
+ * ==>
+ * let %x0 = on_device(h(a, b), device_type=CPU)
+ * let %x1 = on_device(g(%x0), device-type=CPU)
+ * f(on_device(%x1, device_type=CPU))
+ * \endcode
+ *
+ * This pass should be run before FuseOps it can use device-specific fusion 
rules.
+ *
+ * 'Stored on' vs 'Executes on'
+ * ----------------------------
+ * Obviously for a primitive call \code add(x, y) \endcode we can execute the 
primitive on the
+ * same device as will hold its result. Thus 'executes on' is the same as 
'stored on' for
+ * primitives.
+ *
+ * But what about for arbitrary Relay expressions? Most backends (interpreter, 
graph, VM) are
+ * implicitly executed on the 'host' CPU, with only primitive evaluation 
handed off to specific
+ * devices, thus the notion of 'executes on' is mute. AOT backends on the 
other hand need to
+ * know exactly which device (possibly one of a number of available 'CPU'-like 
devices) is
+ * responsible for execution. Currently that's handled independently by the \p 
AnnotateTargets
+ * pass, but we'd like to fold that into device planning here to ensure 
everything is consistent.
+ *
+ * Obviously since tensors are passed-by-pointer it's quite possible to 
execute a Relay
+ * expression (eg an if expression) on one device even though the tensor data 
resides on
+ * another. But for AOT that flexibility seems excessive. So we'd like to just 
take 'executes on'
+ * to be 'stored on' exactly. In particular, for a Relay function, we'd like 
to be able to just
+ * compile the function body for the function's result device.
+ *
+ * This works after conversion to ANF provided the compilation for a let 
expression is prepared
+ * to make a cross-device call. However we leave it to a downstream 
transformation to heuristically
+ * minimize cross-device calls by moving device copies out of functions. E.g.:
+ * \code
+ *   def @f() {  // execute on CPU
+ *     let x = on_device(...GPU computation..., device_type=GPU);
+ *     device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
+ *   }
+ *   def @main() {
+ *     ... call @f() on CPU ...
+ *   }
+ * \endcode
+ * could be rewritten to:
+ * \code
+ *   def @f() {  // execute on GPU
+ *     let x = ...GPU computation...;
+ *     ...GPU computation...
+ *   }
+ *   def @main() {
+ *     let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
+ *     ... use x on CPU ...
+ *   }
+ * \endcode
+ *
+ * Higher-order shenanigans
+ * ------------------------
+ * Relay is a 'mostly' higher-order language -- we can let-bind functions, 
pass functions
+ * as arguments (even anonymous functions), return functions, evaluate 
conditional expressions
+ * over functions, and so on. We handle this during constraint solving using 
the domain:
+ * \code
+ *   D  ::= <specific device type>   -- first-order
+ *        | fn(D,...,D):D            -- higher-order
+ * \endcode
+ * In this way we can determine the device for all function parameters and 
results. E.g. for
+ * \code
+ *   let f = fn(x, y) { ... }
+ *   let g = fn(f, z) { f(z, z) }
+ *   g(f, on_device(..., device_type=CPU))
+ * \endcode
+ * the parameters \p x and \p y will be on the CPU.
+ *
+ * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 
must evaluate to a
+ * function. Our analysis must guarantee that the function's parameters and 
result devices are
+ * consistent for \p e2, \p e3, and the context of the call. But:
+ *  - Which device holds the closure result of evaluating \p e1 ?
+ *  - If \p e2 is of function type, what does that mean when we say every 
function parameter
+ *    is on a device?
+ *  - If \p e1 returns a function, what does that mean when we say every 
function result is
+ *    on a device?
+ *
+ * Since higher-order aspects are later compiled away (by 'defunctionalization'
+ * aka 'firstification') we'd prefer not to have to answer any of those 
questions. In particular,
+ * we really don't want our domain \p D to allow for yet another device for 
the function closure.
+ * So we'll just force the 'device for a function' to be the same as the 
device for the function's
+ * result using the notion of the 'result domain' for a domain:
+ * \code
+ *   result_domain(<specific device type>) = <specific device type>
+ *   result_domain(fn(D1,...,Dn):Dr)       = result_domain(Dr)
+ * \endcode
+ *
+ * Similarly the domain does not have entries for tuples, references, or ADTs. 
Whenever the
+ * analysis encounters a function inside one of those it simply forces all 
argument and result
+ * devices for the function to match the device for the first-order 
expression. For example,
+ * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the 
inner function
+ * parameters and result must similarly be on the GPU.
+ *
+ * -------
+ * | AOR |  This pass supports all of Relay.
+ * -------
+ *    ^
+ *    |
+ *    `-- Mark's stamp of completeness :-)
+ *
+ * TODO(mbs):
+ *  * Though on_device is the identity for all types we can't wrap it around 
functions/constructors
+ *    taking type args (or at least not without changing type_infer.cc to see 
through them).
+ *    This is not currently handled generally.
+ *  * Proper diagnostics for unification failure using spans.
+ *  * Make sure the pass is idempotent even after FuseOps etc.
+ *  * Support application of constructors properly. Are they device 
polymorphic?
+ *  * Replace DLDeviceType with TargetDevice, and unify 'target annotation' 
with 'device planning'.
+ *  * Support running the pass post FuseOps (so need to understand primitive 
functions, both
+ *    outlines and lined) and post the VM transforms (probably need to support 
more intrinsic
+ *    forms?).
+ *  * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish 
between the default
+ *    device for primitives vs the default device for the rest of Relay.
+ *  * We'll probably need some support for partial 'device polymorphism' for 
functions once we
+ *    incorporate targets and memory scopes into the domain. For example it's 
ok for the function
+ *    body to be executed on different device ids provided they have the same 
target and memory
+ *    scope.
+ *  * Might be simpler to just let every type have a device annotation rather 
than work in
+ *    a separate domain?
+ *  * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary 
copies.
+ *  * The original device_annotation.cc RewriteAnnotatedOps removed all 
"on_device" calls
+ *    in tuples at the top level of function bodies or main expression, 
irrespective of the
+ *    "on_device" body. What's up with that?
+ */
+
+#include "./device_planner.h"
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+
+#include <unordered_map>
+
+#include "../op/annotation/annotation.h"
+#include "../op/memory/device_copy.h"
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+namespace {
+
+/*!
+ * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR 
primitives rather
+ * than the original "device_copy" operator.
+ *
+ * See te_compiler.cc for where this rewriting occurs.
+ */
+DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
+  auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
+  if (tir_call_attrs == nullptr) {
+    return {};
+  }
+  if (tir_call_attrs->metadata.count("source_device") != 1 ||
+      tir_call_attrs->metadata.count("dst_device") != 1) {
+    return {};
+  }
+  ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
+  return {
+      call_node->args[0],
+      static_cast<DLDeviceType>(
+          Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
+      
static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
+}
+
+class DeviceDomain;
+using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
+
+/******
+****** Domains
+******/
+
+/*!
+ * \brief Represents the domain over which we collect equality constraints.
+ *
+ * \code
+ *   D ::= ?x?                  -- first order, free
+ *       | <device_type>        -- first order, bound
+ *       | fn(D1, ..., Dn):Dr   -- higher order
+ * \endcode
+ *
+ * We require a function value to be on the same device as its result. To 
support that we need
+ * a notion of the 'result domain' of a domain:
+ * \code
+ *   result_domain(?x?)                = ?x?
+ *   result_domain(<device_type>)      = <device_type>
+ *   result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr)
+ * \endcode
+ */
+class DeviceDomain {
+ public:
+  /*!
+   * \brief Constructs a first-order domain of \p device_type, which may be
+   * \p kInvalidDeviceType to indicate the domain is free.
+   */
+  explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) 
{}
+
+  /*!
+   * \brief Constructs a higher-order domain, where \p args_and_result contain 
the
+   * function argument and result domains in order.
+   */
+  explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result)
+      : device_type_(kInvalidDeviceType), 
args_and_result_(std::move(args_and_result)) {}
+
+  /*! \brief Returns true if domain is first-order and free. */
+  bool is_free() const { return device_type_ == kInvalidDeviceType && 
args_and_result_.empty(); }
+
+  /*! \brief Returns true if domain is higher-order. */
+  bool is_higher_order() const { return !args_and_result_.empty(); }
+
+  DLDeviceType first_order_device_type() const {
+    ICHECK(args_and_result_.empty());
+    return device_type_;
+  }
+
+  size_t function_arity() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.size() - 1UL;
+  }
+
+  DeviceDomainPtr function_param(size_t i) const {
+    ICHECK(!args_and_result_.empty());
+    ICHECK_LT(i + 1, args_and_result_.size());
+    return args_and_result_[i];
+  }
+
+  DeviceDomainPtr function_result() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.back();
+  }
+
+ private:
+  /*!
+   * \brief If this is a function domain then always kInvalidDevice. Otherwise 
will be
+   * kInvalidDevice if the domain is still free, or the specific concrete 
device if the domain is
+   * bound.
+   */
+  const DLDeviceType device_type_;
+
+  /*!
+   * \brief If this is a function domain then the sub-domains for each of the 
function's
+   * arguments, and the domain for its result. Otherwise empty.
+   */
+  const std::vector<DeviceDomainPtr> args_and_result_;
+
+  friend struct DeviceDomainHash;
+  friend struct DeviceDomainEqual;
+  friend class DeviceDomains;
+};
+
+// Ye olde boost hash mixer.
+constexpr size_t mix(size_t h1, size_t h2) {
+  return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+}
+
+// The following hash and equality helpers give each free first-order domain 
pointer its own
+// distinct identity.
+struct DeviceDomainHash {
+  size_t operator()(const DeviceDomainPtr& domain) const {
+    if (domain->is_free()) {
+      // Give each free first-order domain its own identity.
+      return static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get()));
+    } else {
+      size_t h = domain->args_and_result_.size();
+      h = mix(h, std::hash<int>()(static_cast<int>(domain->device_type_)));
+      for (const auto& sub_domain_ptr : domain->args_and_result_) {
+        h = mix(h, DeviceDomainHash()(sub_domain_ptr));
+      }
+      return h;
+    }
+  }
+};
+
+struct DeviceDomainEqual {
+ public:
+  bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) 
const {
+    if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) {
+      // Mismatched arities are never equal.
+      // (Though we'll never ask to do such a comparison explicitly, the hash 
map
+      // may do so implicitly due to hash collisions.)
+      return false;
+    }
+    if (lhs->is_free() && rhs->is_free()) {
+      // Compare first-order free domains by their address.
+      return lhs.get() == rhs.get();
+    }
+    if (lhs->args_and_result_.empty()) {
+      // Compare first-order domains by their device type -- free vs bound 
will compare as false.
+      return lhs->device_type_ == rhs->device_type_;
+    } else {
+      // Compare higher-order domains pointwise.
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) {
+          return false;
+        }
+      }
+      return true;
+    }
+  }
+};
+
+/*!
+ * \brief Tracks the device domains for a set of expressions w.r.t. an 
equivalence relation
+ * built up by calls to \p Unify.
+ */
+class DeviceDomains {
+ public:
+  DeviceDomains() = default;
+
+  /*!
+   * \brief Returns a domain appropriate for \p type who's result domain is 
bound
+   * to \p device_type. If \p device_type is \p kInvalidDeviceType then the 
entire domain
+   * will be free.
+   */
+  static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType 
device_type) {
+    if (const auto* func_type_node = type.as<FuncTypeNode>()) {
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(func_type_node->arg_types.size() + 1);
+      for (const auto& arg_type : func_type_node->arg_types) {
+        args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType));
+      }
+      args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, 
device_type));
+      return std::make_shared<DeviceDomain>(std::move(args_and_result));
+    } else {
+      return std::make_shared<DeviceDomain>(device_type);
+    }
+  }
+
+  /*!
+   * \brief Returns a higher-order domain with \p args_and_results.
+   */
+  static DeviceDomainPtr MakeDomain(std::vector<DeviceDomainPtr> 
arg_and_results) {
+    return std::make_shared<DeviceDomain>(std::move(arg_and_results));
+  }
+
+  /*! \brief Returns a domain with the given result device type appropriate \p 
device_type. */
+  static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType 
device_type) {
+    ICHECK_NE(device_type, kInvalidDeviceType);
+    return MakeDomain(type, device_type);
+  }
+
+  /*! \brief Returns a free domain appropriate for \p type. */
+  static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, 
kInvalidDeviceType); }
+
+  /*! \brief Returns the domain representing the equivalence class containing 
\p domain. */
+  DeviceDomainPtr Lookup(DeviceDomainPtr domain) {
+    DeviceDomainPtr root = domain;
+    while (true) {
+      auto itr = domain_to_equiv_.find(root);
+      if (itr == domain_to_equiv_.end()) {
+        break;
+      }
+      ICHECK_NE(itr->second, root);
+      root = itr->second;
+      ICHECK_NOTNULL(root);
+    }
+    // Path compression.
+    while (domain != root) {
+      auto itr = domain_to_equiv_.find(domain);
+      ICHECK(itr != domain_to_equiv_.end());
+      domain = itr->second;
+      ICHECK_NOTNULL(domain);
+      itr->second = root;
+    }
+    return root;
+  }
+
+  /*!
+   * \brief Returns the domain accounting for all bound devices in \p lhs and 
\p rhs.
+   *
+   * Throws \p Error on failure.
+   */
+  DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) 
{
+    // TODO(mbs): Proper diagnostics.
+    ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size())
+        << "Device domains:" << std::endl
+        << ToString(lhs) << std::endl
+        << "and" << std::endl
+        << ToString(rhs) << std::endl
+        << "do not have the same kind and can't be unified.";
+    if (rhs->is_free()) {
+      return lhs;
+    } else if (lhs->is_free()) {
+      return rhs;
+    } else if (lhs->args_and_result_.empty()) {
+      // Must have consistent device types for first order domains.
+      if (lhs->device_type_ != rhs->device_type_) {
+        // TODO(mbs): Proper diagnostics.
+        std::ostringstream os;
+        os << "Inconsistent device types " << lhs->device_type_ << " and " << 
rhs->device_type_;
+        throw Error(os.str());
+      }
+      return lhs;
+    } else {
+      // Recurse for higher-order.
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(lhs->args_and_result_.size());
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        args_and_result.emplace_back(Unify(lhs->args_and_result_[i], 
rhs->args_and_result_[i]));
+      }
+      return MakeDomain(std::move(args_and_result));
+    }
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. 
Fails if \p lhs and \p
+   * rhs disagree on bound device type.
+   *
+   * Throws \p Error on failure.
+   */
+  // TODO(mbs): I don't think we need an occurs check since the program is 
well-typed, but
+  // given we have refs to functions I'm prepared to be surprised.
+  DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) {
+    lhs = Lookup(lhs);
+    rhs = Lookup(rhs);
+    auto joined_domain = Join(lhs, rhs);
+    if (!DeviceDomainEqual()(lhs, joined_domain)) {
+      domain_to_equiv_.emplace(lhs, joined_domain);
+    }
+    if (!DeviceDomainEqual()(rhs, joined_domain)) {
+      domain_to_equiv_.emplace(rhs, joined_domain);
+    }
+    return joined_domain;
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is 
higher-order,
+   * require all arguments and result of \p rhs to unify with \p lhs. 
Otherwise same as
+   * \p Unify.
+   *
+   * Throws \p Error on failure.
+   */
+  void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    if (!lhs->is_higher_order() && rhs->is_higher_order()) {
+      Collapse(lhs, rhs);
+    } else {
+      Unify(lhs, rhs);
+    }
+  }
+
+  /*! \brief Returns true if a domain is known for \p expr. */
+  bool contains(const Expr& expr) const { return 
expr_to_domain_.count(expr.get()); }
+
+  /*! \brief Returns the domain representing \p expr. */
+  DeviceDomainPtr DomainFor(const Expr& expr) {
+    ICHECK(expr.defined());
+    auto itr = expr_to_domain_.find(expr.get());
+    if (itr != expr_to_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    auto domain = Free(expr->checked_type());
+    expr_to_domain_.emplace(expr.get(), domain);
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the domain representing the callee (ie 'op') in \p call 
expression. If the
+   * callee is a primitive or special operation we handle it specially. 
Otherwise defers to \p
+   * DomainFor(call->op).
+   *
+   * This special handling is needed:
+   * - To handle the "on_device" and "device_copy" ops which constrain devices 
to the given devices.
+   * - To handle some special ops which constrain devices to the CPU.
+   * - To allow the same primitive to be called on different devices at 
different call sites.
+   * Since each call to the op can have a different domain we index the ops by 
the call expression
+   * rather than the op itself.
+   */
+  DeviceDomainPtr DomainForCallee(const Call& call) {
+    auto itr = call_to_callee_domain_.find(call.get());
+    if (itr != call_to_callee_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    std::vector<DeviceDomainPtr> args_and_result;
+
+    auto on_device_props = GetOnDeviceProps(call.get());
+    auto device_copy_props = GetDeviceCopyProps(call.get());
+    if (!device_copy_props.body.defined()) {
+      device_copy_props = GetPrimitiveDeviceCopyProps(call.get());
+    }
+
+    if (on_device_props.body.defined()) {
+      // on_device(expr, device_type=<t>, is_fixed=false)
+      // on_device : fn(<t>):?x?
+      //
+      // on_device(expr, device_type=<t>, is_fixed=true)
+      // on_device: fn(<t>):<t>
+      args_and_result.emplace_back(
+          ForDeviceType(on_device_props.body->checked_type(), 
on_device_props.device_type));
+      if (on_device_props.is_fixed) {
+        args_and_result.emplace_back(args_and_result.front());
+      } else {
+        
args_and_result.emplace_back(Free(on_device_props.body->checked_type()));
+      }
+    } else if (device_copy_props.body.defined()) {
+      // device_copy(expr, src_dev_type=<s>, dst_dev_type=<d>)
+      // device_copy: fn(<s>):<d>
+      args_and_result.emplace_back(
+          ForDeviceType(device_copy_props.body->checked_type(), 
device_copy_props.src_dev_type));
+      args_and_result.emplace_back(
+          ForDeviceType(device_copy_props.body->checked_type(), 
device_copy_props.dst_dev_type));
+    } else if (call->op == alloc_storage_op) {
+      ICHECK_EQ(call->args.size(), 2U);
+      // alloc_storage(size, alignment, device_type=<t>)
+      // alloc_storage: fn(<cpu>, <cpu>):<t>
+      const auto* attrs = call->attrs.as<AllocStorageAttrs>();
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(
+          ForDeviceType(call->checked_type(), 
static_cast<DLDeviceType>(attrs->device_type)));
+    } else if (call->op == alloc_tensor_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // alloc_tensor(storage, offset, shape)
+      // alloc_tensor: fn(?x?, <cpu>, <cpu>):?x?
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op == shape_func_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // shape_func(func, inputs, outputs, is_inputs=[...])
+      // shape_func: fn(..., <cpu>, <cpu>):<cpu>
+      // where ... is a free domain appropriate for func's type
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      // TODO(mbs): I think this should be on the cpu only when is_input = 
[false], but
+      // what do we do when we have multiple arguments with different is_input 
values?
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+    } else if (call->op == shape_of_op) {
+      ICHECK_EQ(call->args.size(), 1U);
+      // shape_of(tensor)
+      // shape_of: fn(?x?):<cpu>
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      args_and_result.emplace_back(cpu_domain_);
+    } else if (call->op == invoke_tvm_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // invoke_tvm_op(op, inputs, outputs)
+      // invoke_tvm_op: fn(..., ?x?, ?x?):?x?
+      // where ... is a free domain appropriate for op's type
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op == reshape_tensor_op) {
+      ICHECK_EQ(call->args.size(), 2U);
+      // reshape_tensor(data, shape)
+      // reshape_tensor: fn(?x?, <cpu>):?x?
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op->IsInstance<OpNode>()) {
+      // <primitive>(arg1, ..., argn)
+      // <primitive>: fn(?x?, ..., ?x?):?x?
+      // (all args and result must be first-order).
+      auto free_domain = Free(arb_);
+      for (size_t i = 0; i < call->args.size(); ++i) {
+        args_and_result.emplace_back(free_domain);
+      }
+      args_and_result.emplace_back(free_domain);
+    } else {
+      // Defer to normal case where op can be an arbitrary expression.
+      return DomainFor(call->op);
+    }
+    auto domain = MakeDomain(std::move(args_and_result));
+    call_to_callee_domain_.emplace(call.get(), domain);
+    return domain;
+  }
+
+  /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */
+  void UnifyExprExact(const Expr& lhs, const Expr& rhs) {
+    auto lhs_domain = DomainFor(lhs);
+    auto rhs_domain = DomainFor(rhs);
+    try {
+      Unify(lhs_domain, rhs_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expressions:" << std::endl
+                 << PrettyPrint(lhs) << std::endl
+                 << "with device:" << std::endl
+                 << ToString(lhs_domain) << "and:" << std::endl
+                 << PrettyPrint(rhs) << std::endl
+                 << "with device:" << std::endl
+                 << ToString(rhs_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*!
+   * \brief Unifies the domain for \p expr with \p expected_domain.
+   */
+  void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& 
expected_domain) {
+    auto actual_domain = DomainFor(expr);
+    try {
+      Unify(actual_domain, expected_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expression:" << std::endl
+                 << PrettyPrint(expr) << std::endl
+                 << "with actual device:" << std::endl
+                 << ToString(actual_domain) << std::endl
+                 << "and expected device:" << std::endl
+                 << ToString(expected_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*!
+   * \brief Unifies the domain for \p expr with \p expected_domain.
+   * If \p expected_domain is higher-order but \p expr is first-order, require 
all arguments
+   * and the result of \p expected_domain to have the same domain as for \p 
expr.
+   */
+  void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& 
expected_domain) {
+    auto actual_domain = DomainFor(expr);
+    try {
+      UnifyCollapsed(actual_domain, expected_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expression:" << std::endl
+                 << PrettyPrint(expr) << std::endl
+                 << "with actual device:" << std::endl
+                 << ToString(actual_domain) << std::endl
+                 << "and expected device:" << std::endl
+                 << ToString(expected_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*! \brief Returns true if \p domain contains any free sub-domains. */
+  bool AnyFree(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    if (domain->is_free()) {
+      return true;
+    }
+    for (const auto& sub_domain : domain->args_and_result_) {
+      if (AnyFree(sub_domain)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /*
+   * \brief Force all domains in \p higher_order_domain to unify with \p 
first_order_domain.
+   * This can be used to handle functions within tuples, references and ADTs 
since we don't
+   * attempt to track anything beyond 'the device' for expressions of those 
first-order types.
+   *
+   * Throws \p Error on failure.
+   */
+  void Collapse(const DeviceDomainPtr& first_order_domain,
+                const DeviceDomainPtr& higher_order_domain) {
+    for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) {
+      Unify(higher_order_domain->function_param(i), first_order_domain);
+    }
+    Unify(higher_order_domain->function_result(), first_order_domain);
+  }
+
+  /*! \brief Force all free domains in \p domain to default to \p 
default_device_type. */
+  void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) {
+    ICHECK_NE(default_device_type, kInvalidDeviceType);
+    domain = Lookup(domain);
+    if (domain->is_free()) {
+      // Will never throw since lhs is free.
+      Unify(domain, std::make_shared<DeviceDomain>(default_device_type));
+    } else if (!domain->args_and_result_.empty()) {
+      for (const auto& sub_domain : domain->args_and_result_) {
+        SetDefault(sub_domain, default_device_type);
+      }
+    }
+  }
+
+  /*!
+   * \brief If \p domain is higher-order and its result domain is free, force 
it to
+   * \p default_device_type. Then force any  remaining free domains to the 
result domain
+   * (freshly defaulted or original). If \p domain is first-order same as \p 
SetDefault.
+   */
+  void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType 
default_device_type) {
+    if (!domain->is_higher_order()) {
+      SetDefault(domain, default_device_type);
+      return;
+    }
+    DLDeviceType result_device_type = ResultDeviceType(domain);
+    if (result_device_type == kInvalidDeviceType) {
+      // If the function result device is still free use the given default.
+      result_device_type = default_device_type;
+    }
+    // Default any remaining free parameters to the function result device.
+    SetDefault(domain, result_device_type);
+  }
+
+  /*! \brief Returns one-line description of \p domain for debugging. */
+  std::string ToString(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    std::ostringstream os;
+    if (domain->is_free()) {
+      // first-order free
+      os << "?" << 
static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get())) << "?";
+    } else if (domain->args_and_result_.empty()) {
+      // first-order bound
+      os << "<" << domain->device_type_ << ">";
+    } else {
+      // higher-order
+      os << "fn(";
+      for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) {
+        if (i > 0) {
+          os << ",";
+        }
+        os << ToString(domain->args_and_result_[i]);
+      }
+      os << "):" << ToString(domain->args_and_result_.back());
+    }
+    return os.str();
+  }
+
+  /*! \brief Returns description of entire system of constraints for debugging 
*/
+  std::string ToString() {
+    std::ostringstream os;
+    for (const auto& pair : expr_to_domain_) {
+      os << "expression:" << std::endl
+         << PrettyPrint(GetRef<Expr>(pair.first)) << std::endl
+         << "domain:" << std::endl
+         << ToString(pair.second) << std::endl
+         << std::endl;
+    }
+    for (const auto& pair : call_to_callee_domain_) {
+      os << "call:" << std::endl
+         << PrettyPrint(GetRef<Call>(pair.first)) << std::endl
+         << "callee domain:" << std::endl
+         << ToString(pair.second) << std::endl
+         << std::endl;
+    }
+    return os.str();
+  }
+
+  /*!
+   * \brief Returns the result domain for \p domain (see defn in DeviceDomain 
comment).
+   */
+  DeviceDomainPtr ResultDomain(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    while (!domain->args_and_result_.empty()) {
+      domain = Lookup(domain->args_and_result_.back());
+    }
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the result (possibly free) device type for \p domain (see 
defn in DeviceDomain
+   * comment).
+   */
+  DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) {
+    return ResultDomain(domain)->first_order_device_type();
+  }
+
+ private:
+  /*! \brief Intrinsics we need to handle specially. */
+  const Op& alloc_storage_op = Op::Get("memory.alloc_storage");
+  const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor");
+  const Op& shape_of_op = Op::Get("vm.shape_of");
+  const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op");
+  const Op& shape_func_op = Op::Get("vm.shape_func");
+  const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor");
+  /*! \brief The CPU device type for special operators such as dynamic shape 
functions. */
+  const DLDeviceType cpu_device_type_ = kDLCPU;
+  /*! \brief Placeholder for any first-order type. */
+  Type arb_ = TupleType();
+  /*! \brief The domain for first-order expressions on the CPU. */
+  DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_);
+
+  /*! \brief Maps expressions to their domains as determined during analysis. 
*/
+  std::unordered_map<const ExprNode*, DeviceDomainPtr> expr_to_domain_;
+
+  /*!
+   * \brief Maps call expressions to the domains for their callee where the 
callee is a primitive.
+   */
+  std::unordered_map<const CallNode*, DeviceDomainPtr> call_to_callee_domain_;
+
+  /*! \brief Maps device domains to their equivalent domains as determined 
during unification. */
+  std::unordered_map<DeviceDomainPtr, DeviceDomainPtr, DeviceDomainHash, 
DeviceDomainEqual>
+      domain_to_equiv_;
+};
+
+/******
+****** Phase 0
+******/
+
+/*!
+ * \brief Rewrites "on_device" calls to handle some special cases.
+ */
+class RewriteOnDevices : public ExprMutator {
+ public:
+  RewriteOnDevices() = default;
+
+ private:
+  Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
+    Expr tuple = VisitExpr(tuple_get_item_node->tuple);
+    // TODO(mbs): Avoid copy.
+    Expr tuple_get_item =
+        TupleGetItem(tuple, tuple_get_item_node->index, 
tuple_get_item_node->span);
+    auto props = GetOnDeviceProps(tuple);
+    if (props.body.defined() && !props.is_fixed) {
+      VLOG(1) << "wrapping tuple get item:" << std::endl
+              << PrettyPrint(GetRef<TupleGetItem>(tuple_get_item_node)) << 
std::endl
+              << "with \"on_device\" for device " << props.device_type;
+      return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false);
+    } else {
+      return tuple_get_item;
+    }
+  }
+
+  Expr VisitExpr_(const LetNode* let_node) final {
+    auto expr = GetRef<Expr>(let_node);
+    std::vector<std::tuple<Var, Expr, Span>> bindings;
+    while (const auto* inner_let_node = expr.as<LetNode>()) {
+      Expr inner_let = GetRef<Let>(inner_let_node);
+      Expr value = VisitExpr(inner_let_node->value);
+      auto props = GetOnDeviceProps(value);
+      if (props.body.defined() && !props.is_fixed) {
+        VLOG(1) << "revising let-bound expression of let:" << std::endl
+                << PrettyPrint(expr) << std::endl
+                << "to be fixed to device " << props.device_type;
+        value = OnDevice(props.body, props.device_type, /*is_fixed=*/true);
+      }
+      bindings.emplace_back(inner_let_node->var, value, inner_let_node->span);
+      expr = inner_let_node->body;
+    }
+    expr = VisitExpr(expr);
+    // TODO(mbs): Avoid copy.
+    for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
+      expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr,
+                 /*span=*/std::get<2>(*itr));
+    }
+    return expr;
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) final {
+    Expr body = VisitExpr(function_node->body);
+    auto props = GetOnDeviceProps(body);
+    if (props.body.defined() && !props.is_fixed) {
+      VLOG(1) << "revising body of function:" << std::endl
+              << PrettyPrint(GetRef<Function>(function_node)) << std::endl
+              << "to be fixed to device " << props.device_type;
+      body = OnDevice(props.body, props.device_type, /*is_fixed=*/true);
+    }
+    // TODO(mbs): Avoid copy
+    return Function(function_node->params, body, function_node->ret_type,
+                    function_node->type_params, function_node->attrs, 
function_node->span);
+  }
+};
+
+/******
+****** Phase 1
+******/
+
+/*
+ * \brief Collects the system of device constraints for all sub-expressions in 
a module.
+ * It is possible some devices remain free and will need to be defaulted by \p 
DeviceDefaulter.
+ */
+class DeviceAnalyzer : public ExprVisitor {
+ public:
+  explicit DeviceAnalyzer(IRModule mod)
+      : mod_(std::move(mod)), domains_(std::make_unique<DeviceDomains>()) {}
+
+  /*!
+   * \brief Returns the expression-to-device-domain map for all expressions in 
all the global
+   * function definitions in the module. Expressions may have free domains, 
these will be resolved
+   * by \p DeviceDefaulter below.
+   */
+  std::unique_ptr<DeviceDomains> Analyze() {
+    VLOG_CONTEXT << "DeviceAnalyzer";
+    for (const auto& pair : mod_->functions) {
+      VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << 
"'";
+      domains_->UnifyExprExact(pair.first, pair.second);
+      VisitExpr(pair.second);
+    }
+    return std::move(domains_);
+  }
+
+ private:
+  void VisitExpr_(const CallNode* call_node) final {
+    auto call = GetRef<Call>(call_node);
+
+    // Find the higher-order domain for the callee. See DomainForCallee for 
the special rules
+    // for primitives.
+    VisitExpr(call_node->op);
+    auto func_domain = domains_->DomainForCallee(call);  // higher-order
+
+    // Build the domain for the function implied by its arguments and call 
context.
+    ICHECK_EQ(func_domain->function_arity(), call_node->args.size());
+    std::vector<DeviceDomainPtr> args_and_result_domains;
+    args_and_result_domains.reserve(call_node->args.size() + 1);
+    for (const auto& arg : call_node->args) {
+      args_and_result_domains.emplace_back(domains_->DomainFor(arg));
+      VisitExpr(arg);
+    }
+    args_and_result_domains.emplace_back(domains_->DomainFor(call));
+    auto implied_domain =
+        DeviceDomains::MakeDomain(std::move(args_and_result_domains));  // 
higher-order
+
+    VLOG(1) << "initial call function domain:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "and implied domain:" << std::endl
+            << domains_->ToString(implied_domain) << "for call:" << std::endl
+            << PrettyPrint(call);
+
+    // The above must match.
+    try {
+      domains_->Unify(func_domain, implied_domain);  // higher-order
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Function parameters and result devices do not match those 
of call. Call:"
+                 << std::endl
+                 << PrettyPrint(call) << std::endl
+                 << "with function devices:" << std::endl
+                 << domains_->ToString(func_domain) << std::endl
+                 << "and implied call devices:" << std::endl
+                 << domains_->ToString(implied_domain) << std::endl
+                 << e.what();
+    }
+
+    VLOG(1) << "final call function domain:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "for call:" << std::endl
+            << PrettyPrint(call);
+  }
+
+  void VisitExpr_(const LetNode* let_node) final {
+    Expr expr = GetRef<Let>(let_node);
+    // Iteratively visit let nodes to avoid stack overflow.
+    while (expr->IsInstance<LetNode>()) {
+      Let let = Downcast<Let>(expr);
+      // Let var must be same device as value it is bound to.
+      domains_->UnifyExprExact(let->var, let->value);  // may be higher-order
+      // Let body must be same device as overall let.
+      domains_->UnifyExprExact(let, let->body);  // may be higher-order
+
+      VisitExpr(let->var);
+      VisitExpr(let->value);
+
+      expr = let->body;
+    }
+
+    // Visit the last body
+    VisitExpr(expr);
+  }
+
+  void VisitExpr_(const FunctionNode* function_node) final {
+    // No need to step into fused primitive functions as they are lowered 
individually according
+    // to the devices of all their call sites.
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return;
+    }
+
+    auto function = GetRef<Function>(function_node);
+    auto func_domain = domains_->DomainFor(function);  // higher-order
+
+    // The function body domain must match the function result domain.
+    domains_->UnifyExprExact(function_node->body,
+                             func_domain->function_result());  // may be 
higher-order
+
+    VLOG(1) << "initial function domain:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "and function body domain:" << std::endl
+            << domains_->ToString(domains_->DomainFor(function_node->body)) << 
std::endl
+            << "for function:" << std::endl
+            << PrettyPrint(function);
+
+    ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
+    for (size_t i = 0; i < function_node->params.size(); ++i) {
+      // The parameter domains must match the function argument domains.
+      domains_->UnifyExprExact(function_node->params[i],
+                               func_domain->function_param(i));  // may be 
higher-order
+      VisitExpr(function_node->params[i]);
+    }
+
+    // If the function already has an "on_device" attribute then we can further
+    // constrain the function's domain to match it.
+    Optional<Attrs> opt_attrs =
+        
function_node->GetAttr<Attrs>(FunctionOnDeviceAttrs::kFunctionAttrsKey);
+    if (opt_attrs) {
+      std::vector<DeviceDomainPtr> args_and_result;
+      for (size_t i = 0; i < function_node->params.size(); ++i) {
+        args_and_result.emplace_back(
+            domains_->ForDeviceType(function_node->params[i]->checked_type(),
+                                    GetFunctionParamDeviceType(function_node, 
i)));
+      }
+      args_and_result.emplace_back(domains_->ForDeviceType(
+          function_node->body->checked_type(), 
GetFunctionResultDeviceType(function_node)));
+      auto annotation_domain = 
domains_->MakeDomain(std::move(args_and_result));
+      try {
+        domains_->Unify(func_domain, annotation_domain);  // higher-order
+      } catch (const Error& e) {
+        // TODO(mbs): Proper diagnostics.
+        LOG(FATAL)
+            << "Function devices are incompatible with its \"on_device\" 
annotation. Function:"
+            << std::endl
+            << PrettyPrint(function) << std::endl
+            << "with function devices:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "and annotation devices:" << std::endl
+            << domains_->ToString(annotation_domain) << std::endl
+            << e.what();
+      }
+    }
+
+    VisitExpr(function_node->body);
+
+    VLOG(1) << "final function domain:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "and function body domain:" << std::endl
+            << domains_->ToString(domains_->DomainFor(function_node->body)) << 
std::endl
+            << "for function:" << std::endl
+            << PrettyPrint(function);
+  }
+
+  void VisitExpr_(const TupleNode* tuple_node) final {
+    Tuple tuple = GetRef<Tuple>(tuple_node);
+    for (size_t i = 0; i < tuple->fields.size(); i++) {
+      auto domain = domains_->DomainFor(tuple->fields[i]);  // may be 
higher-order
+      domains_->UnifyExprCollapsed(tuple, domain);          // collapse to 
first-order if needed
+      VisitExpr(tuple->fields[i]);
+    }
+  }
+
+  void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
+    TupleGetItem tuple_get_item = GetRef<TupleGetItem>(tuple_get_item_node);
+    auto domain = domains_->DomainFor(tuple_get_item);  // may be higher-order
+    domains_->UnifyExprCollapsed(tuple_get_item_node->tuple,
+                                 domain);  // collapse to first-order if needed
+    VisitExpr(tuple_get_item_node->tuple);
+  }
+
+  class DevicePatternAnalyzer : public PatternVisitor {
+   public:
+    DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node)
+        : domains_(domains), adt_node_(adt_node) {}
+
+   private:
+    void VisitPattern_(const PatternVarNode* pattern_var_node) final {
+      auto var_domain = domains_->DomainFor(pattern_var_node->var);  // may be 
higher order
+      domains_->UnifyExprCollapsed(GetRef<Expr>(adt_node_),
+                                   var_domain);  // collapse to first-order if 
needed
+    }
+
+    /*! \brief (Mutable borrow of) the domains for all expressions processed 
so far. */
+    DeviceDomains* domains_;
+    /*! \brief The expression for the ADT we are matching over. */
+    const ExprNode* adt_node_;
+  };
+
+  void VisitPattern(const Pattern& pattern) final {}
+
+  void VisitExpr_(const MatchNode* match_node) final {
+    // For match node, we unify the value and the rhs of each clause
+    Match match = GetRef<Match>(match_node);
+    auto match_domain = domains_->DomainFor(match);  // may be higher-order
+    DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get());
+    domains_->UnifyExprCollapsed(match->data, match_domain);  // collapse to 
first-order if needed
+    for (const auto& clause : match->clauses) {
+      pattern_analyzer.VisitPattern(clause->lhs);
+      domains_->UnifyExprExact(clause->rhs, match_domain);
+      VisitExpr(clause->rhs);
+    }
+    VisitExpr(match_node->data);
+  }
+
+  void VisitExpr_(const GlobalVarNode* global_var_node) final {
+    domains_->DomainFor(GetRef<GlobalVar>(global_var_node));
+  }
+
+  void VisitExpr_(const VarNode* var_node) final { 
domains_->DomainFor(GetRef<Var>(var_node)); }
+
+  void VisitExpr_(const ConstantNode* constant_node) final {
+    domains_->DomainFor(GetRef<Constant>(constant_node));
+  }
+
+  void VisitExpr_(const ConstructorNode* constructor_node) final {
+    // Probably needs to be device polymorphic.
+    domains_->DomainFor(GetRef<Constructor>(constructor_node));
+  }
+
+  void VisitExpr_(const IfNode* if_node) final {
+    auto ife = GetRef<If>(if_node);
+    auto domain = domains_->DomainFor(ife);               // may be 
higher-order
+    domains_->UnifyExprCollapsed(if_node->cond, domain);  // collapse to 
first-order if needed
+    domains_->UnifyExprExact(if_node->true_branch, domain);
+    domains_->UnifyExprExact(if_node->false_branch, domain);
+    VisitExpr(if_node->cond);
+    VisitExpr(if_node->true_branch);
+    VisitExpr(if_node->false_branch);
+  }
+
+  void VisitExpr_(const OpNode* op) final {
+    // no-op, primitive operators are handled at their call-sites.
+  }
+
+  void VisitExpr_(const RefCreateNode* ref_create_node) final {
+    auto ref_create = GetRef<RefCreate>(ref_create_node);
+    auto domain = domains_->DomainFor(ref_create_node->value);  // may be 
higher-order
+    domains_->UnifyExprCollapsed(ref_create, domain);           // collapse to 
first-order if needed
+    VisitExpr(ref_create_node->value);
+  }
+
+  void VisitExpr_(const RefReadNode* ref_read_node) final {
+    auto ref_read = GetRef<RefRead>(ref_read_node);
+    auto domain = domains_->DomainFor(ref_read);               // may be 
higher-order
+    domains_->UnifyExprCollapsed(ref_read_node->ref, domain);  // collapse to 
first-order if needed
+    VisitExpr(ref_read_node->ref);
+  }
+
+  void VisitExpr_(const RefWriteNode* ref_write_node) final {
+    auto ref_write = GetRef<RefWrite>(ref_write_node);
+    auto domain = domains_->DomainFor(ref_write->value);   // may be 
higher-order
+    domains_->UnifyExprCollapsed(ref_write->ref, domain);  // collapse to 
first-order if needed
+    domains_->UnifyExprCollapsed(ref_write, domain);       // collapse to 
first-order if needed
+    VisitExpr(ref_write_node->ref);
+    VisitExpr(ref_write_node->value);
+  }
+
+  /*! \brief The module we are analyzing. */
+  IRModule mod_;
+  /*! \brief The domains for all expressions processed so far. */
+  std::unique_ptr<DeviceDomains> domains_;
+};
+
+/******
+****** Phase 2
+******/
+
+/*!
+ * \brief Ensures every sub-expression in a module has a device type, using 
both the global
+ * default and some local heuristics to avoid unnecessary additional 
"device_copy" CallNodes.
+ *
+ * TODO(mbs): I think this is deterministic? We do however visit the top-level 
defs in hashmap
+ * order.
+ */
+class DeviceDefaulter : public ExprVisitor {
+ public:
+  DeviceDefaulter(IRModule mod, std::unique_ptr<DeviceDomains> domains,
+                  DLDeviceType default_device_type)
+      : mod_(std::move(mod)),
+        domains_(std::move(domains)),
+        default_device_type_(default_device_type) {}
+
+  std::unique_ptr<DeviceDomains> Default() {
+    VLOG_CONTEXT << "DeviceDefaulter";
+    for (const auto& pair : mod_->functions) {
+      VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'";
+      VisitExpr(pair.second);
+    }
+    return std::move(domains_);
+  }
+
+ private:
+  void VisitExpr_(const FunctionNode* function_node) final {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return;
+    }
+
+    auto function = GetRef<Function>(function_node);
+    auto func_domain = domains_->DomainFor(function);  // higher-order
+    ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
+    if (domains_->AnyFree(func_domain)) {
+      VLOG(1) << "before defaulting function:" << std::endl << 
domains_->ToString(func_domain);
+      domains_->SetResultDefaultThenParams(func_domain, default_device_type_);
+      VLOG(1) << "after defaulting function:" << std::endl << 
domains_->ToString(func_domain);
+    }
+    VisitExpr(function_node->body);
+  }
+
+  void VisitExpr_(const CallNode* call_node) final {
+    auto call = GetRef<Call>(call_node);
+    auto func_domain = domains_->DomainForCallee(call);  // higher-order
+    ICHECK_EQ(func_domain->function_arity(), call_node->args.size());
+    if (domains_->AnyFree(func_domain)) {
+      // For calls to Relay functions this step is identical to that for 
VisitExpr_(FunctionNode*)
+      // above. But for calls to primitives we may still need to force free 
domains to be
+      // defaulted.
+      VLOG(1) << "before defaulting callee:" << std::endl << 
domains_->ToString(func_domain);
+      domains_->SetResultDefaultThenParams(func_domain, default_device_type_);
+      VLOG(1) << "after defaulting callee:" << std::endl << 
domains_->ToString(func_domain);
+    }
+    return ExprVisitor::VisitExpr_(call_node);
+  }
+
+  void VisitExpr_(const LetNode* let_node) final {
+    Expr expr = GetRef<Let>(let_node);
+    // Iteratively visit let nodes to avoid stack overflow.
+    while (expr->IsInstance<LetNode>()) {
+      Let let = Downcast<Let>(expr);
+      // If the let-var device is still free force it to match the overall let.
+      auto let_domain = domains_->DomainFor(let);  // may be higher-order
+      DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain);
+      ICHECK_NE(let_device_type, kInvalidDeviceType);
+      auto let_var_domain = domains_->DomainFor(let->var);  // may be 
higher-order
+      if (domains_->AnyFree(let_var_domain)) {
+        VLOG(1) << "before defaulting let-var:" << std::endl << 
domains_->ToString(let_var_domain);
+        domains_->SetDefault(let_var_domain, let_device_type);
+        VLOG(1) << "after defaulting let-var:" << std::endl << 
domains_->ToString(let_var_domain);
+      }
+      VisitExpr(let->var);
+      VisitExpr(let->value);
+      expr = let->body;
+    }
+    VisitExpr(expr);
+  }
+
+  /*! \brief The module we are processing. */
+  IRModule mod_;
+  /*! \brief The domains for all expressions.  */
+  std::unique_ptr<DeviceDomains> domains_;
+  /*! \brief The default device type. */
+  DLDeviceType default_device_type_;
+};
+
+/******
+****** Phase 3
+******/
+
+/*!
+ * \brief Inserts missing "device_copy" CallNodes, and ensures the device type 
of every
+ * sub-expression in a module can be easily recovered by a later 
transformation using simple
+ * lexical scoping rules (e.g. for memory planning).
+ *
+ * - Discard any existing "on_device" CallNodes since their job is done. 
Similarly, discard
+ *   any existing "device_copy" CallNodes which are no-ops.
+ *
+ * - Functions are given an "on_device" attribute bound to a 
FunctionOnDeviceAttrs to capture
+ *   the device type for its parameters and result.
+ *
+ * - Additional "device_copy" CallNodes are inserted wherever there's a 
transition between
+ *   storage device types. Since the DeviceAnalyzer phase succeeded this can 
only happen
+ *   where the original program explicitly allowed a transition using an 
"on_device" CallNode.
+ *   That is, we do not not try to 'fix' a program with inconsistent devices.
+ *
+ * - Additional "on_device" CallNodes are inserted so that a later transform 
can discover
+ *   the device for an arbitrary sub-expression by looking only for the 
lexically enclosing
+ *   "on_device" CallNode or "on_device" function attribute. In particular, 
since function
+ *   arguments and let-bound expressions can be on a device different from the 
function
+ *   or let body itself we will insert "on_device" CallNodes to spell out any 
differences. This
+ *   applies even to the argument to a "device_copy" CallNode, which may look 
pedantic but
+ *   keeps downstream processing simple. The "on_device" calls should be 
removed before code gen,
+ *   which is easily done on-the-fly.
+ */
+class DeviceCapturer : public ExprMutator {
+ public:
+  DeviceCapturer(IRModule mod, std::unique_ptr<DeviceDomains> domains)
+      : mod_(std::move(mod)), domains_(std::move(domains)) {}
+
+  IRModule Capture() {
+    VLOG_CONTEXT << "CaptureDevices";
+    IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), 
mod_->source_map);
+    for (const auto& pair : mod_->functions) {
+      VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'";
+      result->Add(pair.first, Downcast<BaseFunc>(Mutate(pair.second)));
+    }
+    return result;
+  }
+
+ private:
+  // Nothing interesting for VarNode, ConstantNode, GlobalVarNode and OpNode.
+
+  Expr VisitExpr_(const TupleNode* tuple_node) final {
+    auto tuple = GetRef<Tuple>(tuple_node);
+    Array<Expr> fields;
+    fields.reserve(tuple_node->fields.size());
+    for (const auto& field : tuple_node->fields) {
+      fields.push_back(VisitChild(tuple, field));
+    }
+    // TODO(mbs): Avoid copy
+    return Tuple(std::move(fields), tuple_node->span);
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) final {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return GetRef<Function>(function_node);
+    }
+
+    auto function = GetRef<Function>(function_node);
+    auto func_domain = domains_->DomainFor(function);  // higher-order
+    VLOG(1) << "capturing function:" << std::endl
+            << PrettyPrint(function) << std::endl
+            << "with domain:" << std::endl
+            << domains_->ToString(func_domain);
+
+    // Gather the parameter and result device types for the "on_device" 
function attribute.
+    ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
+    DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain);
+    ICHECK_NE(result_device_type, kInvalidDeviceType);
+    Array<Integer> param_device_types;
+    param_device_types.reserve(function_node->params.size());
+    for (size_t i = 0; i < function_node->params.size(); ++i) {
+      DLDeviceType param_device_type = 
domains_->ResultDeviceType(func_domain->function_param(i));
+      ICHECK_NE(param_device_type, kInvalidDeviceType);
+      param_device_types.push_back(param_device_type);
+    }
+
+    // Rewrite the body. Note that the body may have begun with an "on_device" 
so
+    // be prepared to insert a "device_copy".
+    Expr body = VisitChild(
+        /*lexical_device_type=*/result_device_type,
+        /*expected_device_type=*/result_device_type,
+        /*child_device_type=*/GetDeviceType(function_node->body), 
function_node->body);
+
+    // TODO(mbs): Avoid copy
+    Function func = Function(function_node->params, body, 
function_node->ret_type,
+                             function_node->type_params, function_node->attrs, 
function_node->span);
+    return FunctionOnDevice(func, param_device_types, result_device_type);
+  }
+
+  Expr VisitExpr_(const CallNode* call_node) final {

Review comment:
       I added some simple examples to each of the internal pass class 
comments. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to