mbs-octoml commented on a change in pull request #9038:
URL: https://github.com/apache/tvm/pull/9038#discussion_r717101958



##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay 
sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E 
is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We 
do not track the
+ * specific target associated with D (this is recovered independently via a 
TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 
'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or 
"device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 
'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of 
the call
+ *     respectively. It is ok if source and destination devices are the same, 
such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 
'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the 
context
+ *    unconstrained. These are called 'annotations' in the rest of the code, 
have no operational
+ *    significance by themselves, but may trigger the insertion of a new 
"device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be 
constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or 
occurs as an
+ *       immediately let-bound expression. In this situation the extra degree 
of freedom in
+ *       the function result and let-binding leads to surprising device 
copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which 
indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps 
make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound 
are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if 
it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy 
the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by 
rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some 
special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect 
any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and 
results must be on the
+ * same device. However each call site can use a different device. In other 
words primitives are
+ * 'device polymorphic' since we compile and execute them for each required 
device.
+ *
+ * For most Relay expressions the device for the overall expression is the 
same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same 
device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device 
as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime 
using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, 
but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and 
"alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any 
device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the 
same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the 
function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on 
different devices, even
+ *    different from each other. Every call to the function must use the same 
choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay 
functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a 
global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default 
device.
+ *  - Unconstrained function parameters devices default to the device for the 
function result.
+ *  - Unconstrained let-bound expression devices default to the device for the 
overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all 
free domaints to
+ * the global default device. Worth a design doc with motivating examples I 
think.
+ *
+ * Phase 3
+ * -------
+ * Finally, the result of this analysis is reified into the result as:
+ *  - Additional "on_device" attributes (an Attrs resolving to a \p 
FunctionOnDeviceAttrs) for
+ *    every function (both top-level and local). These describe the devices 
for the function's
+ *    parameters and the result.
+ *  - Additional "device_copy" CallNodes where a copy is required in order to 
respect the
+ *    intent of the original "on_device" CallNodes.
+ *  - Additional "on_device" CallNodes where the device type of an expression 
does not match
+ *    that of the lexically enclosing "on_device" CallNode or function 
attribute. In practice
+ *    this means "on_device" CallNodes may appear in two places:
+ *     - On a let-bound expression if its device differs from the overall let 
expression.
+ *     - On a call argument if its device differs from the call result. In 
particular, the
+ *       argument to a "device_copy" call will always be wrapped in an 
"on_device". (That may
+ *       seem pedantic but simplifies downstream handling.)
+ *    However since we make it easy to track devices for variables we never 
wrap an "on_device"
+ *    around a var or global var. These uses of "on_device" imply both the 
argument and result are
+ *    on the same device. We signal this by setting the 'is_fixed' 
OnDeviceAttrs field to true,
+ *    which helps make this pass idempotent.
+ *
+ * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms 
to recover the device
+ * for any expression for their own use, e.g. during memory planning. All 
downstream passes must
+ * preserve the lexical scoping of the "on_device" CallNodes. In particular 
conversion to ANF
+ * must respect the lexical scoping convention:
+ * \code
+ * f(on_device(g(h(a, b), c), device_type=CPU))
+ * ==>
+ * let %x0 = on_device(h(a, b), device_type=CPU)
+ * let %x1 = on_device(g(%x0), device-type=CPU)
+ * f(on_device(%x1, device_type=CPU))
+ * \endcode
+ *
+ * This pass should be run before FuseOps it can use device-specific fusion 
rules.
+ *
+ * 'Stored on' vs 'Executes on'
+ * ----------------------------
+ * Obviously for a primitive call \code add(x, y) \endcode we can execute the 
primitive on the
+ * same device as will hold its result. Thus 'executes on' is the same as 
'stored on' for
+ * primitives.
+ *
+ * But what about for arbitrary Relay expressions? Most backends (interpreter, 
graph, VM) are
+ * implicitly executed on the 'host' CPU, with only primitive evaluation 
handed off to specific
+ * devices, thus the notion of 'executes on' is mute. AOT backends on the 
other hand need to
+ * know exactly which device (possibly one of a number of available 'CPU'-like 
devices) is
+ * responsible for execution. Currently that's handled independently by the \p 
AnnotateTargets
+ * pass, but we'd like to fold that into device planning here to ensure 
everything is consistent.
+ *
+ * Obviously since tensors are passed-by-pointer it's quite possible to 
execute a Relay
+ * expression (eg an if expression) on one device even though the tensor data 
resides on
+ * another. But for AOT that flexibility seems excessive. So we'd like to just 
take 'executes on'
+ * to be 'stored on' exactly. In particular, for a Relay function, we'd like 
to be able to just
+ * compile the function body for the function's result device.
+ *
+ * This works after conversion to ANF provided the compilation for a let 
expression is prepared
+ * to make a cross-device call. However we leave it to a downstream 
transformation to heuristically
+ * minimize cross-device calls by moving device copies out of functions. E.g.:
+ * \code
+ *   def @f() {  // execute on CPU
+ *     let x = on_device(...GPU computation..., device_type=GPU);
+ *     device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
+ *   }
+ *   def @main() {
+ *     ... call @f() on CPU ...
+ *   }
+ * \endcode
+ * could be rewritten to:
+ * \code
+ *   def @f() {  // execute on GPU
+ *     let x = ...GPU computation...;
+ *     ...GPU computation...
+ *   }
+ *   def @main() {
+ *     let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
+ *     ... use x on CPU ...
+ *   }
+ * \endcode
+ *
+ * Higher-order shenanigans
+ * ------------------------
+ * Relay is a 'mostly' higher-order language -- we can let-bind functions, 
pass functions
+ * as arguments (even anonymous functions), return functions, evaluate 
conditional expressions
+ * over functions, and so on. We handle this during constraint solving using 
the domain:
+ * \code
+ *   D  ::= <specific device type>   -- first-order
+ *        | fn(D,...,D):D            -- higher-order
+ * \endcode
+ * In this way we can determine the device for all function parameters and 
results. E.g. for
+ * \code
+ *   let f = fn(x, y) { ... }
+ *   let g = fn(f, z) { f(z, z) }
+ *   g(f, on_device(..., device_type=CPU))
+ * \endcode
+ * the parameters \p x and \p y will be on the CPU.
+ *
+ * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 
must evaluate to a
+ * function. Our analysis must guarantee that the function's parameters and 
result devices are
+ * consistent for \p e2, \p e3, and the context of the call. But:
+ *  - Which device holds the closure result of evaluating \p e1 ?
+ *  - If \p e2 is of function type, what does that mean when we say every 
function parameter
+ *    is on a device?
+ *  - If \p e1 returns a function, what does that mean when we say every 
function result is
+ *    on a device?
+ *
+ * Since higher-order aspects are later compiled away (by 'defunctionalization'
+ * aka 'firstification') we'd prefer not to have to answer any of those 
questions. In particular,
+ * we really don't want our domain \p D to allow for yet another device for 
the function closure.
+ * So we'll just force the 'device for a function' to be the same as the 
device for the function's
+ * result using the notion of the 'result domain' for a domain:
+ * \code
+ *   result_domain(<specific device type>) = <specific device type>
+ *   result_domain(fn(D1,...,Dn):Dr)       = result_domain(Dr)
+ * \endcode
+ *
+ * Similarly the domain does not have entries for tuples, references, or ADTs. 
Whenever the
+ * analysis encounters a function inside one of those it simply forces all 
argument and result
+ * devices for the function to match the device for the first-order 
expression. For example,
+ * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the 
inner function
+ * parameters and result must similarly be on the GPU.
+ *
+ * -------
+ * | AOR |  This pass supports all of Relay.
+ * -------
+ *    ^
+ *    |
+ *    `-- Mark's stamp of completeness :-)
+ *
+ * TODO(mbs):
+ *  * Though on_device is the identity for all types we can't wrap it around 
functions/constructors
+ *    taking type args (or at least not without changing type_infer.cc to see 
through them).
+ *    This is not currently handled generally.
+ *  * Proper diagnostics for unification failure using spans.
+ *  * Make sure the pass is idempotent even after FuseOps etc.
+ *  * Support application of constructors properly. Are they device 
polymorphic?
+ *  * Replace DLDeviceType with TargetDevice, and unify 'target annotation' 
with 'device planning'.
+ *  * Support running the pass post FuseOps (so need to understand primitive 
functions, both
+ *    outlines and lined) and post the VM transforms (probably need to support 
more intrinsic
+ *    forms?).
+ *  * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish 
between the default
+ *    device for primitives vs the default device for the rest of Relay.
+ *  * We'll probably need some support for partial 'device polymorphism' for 
functions once we
+ *    incorporate targets and memory scopes into the domain. For example it's 
ok for the function
+ *    body to be executed on different device ids provided they have the same 
target and memory
+ *    scope.
+ *  * Might be simpler to just let every type have a device annotation rather 
than work in
+ *    a separate domain?
+ *  * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary 
copies.
+ *  * The original device_annotation.cc RewriteAnnotatedOps removed all 
"on_device" calls
+ *    in tuples at the top level of function bodies or main expression, 
irrespective of the
+ *    "on_device" body. What's up with that?
+ */
+
+#include "./device_planner.h"
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+
+#include <unordered_map>
+
+#include "../op/annotation/annotation.h"
+#include "../op/memory/device_copy.h"
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+namespace {
+
+/*!
+ * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR 
primitives rather
+ * than the original "device_copy" operator.
+ *
+ * See te_compiler.cc for where this rewriting occurs.
+ */
+DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
+  auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
+  if (tir_call_attrs == nullptr) {
+    return {};
+  }
+  if (tir_call_attrs->metadata.count("source_device") != 1 ||
+      tir_call_attrs->metadata.count("dst_device") != 1) {
+    return {};
+  }
+  ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
+  return {
+      call_node->args[0],
+      static_cast<DLDeviceType>(
+          Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
+      
static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
+}
+
+class DeviceDomain;
+using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
+
+/******
+****** Domains
+******/
+
+/*!
+ * \brief Represents the domain over which we collect equality constraints.
+ *
+ * \code
+ *   D ::= ?x?                  -- first order, free
+ *       | <device_type>        -- first order, bound
+ *       | fn(D1, ..., Dn):Dr   -- higher order
+ * \endcode
+ *
+ * We require a function value to be on the same device as its result. To 
support that we need
+ * a notion of the 'result domain' of a domain:
+ * \code
+ *   result_domain(?x?)                = ?x?
+ *   result_domain(<device_type>)      = <device_type>
+ *   result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr)
+ * \endcode
+ */
+class DeviceDomain {
+ public:
+  /*!
+   * \brief Constructs a first-order domain of \p device_type, which may be
+   * \p kInvalidDeviceType to indicate the domain is free.
+   */
+  explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) 
{}
+
+  /*!
+   * \brief Constructs a higher-order domain, where \p args_and_result contain 
the
+   * function argument and result domains in order.
+   */
+  explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result)
+      : device_type_(kInvalidDeviceType), 
args_and_result_(std::move(args_and_result)) {}
+
+  /*! \brief Returns true if domain is first-order and free. */
+  bool is_free() const { return device_type_ == kInvalidDeviceType && 
args_and_result_.empty(); }
+
+  /*! \brief Returns true if domain is higher-order. */
+  bool is_higher_order() const { return !args_and_result_.empty(); }
+
+  DLDeviceType first_order_device_type() const {
+    ICHECK(args_and_result_.empty());
+    return device_type_;
+  }
+
+  size_t function_arity() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.size() - 1UL;
+  }
+
+  DeviceDomainPtr function_param(size_t i) const {
+    ICHECK(!args_and_result_.empty());
+    ICHECK_LT(i + 1, args_and_result_.size());
+    return args_and_result_[i];
+  }
+
+  DeviceDomainPtr function_result() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.back();
+  }
+
+ private:
+  /*!
+   * \brief If this is a function domain then always kInvalidDevice. Otherwise 
will be
+   * kInvalidDevice if the domain is still free, or the specific concrete 
device if the domain is
+   * bound.
+   */
+  const DLDeviceType device_type_;
+
+  /*!
+   * \brief If this is a function domain then the sub-domains for each of the 
function's
+   * arguments, and the domain for its result. Otherwise empty.
+   */
+  const std::vector<DeviceDomainPtr> args_and_result_;
+
+  friend struct DeviceDomainHash;
+  friend struct DeviceDomainEqual;
+  friend class DeviceDomains;
+};
+
+// Ye olde boost hash mixer.
+constexpr size_t mix(size_t h1, size_t h2) {
+  return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+}
+
+// The following hash and equality helpers give each free first-order domain 
pointer its own
+// distinct identity.
+struct DeviceDomainHash {
+  size_t operator()(const DeviceDomainPtr& domain) const {
+    if (domain->is_free()) {
+      // Give each free first-order domain its own identity.
+      return static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get()));
+    } else {
+      size_t h = domain->args_and_result_.size();
+      h = mix(h, std::hash<int>()(static_cast<int>(domain->device_type_)));
+      for (const auto& sub_domain_ptr : domain->args_and_result_) {
+        h = mix(h, DeviceDomainHash()(sub_domain_ptr));
+      }
+      return h;
+    }
+  }
+};
+
+struct DeviceDomainEqual {
+ public:
+  bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) 
const {
+    if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) {
+      // Mismatched arities are never equal.
+      // (Though we'll never ask to do such a comparison explicitly, the hash 
map
+      // may do so implicitly due to hash collisions.)
+      return false;
+    }
+    if (lhs->is_free() && rhs->is_free()) {
+      // Compare first-order free domains by their address.
+      return lhs.get() == rhs.get();
+    }
+    if (lhs->args_and_result_.empty()) {
+      // Compare first-order domains by their device type -- free vs bound 
will compare as false.
+      return lhs->device_type_ == rhs->device_type_;
+    } else {
+      // Compare higher-order domains pointwise.
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) {
+          return false;
+        }
+      }
+      return true;
+    }
+  }
+};
+
+/*!
+ * \brief Tracks the device domains for a set of expressions w.r.t. an 
equivalence relation
+ * built up by calls to \p Unify.
+ */
+class DeviceDomains {
+ public:
+  DeviceDomains() = default;
+
+  /*!
+   * \brief Returns a domain appropriate for \p type who's result domain is 
bound
+   * to \p device_type. If \p device_type is \p kInvalidDeviceType then the 
entire domain
+   * will be free.
+   */
+  static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType 
device_type) {
+    if (const auto* func_type_node = type.as<FuncTypeNode>()) {
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(func_type_node->arg_types.size() + 1);
+      for (const auto& arg_type : func_type_node->arg_types) {
+        args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType));
+      }
+      args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, 
device_type));
+      return std::make_shared<DeviceDomain>(std::move(args_and_result));
+    } else {
+      return std::make_shared<DeviceDomain>(device_type);
+    }
+  }
+
+  /*!
+   * \brief Returns a higher-order domain with \p args_and_results.
+   */
+  static DeviceDomainPtr MakeDomain(std::vector<DeviceDomainPtr> 
arg_and_results) {
+    return std::make_shared<DeviceDomain>(std::move(arg_and_results));
+  }
+
+  /*! \brief Returns a domain with the given result device type appropriate \p 
device_type. */
+  static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType 
device_type) {
+    ICHECK_NE(device_type, kInvalidDeviceType);
+    return MakeDomain(type, device_type);
+  }
+
+  /*! \brief Returns a free domain appropriate for \p type. */
+  static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, 
kInvalidDeviceType); }
+
+  /*! \brief Returns the domain representing the equivalence class containing 
\p domain. */
+  DeviceDomainPtr Lookup(DeviceDomainPtr domain) {
+    DeviceDomainPtr root = domain;
+    while (true) {
+      auto itr = domain_to_equiv_.find(root);
+      if (itr == domain_to_equiv_.end()) {
+        break;
+      }
+      ICHECK_NE(itr->second, root);
+      root = itr->second;
+      ICHECK_NOTNULL(root);
+    }
+    // Path compression.
+    while (domain != root) {
+      auto itr = domain_to_equiv_.find(domain);
+      ICHECK(itr != domain_to_equiv_.end());
+      domain = itr->second;
+      ICHECK_NOTNULL(domain);
+      itr->second = root;
+    }
+    return root;
+  }
+
+  /*!
+   * \brief Returns the domain accounting for all bound devices in \p lhs and 
\p rhs.
+   *
+   * Throws \p Error on failure.
+   */
+  DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) 
{
+    // TODO(mbs): Proper diagnostics.
+    ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size())
+        << "Device domains:" << std::endl
+        << ToString(lhs) << std::endl
+        << "and" << std::endl
+        << ToString(rhs) << std::endl
+        << "do not have the same kind and can't be unified.";
+    if (rhs->is_free()) {
+      return lhs;
+    } else if (lhs->is_free()) {
+      return rhs;
+    } else if (lhs->args_and_result_.empty()) {
+      // Must have consistent device types for first order domains.
+      if (lhs->device_type_ != rhs->device_type_) {
+        // TODO(mbs): Proper diagnostics.
+        std::ostringstream os;
+        os << "Inconsistent device types " << lhs->device_type_ << " and " << 
rhs->device_type_;
+        throw Error(os.str());
+      }
+      return lhs;
+    } else {
+      // Recurse for higher-order.
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(lhs->args_and_result_.size());
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        args_and_result.emplace_back(Unify(lhs->args_and_result_[i], 
rhs->args_and_result_[i]));
+      }
+      return MakeDomain(std::move(args_and_result));
+    }
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. 
Fails if \p lhs and \p
+   * rhs disagree on bound device type.
+   *
+   * Throws \p Error on failure.
+   */
+  // TODO(mbs): I don't think we need an occurs check since the program is 
well-typed, but
+  // given we have refs to functions I'm prepared to be surprised.
+  DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) {
+    lhs = Lookup(lhs);
+    rhs = Lookup(rhs);
+    auto joined_domain = Join(lhs, rhs);
+    if (!DeviceDomainEqual()(lhs, joined_domain)) {
+      domain_to_equiv_.emplace(lhs, joined_domain);
+    }
+    if (!DeviceDomainEqual()(rhs, joined_domain)) {
+      domain_to_equiv_.emplace(rhs, joined_domain);
+    }
+    return joined_domain;
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is 
higher-order,
+   * require all arguments and result of \p rhs to unify with \p lhs. 
Otherwise same as
+   * \p Unify.
+   *
+   * Throws \p Error on failure.
+   */
+  void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    if (!lhs->is_higher_order() && rhs->is_higher_order()) {
+      Collapse(lhs, rhs);
+    } else {
+      Unify(lhs, rhs);
+    }
+  }
+
+  /*! \brief Returns true if a domain is known for \p expr. */
+  bool contains(const Expr& expr) const { return 
expr_to_domain_.count(expr.get()); }
+
+  /*! \brief Returns the domain representing \p expr. */
+  DeviceDomainPtr DomainFor(const Expr& expr) {
+    ICHECK(expr.defined());
+    auto itr = expr_to_domain_.find(expr.get());
+    if (itr != expr_to_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    auto domain = Free(expr->checked_type());
+    expr_to_domain_.emplace(expr.get(), domain);
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the domain representing the callee (ie 'op') in \p call 
expression. If the
+   * callee is a primitive or special operation we handle it specially. 
Otherwise defers to \p
+   * DomainFor(call->op).
+   *
+   * This special handling is needed:

Review comment:
       Agree - this at least gets things into a state where we could do that, 
or allow a table to be passed in, or something. Most of this is a large 
refactor from the VisitExpr_(CallNode*) in context_analyis.cc, but without 
needing to repeat all the args handling. I was actually happy with that cleanup 
and want to lock that in :-)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to