You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/24 18:20:02 UTC

[GitHub] [tvm] areusch commented on a change in pull request #9038: [Relay] Merge analysis/context_analysis.cc and transforms/device_annotation.cc

areusch commented on a change in pull request #9038:
URL: https://github.com/apache/tvm/pull/9038#discussion_r715801087



##########
File path: include/tvm/relay/attrs/annotation.h
##########
@@ -32,14 +32,55 @@ namespace tvm {
 namespace relay {
 
 /*!
- * \brief Options for the device annotation operators.
+ * \brief Attributes for the "on_device" operator.
+ *
+ * The relay call
+ * \code
+ *   on_device(expr, device_type=2)
+ * \endcode
+ * denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2
+ * (i.e. \p kDLCuda). Semantically the operator is the identity function.
  */
 struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
+  // TODO(mbs): Replace device types with TargetDevice.
+  /*! \brief Device type on which argument expression should be evaluated. */
   int device_type;
+  /*!
+   * \brief If true, the result device must also be \p device_type and device planning should

Review comment:
       not sure i quite understand--isn't this what the struct-level brief above says it does? what happens if is_fixed == false? if we are going to insert a device copy node, that implies the data must be needed elsewhere. how can we merely attach an attribute to a result saying a copy is simply out of the question?

##########
File path: python/tvm/relay/op/annotation/annotation.py
##########
@@ -33,21 +43,26 @@ def on_device(data, device):
     device : Union[:py:class:`Device`, str]
         The device type to annotate.
 
+    is_fixed : bool
+        If true, annotation does not imply a device_copy may be inserted.
+        (This parameter is used internally by the compiler and unit tests and
+        should not need to be set in user programs.)
+
     Returns
     -------
     result : tvm.relay.Expr
         The annotated expression.
     """
-    if isinstance(device, _Device):
-        device = device.device_type
-    elif isinstance(device, str):
-        device = _nd.device(device).device_type
-    else:
-        raise ValueError(
-            "device is expected to be the type of Device or "
-            "str, but received %s" % (type(device))
-        )
-    return _make.on_device(data, device)
+    return _make.on_device(data, _device_to_int(device), is_fixed)
+
+
+# for testing only
+def function_on_device(function, param_devices, result_device):

Review comment:
       should we call this testonly_function_on_device then or place in `python/tvm/relay/op/annotation/test_utils.py`?

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written

Review comment:
       what's the difference here? are you saying we assume all results are tuples of length 1?

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
+ * same device. However each call site can use a different device. In other words primitives are
+ * 'device polymorphic' since we compile and execute them for each required device.
+ *
+ * For most Relay expressions the device for the overall expression is the same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on different devices, even
+ *    different from each other. Every call to the function must use the same choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default device.
+ *  - Unconstrained function parameters devices default to the device for the function result.
+ *  - Unconstrained let-bound expression devices default to the device for the overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to
+ * the global default device. Worth a design doc with motivating examples I think.
+ *
+ * Phase 3
+ * -------
+ * Finally, the result of this analysis is reified into the result as:
+ *  - Additional "on_device" attributes (an Attrs resolving to a \p FunctionOnDeviceAttrs) for
+ *    every function (both top-level and local). These describe the devices for the function's
+ *    parameters and the result.
+ *  - Additional "device_copy" CallNodes where a copy is required in order to respect the
+ *    intent of the original "on_device" CallNodes.
+ *  - Additional "on_device" CallNodes where the device type of an expression does not match
+ *    that of the lexically enclosing "on_device" CallNode or function attribute. In practice
+ *    this means "on_device" CallNodes may appear in two places:
+ *     - On a let-bound expression if its device differs from the overall let expression.
+ *     - On a call argument if its device differs from the call result. In particular, the
+ *       argument to a "device_copy" call will always be wrapped in an "on_device". (That may
+ *       seem pedantic but simplifies downstream handling.)
+ *    However since we make it easy to track devices for variables we never wrap an "on_device"
+ *    around a var or global var. These uses of "on_device" imply both the argument and result are
+ *    on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true,
+ *    which helps make this pass idempotent.
+ *
+ * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms to recover the device
+ * for any expression for their own use, e.g. during memory planning. All downstream passes must
+ * preserve the lexical scoping of the "on_device" CallNodes. In particular conversion to ANF
+ * must respect the lexical scoping convention:
+ * \code
+ * f(on_device(g(h(a, b), c), device_type=CPU))
+ * ==>
+ * let %x0 = on_device(h(a, b), device_type=CPU)
+ * let %x1 = on_device(g(%x0), device-type=CPU)
+ * f(on_device(%x1, device_type=CPU))
+ * \endcode
+ *
+ * This pass should be run before FuseOps it can use device-specific fusion rules.
+ *
+ * 'Stored on' vs 'Executes on'
+ * ----------------------------
+ * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the
+ * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for
+ * primitives.
+ *
+ * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are
+ * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific
+ * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to
+ * know exactly which device (possibly one of a number of available 'CPU'-like devices) is
+ * responsible for execution. Currently that's handled independently by the \p AnnotateTargets
+ * pass, but we'd like to fold that into device planning here to ensure everything is consistent.
+ *
+ * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay
+ * expression (eg an if expression) on one device even though the tensor data resides on
+ * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on'
+ * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just
+ * compile the function body for the function's result device.
+ *
+ * This works after conversion to ANF provided the compilation for a let expression is prepared
+ * to make a cross-device call. However we leave it to a downstream transformation to heuristically
+ * minimize cross-device calls by moving device copies out of functions. E.g.:
+ * \code
+ *   def @f() {  // execute on CPU
+ *     let x = on_device(...GPU computation..., device_type=GPU);
+ *     device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
+ *   }
+ *   def @main() {
+ *     ... call @f() on CPU ...
+ *   }
+ * \endcode
+ * could be rewritten to:
+ * \code
+ *   def @f() {  // execute on GPU
+ *     let x = ...GPU computation...;
+ *     ...GPU computation...
+ *   }
+ *   def @main() {
+ *     let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
+ *     ... use x on CPU ...
+ *   }
+ * \endcode
+ *
+ * Higher-order shenanigans
+ * ------------------------
+ * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions
+ * as arguments (even anonymous functions), return functions, evaluate conditional expressions
+ * over functions, and so on. We handle this during constraint solving using the domain:
+ * \code
+ *   D  ::= <specific device type>   -- first-order
+ *        | fn(D,...,D):D            -- higher-order
+ * \endcode
+ * In this way we can determine the device for all function parameters and results. E.g. for
+ * \code
+ *   let f = fn(x, y) { ... }
+ *   let g = fn(f, z) { f(z, z) }
+ *   g(f, on_device(..., device_type=CPU))
+ * \endcode
+ * the parameters \p x and \p y will be on the CPU.
+ *
+ * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a
+ * function. Our analysis must guarantee that the function's parameters and result devices are
+ * consistent for \p e2, \p e3, and the context of the call. But:
+ *  - Which device holds the closure result of evaluating \p e1 ?
+ *  - If \p e2 is of function type, what does that mean when we say every function parameter
+ *    is on a device?
+ *  - If \p e1 returns a function, what does that mean when we say every function result is
+ *    on a device?
+ *
+ * Since higher-order aspects are later compiled away (by 'defunctionalization'
+ * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular,
+ * we really don't want our domain \p D to allow for yet another device for the function closure.
+ * So we'll just force the 'device for a function' to be the same as the device for the function's
+ * result using the notion of the 'result domain' for a domain:
+ * \code
+ *   result_domain(<specific device type>) = <specific device type>
+ *   result_domain(fn(D1,...,Dn):Dr)       = result_domain(Dr)
+ * \endcode
+ *
+ * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the
+ * analysis encounters a function inside one of those it simply forces all argument and result
+ * devices for the function to match the device for the first-order expression. For example,
+ * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function
+ * parameters and result must similarly be on the GPU.
+ *
+ * -------
+ * | AOR |  This pass supports all of Relay.
+ * -------
+ *    ^
+ *    |
+ *    `-- Mark's stamp of completeness :-)
+ *
+ * TODO(mbs):
+ *  * Though on_device is the identity for all types we can't wrap it around functions/constructors
+ *    taking type args (or at least not without changing type_infer.cc to see through them).
+ *    This is not currently handled generally.
+ *  * Proper diagnostics for unification failure using spans.
+ *  * Make sure the pass is idempotent even after FuseOps etc.
+ *  * Support application of constructors properly. Are they device polymorphic?
+ *  * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'.
+ *  * Support running the pass post FuseOps (so need to understand primitive functions, both
+ *    outlines and lined) and post the VM transforms (probably need to support more intrinsic
+ *    forms?).
+ *  * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default
+ *    device for primitives vs the default device for the rest of Relay.
+ *  * We'll probably need some support for partial 'device polymorphism' for functions once we
+ *    incorporate targets and memory scopes into the domain. For example it's ok for the function
+ *    body to be executed on different device ids provided they have the same target and memory
+ *    scope.
+ *  * Might be simpler to just let every type have a device annotation rather than work in
+ *    a separate domain?
+ *  * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies.
+ *  * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls
+ *    in tuples at the top level of function bodies or main expression, irrespective of the
+ *    "on_device" body. What's up with that?
+ */
+
+#include "./device_planner.h"
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+
+#include <unordered_map>
+
+#include "../op/annotation/annotation.h"
+#include "../op/memory/device_copy.h"
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+namespace {
+
+/*!
+ * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather
+ * than the original "device_copy" operator.
+ *
+ * See te_compiler.cc for where this rewriting occurs.
+ */
+DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
+  auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
+  if (tir_call_attrs == nullptr) {
+    return {};
+  }
+  if (tir_call_attrs->metadata.count("source_device") != 1 ||
+      tir_call_attrs->metadata.count("dst_device") != 1) {
+    return {};
+  }
+  ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
+  return {
+      call_node->args[0],
+      static_cast<DLDeviceType>(
+          Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
+      static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
+}
+
+class DeviceDomain;
+using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
+
+/******
+****** Domains
+******/
+
+/*!
+ * \brief Represents the domain over which we collect equality constraints.
+ *
+ * \code
+ *   D ::= ?x?                  -- first order, free
+ *       | <device_type>        -- first order, bound
+ *       | fn(D1, ..., Dn):Dr   -- higher order
+ * \endcode
+ *
+ * We require a function value to be on the same device as its result. To support that we need
+ * a notion of the 'result domain' of a domain:
+ * \code
+ *   result_domain(?x?)                = ?x?
+ *   result_domain(<device_type>)      = <device_type>
+ *   result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr)
+ * \endcode
+ */
+class DeviceDomain {
+ public:
+  /*!
+   * \brief Constructs a first-order domain of \p device_type, which may be
+   * \p kInvalidDeviceType to indicate the domain is free.
+   */
+  explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {}
+
+  /*!
+   * \brief Constructs a higher-order domain, where \p args_and_result contain the
+   * function argument and result domains in order.
+   */
+  explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result)
+      : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {}
+
+  /*! \brief Returns true if domain is first-order and free. */
+  bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); }
+
+  /*! \brief Returns true if domain is higher-order. */
+  bool is_higher_order() const { return !args_and_result_.empty(); }
+
+  DLDeviceType first_order_device_type() const {
+    ICHECK(args_and_result_.empty());
+    return device_type_;
+  }
+
+  size_t function_arity() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.size() - 1UL;
+  }
+
+  DeviceDomainPtr function_param(size_t i) const {
+    ICHECK(!args_and_result_.empty());
+    ICHECK_LT(i + 1, args_and_result_.size());
+    return args_and_result_[i];
+  }
+
+  DeviceDomainPtr function_result() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.back();
+  }
+
+ private:
+  /*!
+   * \brief If this is a function domain then always kInvalidDevice. Otherwise will be
+   * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is
+   * bound.
+   */
+  const DLDeviceType device_type_;
+
+  /*!
+   * \brief If this is a function domain then the sub-domains for each of the function's
+   * arguments, and the domain for its result. Otherwise empty.
+   */
+  const std::vector<DeviceDomainPtr> args_and_result_;
+
+  friend struct DeviceDomainHash;
+  friend struct DeviceDomainEqual;
+  friend class DeviceDomains;
+};
+
+// Ye olde boost hash mixer.
+constexpr size_t mix(size_t h1, size_t h2) {
+  return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+}
+
+// The following hash and equality helpers give each free first-order domain pointer its own
+// distinct identity.
+struct DeviceDomainHash {
+  size_t operator()(const DeviceDomainPtr& domain) const {
+    if (domain->is_free()) {
+      // Give each free first-order domain its own identity.
+      return static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get()));
+    } else {
+      size_t h = domain->args_and_result_.size();
+      h = mix(h, std::hash<int>()(static_cast<int>(domain->device_type_)));
+      for (const auto& sub_domain_ptr : domain->args_and_result_) {
+        h = mix(h, DeviceDomainHash()(sub_domain_ptr));
+      }
+      return h;
+    }
+  }
+};
+
+struct DeviceDomainEqual {
+ public:
+  bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const {
+    if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) {
+      // Mismatched arities are never equal.
+      // (Though we'll never ask to do such a comparison explicitly, the hash map
+      // may do so implicitly due to hash collisions.)
+      return false;
+    }
+    if (lhs->is_free() && rhs->is_free()) {
+      // Compare first-order free domains by their address.
+      return lhs.get() == rhs.get();
+    }
+    if (lhs->args_and_result_.empty()) {
+      // Compare first-order domains by their device type -- free vs bound will compare as false.
+      return lhs->device_type_ == rhs->device_type_;
+    } else {
+      // Compare higher-order domains pointwise.
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) {
+          return false;
+        }
+      }
+      return true;
+    }
+  }
+};
+
+/*!
+ * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation
+ * built up by calls to \p Unify.
+ */
+class DeviceDomains {
+ public:
+  DeviceDomains() = default;
+
+  /*!
+   * \brief Returns a domain appropriate for \p type who's result domain is bound
+   * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain
+   * will be free.
+   */
+  static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type) {
+    if (const auto* func_type_node = type.as<FuncTypeNode>()) {
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(func_type_node->arg_types.size() + 1);
+      for (const auto& arg_type : func_type_node->arg_types) {
+        args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType));
+      }
+      args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type));
+      return std::make_shared<DeviceDomain>(std::move(args_and_result));
+    } else {
+      return std::make_shared<DeviceDomain>(device_type);
+    }
+  }
+
+  /*!
+   * \brief Returns a higher-order domain with \p args_and_results.
+   */
+  static DeviceDomainPtr MakeDomain(std::vector<DeviceDomainPtr> arg_and_results) {
+    return std::make_shared<DeviceDomain>(std::move(arg_and_results));
+  }
+
+  /*! \brief Returns a domain with the given result device type appropriate \p device_type. */
+  static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) {
+    ICHECK_NE(device_type, kInvalidDeviceType);
+    return MakeDomain(type, device_type);
+  }
+
+  /*! \brief Returns a free domain appropriate for \p type. */
+  static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); }
+
+  /*! \brief Returns the domain representing the equivalence class containing \p domain. */
+  DeviceDomainPtr Lookup(DeviceDomainPtr domain) {
+    DeviceDomainPtr root = domain;
+    while (true) {
+      auto itr = domain_to_equiv_.find(root);
+      if (itr == domain_to_equiv_.end()) {
+        break;
+      }
+      ICHECK_NE(itr->second, root);
+      root = itr->second;
+      ICHECK_NOTNULL(root);
+    }
+    // Path compression.
+    while (domain != root) {
+      auto itr = domain_to_equiv_.find(domain);
+      ICHECK(itr != domain_to_equiv_.end());
+      domain = itr->second;
+      ICHECK_NOTNULL(domain);
+      itr->second = root;
+    }
+    return root;
+  }
+
+  /*!
+   * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs.
+   *
+   * Throws \p Error on failure.
+   */
+  DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    // TODO(mbs): Proper diagnostics.
+    ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size())
+        << "Device domains:" << std::endl
+        << ToString(lhs) << std::endl
+        << "and" << std::endl
+        << ToString(rhs) << std::endl
+        << "do not have the same kind and can't be unified.";
+    if (rhs->is_free()) {
+      return lhs;
+    } else if (lhs->is_free()) {
+      return rhs;
+    } else if (lhs->args_and_result_.empty()) {
+      // Must have consistent device types for first order domains.
+      if (lhs->device_type_ != rhs->device_type_) {
+        // TODO(mbs): Proper diagnostics.
+        std::ostringstream os;
+        os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_;

Review comment:
       shall we print more details somehow, like the sub-expression, or use Jared's expression annnotating infra? or maybe this does and i'm just dense :)

##########
File path: tests/python/relay/test_pass_plan_devices.py
##########
@@ -0,0 +1,1405 @@
+# Licensed to the Apache Software Foundation (ASF) under one

Review comment:
       +1 for relay text format, i think that would greatly help

##########
File path: src/relay/op/annotation/annotation.h
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/attrs/annotation.h
+ * \brief Helpers for working with various 'annotation' attributes.
+ */
+#ifndef TVM_RELAY_OP_ANNOTATION_ANNOTATION_H_
+#define TVM_RELAY_OP_ANNOTATION_ANNOTATION_H_
+
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+#include <tvm/runtime/ndarray.h>
+
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Returns the "on_device" operator. */
+const Op& OnDeviceOp();
+
+/*!
+ * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed.
+ */
+Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed);
+
+/*!
+ * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. However
+ * returns \p expr directly if:
+ *  - \p device_type is \p kInvalidDeviceType, which signals there are no device annotations
+ *    already in play.
+ *  - \p expr is an operator or primitive function literal. These are device polymorphic.
+ *  - \p expr is a global or local var. These already have an implied device.
+ *  - \p expr is a constructor. There should probably be device polymorphic but are in an
+ *    in-between state at the moment.
+ */
+Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed);
+
+/*! \brief Result of \p GetOnDeviceProps. */
+struct OnDeviceProps {
+  Expr body;  // = null
+  DLDeviceType device_type = kInvalidDeviceType;
+  bool is_fixed = false;
+
+  OnDeviceProps() = default;
+
+  OnDeviceProps(const Expr& body, DLDeviceType deviceType, bool isFixed)
+      : body(body), device_type(deviceType), is_fixed(isFixed) {}
+};
+
+/*!
+ * \brief Returns the body expression, device type and is_fixed field for \p call_node if it is
+ * an "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p
+ * false.
+ */
+OnDeviceProps GetOnDeviceProps(const CallNode* call_node);
+
+/*!
+ * \brief Returns the body expression, device type and is_fixed field for \p expr if it is an
+ * "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p false.
+ */
+OnDeviceProps GetOnDeviceProps(const Expr& expr);
+
+/*! \brief Returns true if \p expr is an on_device CallNode. */

Review comment:
       this will CHECK-fail if there is an ill-formed on_device CallNode tho, not sure if that should get documented?

##########
File path: src/relay/op/annotation/annotation.cc
##########
@@ -56,12 +83,98 @@ RELAY_REGISTER_OP("on_device")
     .set_attr<TOpPattern>("TOpPattern", kOpaque)
     .set_attr<TOpIsStateful>("TOpIsStateful", false)
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attr<TNonComputational>("TNonComputational", true)
     .set_attr<FTVMCompute>("FTVMCompute",
                            [](const Attrs& attrs, const Array<te::Tensor>& inputs,
                               const Type& out_type) -> Array<te::Tensor> {
                              return {topi::identity(inputs[0])};
                            });
 
+OnDeviceProps GetOnDeviceProps(const CallNode* call_node) {
+  if (call_node->op == OnDeviceOp()) {
+    ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument";

Review comment:
       want to print the args or op node in this case to help debug? same question below

##########
File path: src/relay/op/memory/device_copy.cc
##########
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/op/memory/device_copy.cc
+ * \brief Helpers for working with "device_copy" attributes.
+ */
+
+#include "./device_copy.h"
+
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/topi/elemwise.h>
+
+#include "../../transforms/infer_layout_utils.h"
+#include "../type_relations.h"
+
+namespace tvm {
+namespace relay {
+
+// relay.device_copy
+TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);
+
+const Op& DeviceCopyOp() {
+  static const Op& op = Op::Get("device_copy");
+  return op;
+}
+
+Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) {
+  auto attrs = make_object<DeviceCopyAttrs>();
+  attrs->src_dev_type = src_dev_type;
+  attrs->dst_dev_type = dst_dev_type;
+  Span span = expr->span;
+  return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(attrs), /*type_args=*/{}, span);
+}
+
+Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) {
+  if (src_dev_type == dst_dev_type) {
+    return expr;
+  }
+  ICHECK_NE(src_dev_type, kInvalidDeviceType);
+  ICHECK_NE(dst_dev_type, kInvalidDeviceType);
+  return DeviceCopy(expr, src_dev_type, dst_dev_type);
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.device_copy")
+    .set_body_typed([](Expr expr, int src_dev_type, int dst_dev_type) {
+      return DeviceCopy(expr, static_cast<DLDeviceType>(src_dev_type),
+                        static_cast<DLDeviceType>(dst_dev_type));
+    });
+
+RELAY_REGISTER_OP("device_copy")
+    .describe(R"code(
+Copy data from one tensor to another. The source and destination might be

Review comment:
       interesting...is this internal use only i guess til TargetDevice lands? should we annotate docs as such?

##########
File path: src/relay/op/annotation/annotation.cc
##########
@@ -36,15 +38,40 @@
 namespace tvm {
 namespace relay {
 
-// relay.annotation.on_device
 TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);
 
+const Op& OnDeviceOp() {
+  static const Op& op = Op::Get("on_device");
+  return op;
+}
+
+Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) {
+  auto attrs = make_object<OnDeviceAttrs>();
+  attrs->device_type = device_type;
+  attrs->is_fixed = is_fixed;
+  Span span = expr->span;
+  return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, span);
+}
+
+Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) {

Review comment:
       opt=Optional?

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
+ * same device. However each call site can use a different device. In other words primitives are
+ * 'device polymorphic' since we compile and execute them for each required device.
+ *
+ * For most Relay expressions the device for the overall expression is the same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on different devices, even
+ *    different from each other. Every call to the function must use the same choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default device.
+ *  - Unconstrained function parameters devices default to the device for the function result.
+ *  - Unconstrained let-bound expression devices default to the device for the overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to
+ * the global default device. Worth a design doc with motivating examples I think.
+ *
+ * Phase 3
+ * -------
+ * Finally, the result of this analysis is reified into the result as:
+ *  - Additional "on_device" attributes (an Attrs resolving to a \p FunctionOnDeviceAttrs) for
+ *    every function (both top-level and local). These describe the devices for the function's
+ *    parameters and the result.
+ *  - Additional "device_copy" CallNodes where a copy is required in order to respect the
+ *    intent of the original "on_device" CallNodes.
+ *  - Additional "on_device" CallNodes where the device type of an expression does not match
+ *    that of the lexically enclosing "on_device" CallNode or function attribute. In practice
+ *    this means "on_device" CallNodes may appear in two places:
+ *     - On a let-bound expression if its device differs from the overall let expression.
+ *     - On a call argument if its device differs from the call result. In particular, the
+ *       argument to a "device_copy" call will always be wrapped in an "on_device". (That may
+ *       seem pedantic but simplifies downstream handling.)
+ *    However since we make it easy to track devices for variables we never wrap an "on_device"
+ *    around a var or global var. These uses of "on_device" imply both the argument and result are
+ *    on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true,
+ *    which helps make this pass idempotent.
+ *
+ * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms to recover the device
+ * for any expression for their own use, e.g. during memory planning. All downstream passes must
+ * preserve the lexical scoping of the "on_device" CallNodes. In particular conversion to ANF
+ * must respect the lexical scoping convention:
+ * \code
+ * f(on_device(g(h(a, b), c), device_type=CPU))
+ * ==>
+ * let %x0 = on_device(h(a, b), device_type=CPU)
+ * let %x1 = on_device(g(%x0), device-type=CPU)
+ * f(on_device(%x1, device_type=CPU))
+ * \endcode
+ *
+ * This pass should be run before FuseOps it can use device-specific fusion rules.
+ *
+ * 'Stored on' vs 'Executes on'
+ * ----------------------------
+ * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the
+ * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for
+ * primitives.
+ *
+ * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are
+ * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific
+ * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to
+ * know exactly which device (possibly one of a number of available 'CPU'-like devices) is
+ * responsible for execution. Currently that's handled independently by the \p AnnotateTargets
+ * pass, but we'd like to fold that into device planning here to ensure everything is consistent.
+ *
+ * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay
+ * expression (eg an if expression) on one device even though the tensor data resides on
+ * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on'
+ * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just
+ * compile the function body for the function's result device.
+ *
+ * This works after conversion to ANF provided the compilation for a let expression is prepared
+ * to make a cross-device call. However we leave it to a downstream transformation to heuristically
+ * minimize cross-device calls by moving device copies out of functions. E.g.:
+ * \code
+ *   def @f() {  // execute on CPU
+ *     let x = on_device(...GPU computation..., device_type=GPU);
+ *     device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
+ *   }
+ *   def @main() {
+ *     ... call @f() on CPU ...
+ *   }
+ * \endcode
+ * could be rewritten to:
+ * \code
+ *   def @f() {  // execute on GPU
+ *     let x = ...GPU computation...;
+ *     ...GPU computation...
+ *   }
+ *   def @main() {
+ *     let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
+ *     ... use x on CPU ...
+ *   }
+ * \endcode
+ *
+ * Higher-order shenanigans
+ * ------------------------
+ * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions
+ * as arguments (even anonymous functions), return functions, evaluate conditional expressions
+ * over functions, and so on. We handle this during constraint solving using the domain:
+ * \code
+ *   D  ::= <specific device type>   -- first-order
+ *        | fn(D,...,D):D            -- higher-order
+ * \endcode
+ * In this way we can determine the device for all function parameters and results. E.g. for
+ * \code
+ *   let f = fn(x, y) { ... }
+ *   let g = fn(f, z) { f(z, z) }
+ *   g(f, on_device(..., device_type=CPU))
+ * \endcode
+ * the parameters \p x and \p y will be on the CPU.
+ *
+ * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a
+ * function. Our analysis must guarantee that the function's parameters and result devices are
+ * consistent for \p e2, \p e3, and the context of the call. But:
+ *  - Which device holds the closure result of evaluating \p e1 ?
+ *  - If \p e2 is of function type, what does that mean when we say every function parameter
+ *    is on a device?
+ *  - If \p e1 returns a function, what does that mean when we say every function result is
+ *    on a device?
+ *
+ * Since higher-order aspects are later compiled away (by 'defunctionalization'
+ * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular,
+ * we really don't want our domain \p D to allow for yet another device for the function closure.
+ * So we'll just force the 'device for a function' to be the same as the device for the function's
+ * result using the notion of the 'result domain' for a domain:
+ * \code
+ *   result_domain(<specific device type>) = <specific device type>
+ *   result_domain(fn(D1,...,Dn):Dr)       = result_domain(Dr)
+ * \endcode
+ *
+ * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the
+ * analysis encounters a function inside one of those it simply forces all argument and result
+ * devices for the function to match the device for the first-order expression. For example,
+ * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function
+ * parameters and result must similarly be on the GPU.
+ *
+ * -------
+ * | AOR |  This pass supports all of Relay.
+ * -------
+ *    ^
+ *    |
+ *    `-- Mark's stamp of completeness :-)

Review comment:
       😂  will this become a new lint check?

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
+ * same device. However each call site can use a different device. In other words primitives are
+ * 'device polymorphic' since we compile and execute them for each required device.
+ *
+ * For most Relay expressions the device for the overall expression is the same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on different devices, even
+ *    different from each other. Every call to the function must use the same choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default device.
+ *  - Unconstrained function parameters devices default to the device for the function result.
+ *  - Unconstrained let-bound expression devices default to the device for the overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to
+ * the global default device. Worth a design doc with motivating examples I think.
+ *
+ * Phase 3
+ * -------
+ * Finally, the result of this analysis is reified into the result as:
+ *  - Additional "on_device" attributes (an Attrs resolving to a \p FunctionOnDeviceAttrs) for
+ *    every function (both top-level and local). These describe the devices for the function's
+ *    parameters and the result.
+ *  - Additional "device_copy" CallNodes where a copy is required in order to respect the
+ *    intent of the original "on_device" CallNodes.
+ *  - Additional "on_device" CallNodes where the device type of an expression does not match
+ *    that of the lexically enclosing "on_device" CallNode or function attribute. In practice
+ *    this means "on_device" CallNodes may appear in two places:
+ *     - On a let-bound expression if its device differs from the overall let expression.
+ *     - On a call argument if its device differs from the call result. In particular, the
+ *       argument to a "device_copy" call will always be wrapped in an "on_device". (That may
+ *       seem pedantic but simplifies downstream handling.)
+ *    However since we make it easy to track devices for variables we never wrap an "on_device"
+ *    around a var or global var. These uses of "on_device" imply both the argument and result are
+ *    on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true,
+ *    which helps make this pass idempotent.
+ *
+ * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms to recover the device
+ * for any expression for their own use, e.g. during memory planning. All downstream passes must
+ * preserve the lexical scoping of the "on_device" CallNodes. In particular conversion to ANF
+ * must respect the lexical scoping convention:
+ * \code
+ * f(on_device(g(h(a, b), c), device_type=CPU))
+ * ==>
+ * let %x0 = on_device(h(a, b), device_type=CPU)
+ * let %x1 = on_device(g(%x0), device-type=CPU)
+ * f(on_device(%x1, device_type=CPU))
+ * \endcode
+ *
+ * This pass should be run before FuseOps it can use device-specific fusion rules.
+ *
+ * 'Stored on' vs 'Executes on'
+ * ----------------------------
+ * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the
+ * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for
+ * primitives.
+ *
+ * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are
+ * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific
+ * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to
+ * know exactly which device (possibly one of a number of available 'CPU'-like devices) is
+ * responsible for execution. Currently that's handled independently by the \p AnnotateTargets
+ * pass, but we'd like to fold that into device planning here to ensure everything is consistent.
+ *
+ * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay
+ * expression (eg an if expression) on one device even though the tensor data resides on
+ * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on'
+ * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just
+ * compile the function body for the function's result device.
+ *
+ * This works after conversion to ANF provided the compilation for a let expression is prepared
+ * to make a cross-device call. However we leave it to a downstream transformation to heuristically
+ * minimize cross-device calls by moving device copies out of functions. E.g.:
+ * \code
+ *   def @f() {  // execute on CPU
+ *     let x = on_device(...GPU computation..., device_type=GPU);
+ *     device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
+ *   }
+ *   def @main() {
+ *     ... call @f() on CPU ...
+ *   }
+ * \endcode
+ * could be rewritten to:
+ * \code
+ *   def @f() {  // execute on GPU
+ *     let x = ...GPU computation...;
+ *     ...GPU computation...
+ *   }
+ *   def @main() {
+ *     let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
+ *     ... use x on CPU ...
+ *   }
+ * \endcode
+ *
+ * Higher-order shenanigans
+ * ------------------------
+ * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions
+ * as arguments (even anonymous functions), return functions, evaluate conditional expressions
+ * over functions, and so on. We handle this during constraint solving using the domain:
+ * \code
+ *   D  ::= <specific device type>   -- first-order
+ *        | fn(D,...,D):D            -- higher-order
+ * \endcode
+ * In this way we can determine the device for all function parameters and results. E.g. for
+ * \code
+ *   let f = fn(x, y) { ... }
+ *   let g = fn(f, z) { f(z, z) }
+ *   g(f, on_device(..., device_type=CPU))
+ * \endcode
+ * the parameters \p x and \p y will be on the CPU.
+ *
+ * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a
+ * function. Our analysis must guarantee that the function's parameters and result devices are
+ * consistent for \p e2, \p e3, and the context of the call. But:
+ *  - Which device holds the closure result of evaluating \p e1 ?
+ *  - If \p e2 is of function type, what does that mean when we say every function parameter
+ *    is on a device?
+ *  - If \p e1 returns a function, what does that mean when we say every function result is
+ *    on a device?
+ *
+ * Since higher-order aspects are later compiled away (by 'defunctionalization'
+ * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular,
+ * we really don't want our domain \p D to allow for yet another device for the function closure.
+ * So we'll just force the 'device for a function' to be the same as the device for the function's

Review comment:
       yeah this seems like a generally good design practice. people (or passes) can explicitly add copy nodes if that's what they want/need.

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
+ * same device. However each call site can use a different device. In other words primitives are
+ * 'device polymorphic' since we compile and execute them for each required device.
+ *
+ * For most Relay expressions the device for the overall expression is the same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on different devices, even
+ *    different from each other. Every call to the function must use the same choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default device.
+ *  - Unconstrained function parameters devices default to the device for the function result.
+ *  - Unconstrained let-bound expression devices default to the device for the overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to

Review comment:
       actually my feeling is that there may (later) be some desire for user control over this. like a fuzzy device constraint or optimization target. e.g. "Optimize for fewer copies at the expense of performance" because we might not quite get the memory bus modelling right for quite some time yet (or ever, depending on which SoC is being worked with).

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
+ * same device. However each call site can use a different device. In other words primitives are
+ * 'device polymorphic' since we compile and execute them for each required device.
+ *
+ * For most Relay expressions the device for the overall expression is the same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on different devices, even
+ *    different from each other. Every call to the function must use the same choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default device.

Review comment:
       should we instead default them to the most convenient device given the args?

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
+ * same device. However each call site can use a different device. In other words primitives are
+ * 'device polymorphic' since we compile and execute them for each required device.
+ *
+ * For most Relay expressions the device for the overall expression is the same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on different devices, even
+ *    different from each other. Every call to the function must use the same choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default device.
+ *  - Unconstrained function parameters devices default to the device for the function result.
+ *  - Unconstrained let-bound expression devices default to the device for the overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to
+ * the global default device. Worth a design doc with motivating examples I think.
+ *
+ * Phase 3
+ * -------
+ * Finally, the result of this analysis is reified into the result as:
+ *  - Additional "on_device" attributes (an Attrs resolving to a \p FunctionOnDeviceAttrs) for
+ *    every function (both top-level and local). These describe the devices for the function's
+ *    parameters and the result.
+ *  - Additional "device_copy" CallNodes where a copy is required in order to respect the
+ *    intent of the original "on_device" CallNodes.
+ *  - Additional "on_device" CallNodes where the device type of an expression does not match
+ *    that of the lexically enclosing "on_device" CallNode or function attribute. In practice
+ *    this means "on_device" CallNodes may appear in two places:
+ *     - On a let-bound expression if its device differs from the overall let expression.
+ *     - On a call argument if its device differs from the call result. In particular, the
+ *       argument to a "device_copy" call will always be wrapped in an "on_device". (That may
+ *       seem pedantic but simplifies downstream handling.)
+ *    However since we make it easy to track devices for variables we never wrap an "on_device"
+ *    around a var or global var. These uses of "on_device" imply both the argument and result are
+ *    on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true,
+ *    which helps make this pass idempotent.
+ *
+ * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms to recover the device
+ * for any expression for their own use, e.g. during memory planning. All downstream passes must
+ * preserve the lexical scoping of the "on_device" CallNodes. In particular conversion to ANF
+ * must respect the lexical scoping convention:
+ * \code
+ * f(on_device(g(h(a, b), c), device_type=CPU))
+ * ==>
+ * let %x0 = on_device(h(a, b), device_type=CPU)
+ * let %x1 = on_device(g(%x0), device-type=CPU)
+ * f(on_device(%x1, device_type=CPU))
+ * \endcode
+ *
+ * This pass should be run before FuseOps it can use device-specific fusion rules.
+ *
+ * 'Stored on' vs 'Executes on'
+ * ----------------------------
+ * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the
+ * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for
+ * primitives.
+ *
+ * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are
+ * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific
+ * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to
+ * know exactly which device (possibly one of a number of available 'CPU'-like devices) is
+ * responsible for execution. Currently that's handled independently by the \p AnnotateTargets
+ * pass, but we'd like to fold that into device planning here to ensure everything is consistent.
+ *
+ * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay
+ * expression (eg an if expression) on one device even though the tensor data resides on
+ * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on'
+ * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just
+ * compile the function body for the function's result device.
+ *
+ * This works after conversion to ANF provided the compilation for a let expression is prepared
+ * to make a cross-device call. However we leave it to a downstream transformation to heuristically
+ * minimize cross-device calls by moving device copies out of functions. E.g.:
+ * \code
+ *   def @f() {  // execute on CPU
+ *     let x = on_device(...GPU computation..., device_type=GPU);
+ *     device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
+ *   }
+ *   def @main() {
+ *     ... call @f() on CPU ...
+ *   }
+ * \endcode
+ * could be rewritten to:
+ * \code
+ *   def @f() {  // execute on GPU
+ *     let x = ...GPU computation...;
+ *     ...GPU computation...
+ *   }
+ *   def @main() {
+ *     let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
+ *     ... use x on CPU ...
+ *   }
+ * \endcode
+ *
+ * Higher-order shenanigans
+ * ------------------------
+ * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions
+ * as arguments (even anonymous functions), return functions, evaluate conditional expressions
+ * over functions, and so on. We handle this during constraint solving using the domain:
+ * \code
+ *   D  ::= <specific device type>   -- first-order
+ *        | fn(D,...,D):D            -- higher-order
+ * \endcode
+ * In this way we can determine the device for all function parameters and results. E.g. for
+ * \code
+ *   let f = fn(x, y) { ... }
+ *   let g = fn(f, z) { f(z, z) }
+ *   g(f, on_device(..., device_type=CPU))
+ * \endcode
+ * the parameters \p x and \p y will be on the CPU.
+ *
+ * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a
+ * function. Our analysis must guarantee that the function's parameters and result devices are
+ * consistent for \p e2, \p e3, and the context of the call. But:
+ *  - Which device holds the closure result of evaluating \p e1 ?
+ *  - If \p e2 is of function type, what does that mean when we say every function parameter
+ *    is on a device?
+ *  - If \p e1 returns a function, what does that mean when we say every function result is
+ *    on a device?
+ *
+ * Since higher-order aspects are later compiled away (by 'defunctionalization'
+ * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular,
+ * we really don't want our domain \p D to allow for yet another device for the function closure.
+ * So we'll just force the 'device for a function' to be the same as the device for the function's
+ * result using the notion of the 'result domain' for a domain:
+ * \code
+ *   result_domain(<specific device type>) = <specific device type>
+ *   result_domain(fn(D1,...,Dn):Dr)       = result_domain(Dr)
+ * \endcode
+ *
+ * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the
+ * analysis encounters a function inside one of those it simply forces all argument and result
+ * devices for the function to match the device for the first-order expression. For example,
+ * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function
+ * parameters and result must similarly be on the GPU.
+ *
+ * -------
+ * | AOR |  This pass supports all of Relay.
+ * -------
+ *    ^
+ *    |
+ *    `-- Mark's stamp of completeness :-)
+ *
+ * TODO(mbs):
+ *  * Though on_device is the identity for all types we can't wrap it around functions/constructors
+ *    taking type args (or at least not without changing type_infer.cc to see through them).
+ *    This is not currently handled generally.
+ *  * Proper diagnostics for unification failure using spans.
+ *  * Make sure the pass is idempotent even after FuseOps etc.
+ *  * Support application of constructors properly. Are they device polymorphic?
+ *  * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'.
+ *  * Support running the pass post FuseOps (so need to understand primitive functions, both
+ *    outlines and lined) and post the VM transforms (probably need to support more intrinsic
+ *    forms?).
+ *  * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default
+ *    device for primitives vs the default device for the rest of Relay.
+ *  * We'll probably need some support for partial 'device polymorphism' for functions once we
+ *    incorporate targets and memory scopes into the domain. For example it's ok for the function
+ *    body to be executed on different device ids provided they have the same target and memory
+ *    scope.
+ *  * Might be simpler to just let every type have a device annotation rather than work in
+ *    a separate domain?
+ *  * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies.
+ *  * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls
+ *    in tuples at the top level of function bodies or main expression, irrespective of the
+ *    "on_device" body. What's up with that?
+ */
+
+#include "./device_planner.h"
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+
+#include <unordered_map>
+
+#include "../op/annotation/annotation.h"
+#include "../op/memory/device_copy.h"
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+namespace {
+
+/*!
+ * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather
+ * than the original "device_copy" operator.
+ *
+ * See te_compiler.cc for where this rewriting occurs.
+ */
+DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
+  auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
+  if (tir_call_attrs == nullptr) {
+    return {};
+  }
+  if (tir_call_attrs->metadata.count("source_device") != 1 ||
+      tir_call_attrs->metadata.count("dst_device") != 1) {
+    return {};
+  }
+  ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
+  return {
+      call_node->args[0],
+      static_cast<DLDeviceType>(
+          Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
+      static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
+}
+
+class DeviceDomain;
+using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
+
+/******
+****** Domains
+******/
+
+/*!
+ * \brief Represents the domain over which we collect equality constraints.
+ *
+ * \code
+ *   D ::= ?x?                  -- first order, free
+ *       | <device_type>        -- first order, bound
+ *       | fn(D1, ..., Dn):Dr   -- higher order
+ * \endcode
+ *
+ * We require a function value to be on the same device as its result. To support that we need
+ * a notion of the 'result domain' of a domain:
+ * \code
+ *   result_domain(?x?)                = ?x?
+ *   result_domain(<device_type>)      = <device_type>
+ *   result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr)
+ * \endcode
+ */
+class DeviceDomain {
+ public:
+  /*!
+   * \brief Constructs a first-order domain of \p device_type, which may be
+   * \p kInvalidDeviceType to indicate the domain is free.
+   */
+  explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {}
+
+  /*!
+   * \brief Constructs a higher-order domain, where \p args_and_result contain the
+   * function argument and result domains in order.
+   */
+  explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result)
+      : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {}
+
+  /*! \brief Returns true if domain is first-order and free. */
+  bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); }
+
+  /*! \brief Returns true if domain is higher-order. */
+  bool is_higher_order() const { return !args_and_result_.empty(); }
+
+  DLDeviceType first_order_device_type() const {
+    ICHECK(args_and_result_.empty());
+    return device_type_;
+  }
+
+  size_t function_arity() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.size() - 1UL;
+  }
+
+  DeviceDomainPtr function_param(size_t i) const {
+    ICHECK(!args_and_result_.empty());
+    ICHECK_LT(i + 1, args_and_result_.size());
+    return args_and_result_[i];
+  }
+
+  DeviceDomainPtr function_result() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.back();
+  }
+
+ private:
+  /*!
+   * \brief If this is a function domain then always kInvalidDevice. Otherwise will be
+   * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is
+   * bound.
+   */
+  const DLDeviceType device_type_;
+
+  /*!
+   * \brief If this is a function domain then the sub-domains for each of the function's
+   * arguments, and the domain for its result. Otherwise empty.
+   */
+  const std::vector<DeviceDomainPtr> args_and_result_;
+
+  friend struct DeviceDomainHash;
+  friend struct DeviceDomainEqual;
+  friend class DeviceDomains;
+};
+
+// Ye olde boost hash mixer.
+constexpr size_t mix(size_t h1, size_t h2) {
+  return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+}
+
+// The following hash and equality helpers give each free first-order domain pointer its own
+// distinct identity.
+struct DeviceDomainHash {
+  size_t operator()(const DeviceDomainPtr& domain) const {
+    if (domain->is_free()) {
+      // Give each free first-order domain its own identity.
+      return static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get()));
+    } else {
+      size_t h = domain->args_and_result_.size();
+      h = mix(h, std::hash<int>()(static_cast<int>(domain->device_type_)));
+      for (const auto& sub_domain_ptr : domain->args_and_result_) {
+        h = mix(h, DeviceDomainHash()(sub_domain_ptr));
+      }
+      return h;
+    }
+  }
+};
+
+struct DeviceDomainEqual {
+ public:
+  bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const {
+    if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) {
+      // Mismatched arities are never equal.
+      // (Though we'll never ask to do such a comparison explicitly, the hash map
+      // may do so implicitly due to hash collisions.)
+      return false;
+    }
+    if (lhs->is_free() && rhs->is_free()) {
+      // Compare first-order free domains by their address.
+      return lhs.get() == rhs.get();
+    }
+    if (lhs->args_and_result_.empty()) {
+      // Compare first-order domains by their device type -- free vs bound will compare as false.
+      return lhs->device_type_ == rhs->device_type_;
+    } else {
+      // Compare higher-order domains pointwise.
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) {
+          return false;
+        }
+      }
+      return true;
+    }
+  }
+};
+
+/*!
+ * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation
+ * built up by calls to \p Unify.
+ */
+class DeviceDomains {
+ public:
+  DeviceDomains() = default;
+
+  /*!
+   * \brief Returns a domain appropriate for \p type who's result domain is bound
+   * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain
+   * will be free.
+   */
+  static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type) {
+    if (const auto* func_type_node = type.as<FuncTypeNode>()) {
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(func_type_node->arg_types.size() + 1);
+      for (const auto& arg_type : func_type_node->arg_types) {
+        args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType));
+      }
+      args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type));
+      return std::make_shared<DeviceDomain>(std::move(args_and_result));
+    } else {
+      return std::make_shared<DeviceDomain>(device_type);
+    }
+  }
+
+  /*!
+   * \brief Returns a higher-order domain with \p args_and_results.
+   */
+  static DeviceDomainPtr MakeDomain(std::vector<DeviceDomainPtr> arg_and_results) {
+    return std::make_shared<DeviceDomain>(std::move(arg_and_results));
+  }
+
+  /*! \brief Returns a domain with the given result device type appropriate \p device_type. */
+  static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) {
+    ICHECK_NE(device_type, kInvalidDeviceType);
+    return MakeDomain(type, device_type);
+  }
+
+  /*! \brief Returns a free domain appropriate for \p type. */
+  static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); }
+
+  /*! \brief Returns the domain representing the equivalence class containing \p domain. */
+  DeviceDomainPtr Lookup(DeviceDomainPtr domain) {
+    DeviceDomainPtr root = domain;
+    while (true) {
+      auto itr = domain_to_equiv_.find(root);
+      if (itr == domain_to_equiv_.end()) {
+        break;
+      }
+      ICHECK_NE(itr->second, root);
+      root = itr->second;
+      ICHECK_NOTNULL(root);
+    }
+    // Path compression.
+    while (domain != root) {
+      auto itr = domain_to_equiv_.find(domain);
+      ICHECK(itr != domain_to_equiv_.end());
+      domain = itr->second;
+      ICHECK_NOTNULL(domain);
+      itr->second = root;
+    }
+    return root;
+  }
+
+  /*!
+   * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs.
+   *
+   * Throws \p Error on failure.
+   */
+  DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    // TODO(mbs): Proper diagnostics.
+    ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size())
+        << "Device domains:" << std::endl
+        << ToString(lhs) << std::endl
+        << "and" << std::endl
+        << ToString(rhs) << std::endl
+        << "do not have the same kind and can't be unified.";
+    if (rhs->is_free()) {
+      return lhs;
+    } else if (lhs->is_free()) {
+      return rhs;
+    } else if (lhs->args_and_result_.empty()) {
+      // Must have consistent device types for first order domains.
+      if (lhs->device_type_ != rhs->device_type_) {
+        // TODO(mbs): Proper diagnostics.
+        std::ostringstream os;
+        os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_;
+        throw Error(os.str());
+      }
+      return lhs;
+    } else {
+      // Recurse for higher-order.
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(lhs->args_and_result_.size());
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i]));
+      }
+      return MakeDomain(std::move(args_and_result));
+    }
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p
+   * rhs disagree on bound device type.
+   *
+   * Throws \p Error on failure.
+   */
+  // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but
+  // given we have refs to functions I'm prepared to be surprised.
+  DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) {
+    lhs = Lookup(lhs);
+    rhs = Lookup(rhs);
+    auto joined_domain = Join(lhs, rhs);
+    if (!DeviceDomainEqual()(lhs, joined_domain)) {
+      domain_to_equiv_.emplace(lhs, joined_domain);
+    }
+    if (!DeviceDomainEqual()(rhs, joined_domain)) {
+      domain_to_equiv_.emplace(rhs, joined_domain);
+    }
+    return joined_domain;
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order,
+   * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as
+   * \p Unify.
+   *
+   * Throws \p Error on failure.
+   */
+  void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    if (!lhs->is_higher_order() && rhs->is_higher_order()) {
+      Collapse(lhs, rhs);
+    } else {
+      Unify(lhs, rhs);
+    }
+  }
+
+  /*! \brief Returns true if a domain is known for \p expr. */
+  bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); }
+
+  /*! \brief Returns the domain representing \p expr. */
+  DeviceDomainPtr DomainFor(const Expr& expr) {
+    ICHECK(expr.defined());
+    auto itr = expr_to_domain_.find(expr.get());
+    if (itr != expr_to_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    auto domain = Free(expr->checked_type());
+    expr_to_domain_.emplace(expr.get(), domain);
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the
+   * callee is a primitive or special operation we handle it specially. Otherwise defers to \p
+   * DomainFor(call->op).
+   *
+   * This special handling is needed:

Review comment:
       this smells a bit like an overridable Pass...wdyt?

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
+ * same device. However each call site can use a different device. In other words primitives are
+ * 'device polymorphic' since we compile and execute them for each required device.
+ *
+ * For most Relay expressions the device for the overall expression is the same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on different devices, even
+ *    different from each other. Every call to the function must use the same choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default device.
+ *  - Unconstrained function parameters devices default to the device for the function result.
+ *  - Unconstrained let-bound expression devices default to the device for the overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to
+ * the global default device. Worth a design doc with motivating examples I think.
+ *
+ * Phase 3
+ * -------
+ * Finally, the result of this analysis is reified into the result as:
+ *  - Additional "on_device" attributes (an Attrs resolving to a \p FunctionOnDeviceAttrs) for
+ *    every function (both top-level and local). These describe the devices for the function's
+ *    parameters and the result.
+ *  - Additional "device_copy" CallNodes where a copy is required in order to respect the
+ *    intent of the original "on_device" CallNodes.
+ *  - Additional "on_device" CallNodes where the device type of an expression does not match
+ *    that of the lexically enclosing "on_device" CallNode or function attribute. In practice
+ *    this means "on_device" CallNodes may appear in two places:
+ *     - On a let-bound expression if its device differs from the overall let expression.
+ *     - On a call argument if its device differs from the call result. In particular, the
+ *       argument to a "device_copy" call will always be wrapped in an "on_device". (That may
+ *       seem pedantic but simplifies downstream handling.)
+ *    However since we make it easy to track devices for variables we never wrap an "on_device"
+ *    around a var or global var. These uses of "on_device" imply both the argument and result are
+ *    on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true,
+ *    which helps make this pass idempotent.
+ *
+ * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms to recover the device
+ * for any expression for their own use, e.g. during memory planning. All downstream passes must
+ * preserve the lexical scoping of the "on_device" CallNodes. In particular conversion to ANF
+ * must respect the lexical scoping convention:
+ * \code
+ * f(on_device(g(h(a, b), c), device_type=CPU))
+ * ==>
+ * let %x0 = on_device(h(a, b), device_type=CPU)
+ * let %x1 = on_device(g(%x0), device-type=CPU)
+ * f(on_device(%x1, device_type=CPU))
+ * \endcode
+ *
+ * This pass should be run before FuseOps it can use device-specific fusion rules.
+ *
+ * 'Stored on' vs 'Executes on'
+ * ----------------------------
+ * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the
+ * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for
+ * primitives.
+ *
+ * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are
+ * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific
+ * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to
+ * know exactly which device (possibly one of a number of available 'CPU'-like devices) is
+ * responsible for execution. Currently that's handled independently by the \p AnnotateTargets
+ * pass, but we'd like to fold that into device planning here to ensure everything is consistent.
+ *
+ * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay
+ * expression (eg an if expression) on one device even though the tensor data resides on
+ * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on'
+ * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just
+ * compile the function body for the function's result device.
+ *
+ * This works after conversion to ANF provided the compilation for a let expression is prepared
+ * to make a cross-device call. However we leave it to a downstream transformation to heuristically
+ * minimize cross-device calls by moving device copies out of functions. E.g.:
+ * \code
+ *   def @f() {  // execute on CPU
+ *     let x = on_device(...GPU computation..., device_type=GPU);
+ *     device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
+ *   }
+ *   def @main() {
+ *     ... call @f() on CPU ...
+ *   }
+ * \endcode
+ * could be rewritten to:
+ * \code
+ *   def @f() {  // execute on GPU
+ *     let x = ...GPU computation...;
+ *     ...GPU computation...
+ *   }
+ *   def @main() {
+ *     let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
+ *     ... use x on CPU ...
+ *   }
+ * \endcode
+ *
+ * Higher-order shenanigans
+ * ------------------------
+ * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions
+ * as arguments (even anonymous functions), return functions, evaluate conditional expressions
+ * over functions, and so on. We handle this during constraint solving using the domain:
+ * \code
+ *   D  ::= <specific device type>   -- first-order
+ *        | fn(D,...,D):D            -- higher-order
+ * \endcode
+ * In this way we can determine the device for all function parameters and results. E.g. for
+ * \code
+ *   let f = fn(x, y) { ... }
+ *   let g = fn(f, z) { f(z, z) }
+ *   g(f, on_device(..., device_type=CPU))
+ * \endcode
+ * the parameters \p x and \p y will be on the CPU.
+ *
+ * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a
+ * function. Our analysis must guarantee that the function's parameters and result devices are
+ * consistent for \p e2, \p e3, and the context of the call. But:
+ *  - Which device holds the closure result of evaluating \p e1 ?
+ *  - If \p e2 is of function type, what does that mean when we say every function parameter
+ *    is on a device?
+ *  - If \p e1 returns a function, what does that mean when we say every function result is
+ *    on a device?
+ *
+ * Since higher-order aspects are later compiled away (by 'defunctionalization'
+ * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular,
+ * we really don't want our domain \p D to allow for yet another device for the function closure.
+ * So we'll just force the 'device for a function' to be the same as the device for the function's
+ * result using the notion of the 'result domain' for a domain:
+ * \code
+ *   result_domain(<specific device type>) = <specific device type>
+ *   result_domain(fn(D1,...,Dn):Dr)       = result_domain(Dr)
+ * \endcode
+ *
+ * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the
+ * analysis encounters a function inside one of those it simply forces all argument and result
+ * devices for the function to match the device for the first-order expression. For example,
+ * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function
+ * parameters and result must similarly be on the GPU.
+ *
+ * -------
+ * | AOR |  This pass supports all of Relay.
+ * -------
+ *    ^
+ *    |
+ *    `-- Mark's stamp of completeness :-)
+ *
+ * TODO(mbs):
+ *  * Though on_device is the identity for all types we can't wrap it around functions/constructors
+ *    taking type args (or at least not without changing type_infer.cc to see through them).
+ *    This is not currently handled generally.
+ *  * Proper diagnostics for unification failure using spans.
+ *  * Make sure the pass is idempotent even after FuseOps etc.
+ *  * Support application of constructors properly. Are they device polymorphic?
+ *  * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'.
+ *  * Support running the pass post FuseOps (so need to understand primitive functions, both
+ *    outlines and lined) and post the VM transforms (probably need to support more intrinsic
+ *    forms?).
+ *  * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default
+ *    device for primitives vs the default device for the rest of Relay.
+ *  * We'll probably need some support for partial 'device polymorphism' for functions once we
+ *    incorporate targets and memory scopes into the domain. For example it's ok for the function
+ *    body to be executed on different device ids provided they have the same target and memory
+ *    scope.
+ *  * Might be simpler to just let every type have a device annotation rather than work in
+ *    a separate domain?
+ *  * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies.
+ *  * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls
+ *    in tuples at the top level of function bodies or main expression, irrespective of the
+ *    "on_device" body. What's up with that?
+ */
+
+#include "./device_planner.h"
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+
+#include <unordered_map>
+
+#include "../op/annotation/annotation.h"
+#include "../op/memory/device_copy.h"
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+namespace {
+
+/*!
+ * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather
+ * than the original "device_copy" operator.
+ *
+ * See te_compiler.cc for where this rewriting occurs.
+ */
+DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
+  auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
+  if (tir_call_attrs == nullptr) {
+    return {};
+  }
+  if (tir_call_attrs->metadata.count("source_device") != 1 ||
+      tir_call_attrs->metadata.count("dst_device") != 1) {
+    return {};
+  }
+  ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
+  return {
+      call_node->args[0],
+      static_cast<DLDeviceType>(
+          Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
+      static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
+}
+
+class DeviceDomain;
+using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
+
+/******
+****** Domains
+******/
+
+/*!
+ * \brief Represents the domain over which we collect equality constraints.
+ *
+ * \code
+ *   D ::= ?x?                  -- first order, free
+ *       | <device_type>        -- first order, bound
+ *       | fn(D1, ..., Dn):Dr   -- higher order
+ * \endcode
+ *
+ * We require a function value to be on the same device as its result. To support that we need
+ * a notion of the 'result domain' of a domain:
+ * \code
+ *   result_domain(?x?)                = ?x?
+ *   result_domain(<device_type>)      = <device_type>
+ *   result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr)
+ * \endcode
+ */
+class DeviceDomain {
+ public:
+  /*!
+   * \brief Constructs a first-order domain of \p device_type, which may be
+   * \p kInvalidDeviceType to indicate the domain is free.
+   */
+  explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {}
+
+  /*!
+   * \brief Constructs a higher-order domain, where \p args_and_result contain the
+   * function argument and result domains in order.
+   */
+  explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result)
+      : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {}
+
+  /*! \brief Returns true if domain is first-order and free. */
+  bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); }
+
+  /*! \brief Returns true if domain is higher-order. */
+  bool is_higher_order() const { return !args_and_result_.empty(); }
+
+  DLDeviceType first_order_device_type() const {
+    ICHECK(args_and_result_.empty());
+    return device_type_;
+  }
+
+  size_t function_arity() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.size() - 1UL;
+  }
+
+  DeviceDomainPtr function_param(size_t i) const {
+    ICHECK(!args_and_result_.empty());
+    ICHECK_LT(i + 1, args_and_result_.size());
+    return args_and_result_[i];
+  }
+
+  DeviceDomainPtr function_result() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.back();
+  }
+
+ private:
+  /*!
+   * \brief If this is a function domain then always kInvalidDevice. Otherwise will be
+   * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is
+   * bound.
+   */
+  const DLDeviceType device_type_;
+
+  /*!
+   * \brief If this is a function domain then the sub-domains for each of the function's
+   * arguments, and the domain for its result. Otherwise empty.
+   */
+  const std::vector<DeviceDomainPtr> args_and_result_;
+
+  friend struct DeviceDomainHash;
+  friend struct DeviceDomainEqual;
+  friend class DeviceDomains;
+};
+
+// Ye olde boost hash mixer.
+constexpr size_t mix(size_t h1, size_t h2) {
+  return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+}
+
+// The following hash and equality helpers give each free first-order domain pointer its own
+// distinct identity.
+struct DeviceDomainHash {
+  size_t operator()(const DeviceDomainPtr& domain) const {
+    if (domain->is_free()) {
+      // Give each free first-order domain its own identity.
+      return static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get()));
+    } else {
+      size_t h = domain->args_and_result_.size();
+      h = mix(h, std::hash<int>()(static_cast<int>(domain->device_type_)));
+      for (const auto& sub_domain_ptr : domain->args_and_result_) {
+        h = mix(h, DeviceDomainHash()(sub_domain_ptr));
+      }
+      return h;
+    }
+  }
+};
+
+struct DeviceDomainEqual {
+ public:
+  bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const {
+    if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) {
+      // Mismatched arities are never equal.
+      // (Though we'll never ask to do such a comparison explicitly, the hash map
+      // may do so implicitly due to hash collisions.)
+      return false;
+    }
+    if (lhs->is_free() && rhs->is_free()) {
+      // Compare first-order free domains by their address.
+      return lhs.get() == rhs.get();
+    }
+    if (lhs->args_and_result_.empty()) {
+      // Compare first-order domains by their device type -- free vs bound will compare as false.
+      return lhs->device_type_ == rhs->device_type_;
+    } else {
+      // Compare higher-order domains pointwise.
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) {
+          return false;
+        }
+      }
+      return true;
+    }
+  }
+};
+
+/*!
+ * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation
+ * built up by calls to \p Unify.
+ */
+class DeviceDomains {
+ public:
+  DeviceDomains() = default;
+
+  /*!
+   * \brief Returns a domain appropriate for \p type who's result domain is bound
+   * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain
+   * will be free.
+   */
+  static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type) {
+    if (const auto* func_type_node = type.as<FuncTypeNode>()) {
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(func_type_node->arg_types.size() + 1);
+      for (const auto& arg_type : func_type_node->arg_types) {
+        args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType));
+      }
+      args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type));
+      return std::make_shared<DeviceDomain>(std::move(args_and_result));
+    } else {
+      return std::make_shared<DeviceDomain>(device_type);
+    }
+  }
+
+  /*!
+   * \brief Returns a higher-order domain with \p args_and_results.
+   */
+  static DeviceDomainPtr MakeDomain(std::vector<DeviceDomainPtr> arg_and_results) {
+    return std::make_shared<DeviceDomain>(std::move(arg_and_results));
+  }
+
+  /*! \brief Returns a domain with the given result device type appropriate \p device_type. */
+  static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) {
+    ICHECK_NE(device_type, kInvalidDeviceType);
+    return MakeDomain(type, device_type);
+  }
+
+  /*! \brief Returns a free domain appropriate for \p type. */
+  static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); }
+
+  /*! \brief Returns the domain representing the equivalence class containing \p domain. */
+  DeviceDomainPtr Lookup(DeviceDomainPtr domain) {
+    DeviceDomainPtr root = domain;
+    while (true) {
+      auto itr = domain_to_equiv_.find(root);
+      if (itr == domain_to_equiv_.end()) {
+        break;
+      }
+      ICHECK_NE(itr->second, root);
+      root = itr->second;
+      ICHECK_NOTNULL(root);
+    }
+    // Path compression.
+    while (domain != root) {
+      auto itr = domain_to_equiv_.find(domain);
+      ICHECK(itr != domain_to_equiv_.end());
+      domain = itr->second;
+      ICHECK_NOTNULL(domain);
+      itr->second = root;
+    }
+    return root;
+  }
+
+  /*!
+   * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs.
+   *
+   * Throws \p Error on failure.
+   */
+  DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    // TODO(mbs): Proper diagnostics.
+    ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size())
+        << "Device domains:" << std::endl
+        << ToString(lhs) << std::endl
+        << "and" << std::endl
+        << ToString(rhs) << std::endl
+        << "do not have the same kind and can't be unified.";
+    if (rhs->is_free()) {
+      return lhs;
+    } else if (lhs->is_free()) {
+      return rhs;
+    } else if (lhs->args_and_result_.empty()) {
+      // Must have consistent device types for first order domains.
+      if (lhs->device_type_ != rhs->device_type_) {
+        // TODO(mbs): Proper diagnostics.
+        std::ostringstream os;
+        os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_;
+        throw Error(os.str());
+      }
+      return lhs;
+    } else {
+      // Recurse for higher-order.
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(lhs->args_and_result_.size());
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i]));
+      }
+      return MakeDomain(std::move(args_and_result));
+    }
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p
+   * rhs disagree on bound device type.
+   *
+   * Throws \p Error on failure.
+   */
+  // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but
+  // given we have refs to functions I'm prepared to be surprised.
+  DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) {
+    lhs = Lookup(lhs);
+    rhs = Lookup(rhs);
+    auto joined_domain = Join(lhs, rhs);
+    if (!DeviceDomainEqual()(lhs, joined_domain)) {
+      domain_to_equiv_.emplace(lhs, joined_domain);
+    }
+    if (!DeviceDomainEqual()(rhs, joined_domain)) {
+      domain_to_equiv_.emplace(rhs, joined_domain);
+    }
+    return joined_domain;
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order,
+   * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as
+   * \p Unify.
+   *
+   * Throws \p Error on failure.
+   */
+  void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    if (!lhs->is_higher_order() && rhs->is_higher_order()) {
+      Collapse(lhs, rhs);
+    } else {
+      Unify(lhs, rhs);
+    }
+  }
+
+  /*! \brief Returns true if a domain is known for \p expr. */
+  bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); }
+
+  /*! \brief Returns the domain representing \p expr. */
+  DeviceDomainPtr DomainFor(const Expr& expr) {
+    ICHECK(expr.defined());
+    auto itr = expr_to_domain_.find(expr.get());
+    if (itr != expr_to_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    auto domain = Free(expr->checked_type());
+    expr_to_domain_.emplace(expr.get(), domain);
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the
+   * callee is a primitive or special operation we handle it specially. Otherwise defers to \p
+   * DomainFor(call->op).
+   *
+   * This special handling is needed:
+   * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices.
+   * - To handle some special ops which constrain devices to the CPU.
+   * - To allow the same primitive to be called on different devices at different call sites.
+   * Since each call to the op can have a different domain we index the ops by the call expression
+   * rather than the op itself.
+   */
+  DeviceDomainPtr DomainForCallee(const Call& call) {
+    auto itr = call_to_callee_domain_.find(call.get());
+    if (itr != call_to_callee_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    std::vector<DeviceDomainPtr> args_and_result;
+
+    auto on_device_props = GetOnDeviceProps(call.get());
+    auto device_copy_props = GetDeviceCopyProps(call.get());
+    if (!device_copy_props.body.defined()) {
+      device_copy_props = GetPrimitiveDeviceCopyProps(call.get());
+    }
+
+    if (on_device_props.body.defined()) {
+      // on_device(expr, device_type=<t>, is_fixed=false)
+      // on_device : fn(<t>):?x?
+      //
+      // on_device(expr, device_type=<t>, is_fixed=true)
+      // on_device: fn(<t>):<t>
+      args_and_result.emplace_back(
+          ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type));
+      if (on_device_props.is_fixed) {
+        args_and_result.emplace_back(args_and_result.front());
+      } else {
+        args_and_result.emplace_back(Free(on_device_props.body->checked_type()));
+      }
+    } else if (device_copy_props.body.defined()) {
+      // device_copy(expr, src_dev_type=<s>, dst_dev_type=<d>)
+      // device_copy: fn(<s>):<d>
+      args_and_result.emplace_back(
+          ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type));
+      args_and_result.emplace_back(
+          ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type));
+    } else if (call->op == alloc_storage_op) {
+      ICHECK_EQ(call->args.size(), 2U);
+      // alloc_storage(size, alignment, device_type=<t>)
+      // alloc_storage: fn(<cpu>, <cpu>):<t>
+      const auto* attrs = call->attrs.as<AllocStorageAttrs>();
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(
+          ForDeviceType(call->checked_type(), static_cast<DLDeviceType>(attrs->device_type)));
+    } else if (call->op == alloc_tensor_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // alloc_tensor(storage, offset, shape)
+      // alloc_tensor: fn(?x?, <cpu>, <cpu>):?x?
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op == shape_func_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // shape_func(func, inputs, outputs, is_inputs=[...])
+      // shape_func: fn(..., <cpu>, <cpu>):<cpu>
+      // where ... is a free domain appropriate for func's type
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      // TODO(mbs): I think this should be on the cpu only when is_input = [false], but
+      // what do we do when we have multiple arguments with different is_input values?
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+    } else if (call->op == shape_of_op) {
+      ICHECK_EQ(call->args.size(), 1U);
+      // shape_of(tensor)
+      // shape_of: fn(?x?):<cpu>
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      args_and_result.emplace_back(cpu_domain_);
+    } else if (call->op == invoke_tvm_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // invoke_tvm_op(op, inputs, outputs)
+      // invoke_tvm_op: fn(..., ?x?, ?x?):?x?
+      // where ... is a free domain appropriate for op's type
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op == reshape_tensor_op) {
+      ICHECK_EQ(call->args.size(), 2U);
+      // reshape_tensor(data, shape)
+      // reshape_tensor: fn(?x?, <cpu>):?x?
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op->IsInstance<OpNode>()) {
+      // <primitive>(arg1, ..., argn)
+      // <primitive>: fn(?x?, ..., ?x?):?x?
+      // (all args and result must be first-order).
+      auto free_domain = Free(arb_);
+      for (size_t i = 0; i < call->args.size(); ++i) {
+        args_and_result.emplace_back(free_domain);
+      }
+      args_and_result.emplace_back(free_domain);
+    } else {
+      // Defer to normal case where op can be an arbitrary expression.
+      return DomainFor(call->op);
+    }
+    auto domain = MakeDomain(std::move(args_and_result));
+    call_to_callee_domain_.emplace(call.get(), domain);
+    return domain;
+  }
+
+  /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */
+  void UnifyExprExact(const Expr& lhs, const Expr& rhs) {
+    auto lhs_domain = DomainFor(lhs);
+    auto rhs_domain = DomainFor(rhs);
+    try {
+      Unify(lhs_domain, rhs_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expressions:" << std::endl
+                 << PrettyPrint(lhs) << std::endl
+                 << "with device:" << std::endl
+                 << ToString(lhs_domain) << "and:" << std::endl
+                 << PrettyPrint(rhs) << std::endl
+                 << "with device:" << std::endl
+                 << ToString(rhs_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*!
+   * \brief Unifies the domain for \p expr with \p expected_domain.
+   */
+  void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) {
+    auto actual_domain = DomainFor(expr);
+    try {
+      Unify(actual_domain, expected_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expression:" << std::endl
+                 << PrettyPrint(expr) << std::endl
+                 << "with actual device:" << std::endl
+                 << ToString(actual_domain) << std::endl
+                 << "and expected device:" << std::endl
+                 << ToString(expected_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*!
+   * \brief Unifies the domain for \p expr with \p expected_domain.
+   * If \p expected_domain is higher-order but \p expr is first-order, require all arguments
+   * and the result of \p expected_domain to have the same domain as for \p expr.
+   */
+  void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) {
+    auto actual_domain = DomainFor(expr);
+    try {
+      UnifyCollapsed(actual_domain, expected_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expression:" << std::endl
+                 << PrettyPrint(expr) << std::endl
+                 << "with actual device:" << std::endl
+                 << ToString(actual_domain) << std::endl
+                 << "and expected device:" << std::endl
+                 << ToString(expected_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*! \brief Returns true if \p domain contains any free sub-domains. */
+  bool AnyFree(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    if (domain->is_free()) {
+      return true;
+    }
+    for (const auto& sub_domain : domain->args_and_result_) {
+      if (AnyFree(sub_domain)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /*
+   * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain.
+   * This can be used to handle functions within tuples, references and ADTs since we don't
+   * attempt to track anything beyond 'the device' for expressions of those first-order types.
+   *
+   * Throws \p Error on failure.
+   */
+  void Collapse(const DeviceDomainPtr& first_order_domain,
+                const DeviceDomainPtr& higher_order_domain) {
+    for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) {
+      Unify(higher_order_domain->function_param(i), first_order_domain);
+    }
+    Unify(higher_order_domain->function_result(), first_order_domain);
+  }
+
+  /*! \brief Force all free domains in \p domain to default to \p default_device_type. */
+  void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) {
+    ICHECK_NE(default_device_type, kInvalidDeviceType);
+    domain = Lookup(domain);
+    if (domain->is_free()) {
+      // Will never throw since lhs is free.
+      Unify(domain, std::make_shared<DeviceDomain>(default_device_type));
+    } else if (!domain->args_and_result_.empty()) {
+      for (const auto& sub_domain : domain->args_and_result_) {
+        SetDefault(sub_domain, default_device_type);
+      }
+    }
+  }
+
+  /*!
+   * \brief If \p domain is higher-order and its result domain is free, force it to
+   * \p default_device_type. Then force any  remaining free domains to the result domain
+   * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault.
+   */
+  void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type) {
+    if (!domain->is_higher_order()) {
+      SetDefault(domain, default_device_type);
+      return;
+    }
+    DLDeviceType result_device_type = ResultDeviceType(domain);
+    if (result_device_type == kInvalidDeviceType) {
+      // If the function result device is still free use the given default.
+      result_device_type = default_device_type;
+    }
+    // Default any remaining free parameters to the function result device.
+    SetDefault(domain, result_device_type);
+  }
+
+  /*! \brief Returns one-line description of \p domain for debugging. */
+  std::string ToString(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    std::ostringstream os;
+    if (domain->is_free()) {
+      // first-order free
+      os << "?" << static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get())) << "?";
+    } else if (domain->args_and_result_.empty()) {
+      // first-order bound
+      os << "<" << domain->device_type_ << ">";
+    } else {
+      // higher-order
+      os << "fn(";
+      for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) {
+        if (i > 0) {
+          os << ",";
+        }
+        os << ToString(domain->args_and_result_[i]);
+      }
+      os << "):" << ToString(domain->args_and_result_.back());
+    }
+    return os.str();
+  }
+
+  /*! \brief Returns description of entire system of constraints for debugging */
+  std::string ToString() {
+    std::ostringstream os;
+    for (const auto& pair : expr_to_domain_) {
+      os << "expression:" << std::endl
+         << PrettyPrint(GetRef<Expr>(pair.first)) << std::endl
+         << "domain:" << std::endl
+         << ToString(pair.second) << std::endl
+         << std::endl;
+    }
+    for (const auto& pair : call_to_callee_domain_) {
+      os << "call:" << std::endl
+         << PrettyPrint(GetRef<Call>(pair.first)) << std::endl
+         << "callee domain:" << std::endl
+         << ToString(pair.second) << std::endl
+         << std::endl;
+    }
+    return os.str();
+  }
+
+  /*!
+   * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment).
+   */
+  DeviceDomainPtr ResultDomain(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    while (!domain->args_and_result_.empty()) {
+      domain = Lookup(domain->args_and_result_.back());
+    }
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain
+   * comment).
+   */
+  DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) {
+    return ResultDomain(domain)->first_order_device_type();
+  }
+
+ private:
+  /*! \brief Intrinsics we need to handle specially. */
+  const Op& alloc_storage_op = Op::Get("memory.alloc_storage");
+  const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor");
+  const Op& shape_of_op = Op::Get("vm.shape_of");
+  const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op");
+  const Op& shape_func_op = Op::Get("vm.shape_func");
+  const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor");
+  /*! \brief The CPU device type for special operators such as dynamic shape functions. */
+  const DLDeviceType cpu_device_type_ = kDLCPU;
+  /*! \brief Placeholder for any first-order type. */
+  Type arb_ = TupleType();
+  /*! \brief The domain for first-order expressions on the CPU. */
+  DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_);
+
+  /*! \brief Maps expressions to their domains as determined during analysis. */
+  std::unordered_map<const ExprNode*, DeviceDomainPtr> expr_to_domain_;
+
+  /*!
+   * \brief Maps call expressions to the domains for their callee where the callee is a primitive.
+   */
+  std::unordered_map<const CallNode*, DeviceDomainPtr> call_to_callee_domain_;
+
+  /*! \brief Maps device domains to their equivalent domains as determined during unification. */
+  std::unordered_map<DeviceDomainPtr, DeviceDomainPtr, DeviceDomainHash, DeviceDomainEqual>
+      domain_to_equiv_;
+};
+
+/******
+****** Phase 0
+******/
+
+/*!
+ * \brief Rewrites "on_device" calls to handle some special cases.
+ */
+class RewriteOnDevices : public ExprMutator {
+ public:
+  RewriteOnDevices() = default;
+
+ private:
+  Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
+    Expr tuple = VisitExpr(tuple_get_item_node->tuple);
+    // TODO(mbs): Avoid copy.
+    Expr tuple_get_item =
+        TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span);
+    auto props = GetOnDeviceProps(tuple);
+    if (props.body.defined() && !props.is_fixed) {
+      VLOG(1) << "wrapping tuple get item:" << std::endl
+              << PrettyPrint(GetRef<TupleGetItem>(tuple_get_item_node)) << std::endl
+              << "with \"on_device\" for device " << props.device_type;
+      return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false);
+    } else {
+      return tuple_get_item;
+    }
+  }
+
+  Expr VisitExpr_(const LetNode* let_node) final {
+    auto expr = GetRef<Expr>(let_node);
+    std::vector<std::tuple<Var, Expr, Span>> bindings;
+    while (const auto* inner_let_node = expr.as<LetNode>()) {
+      Expr inner_let = GetRef<Let>(inner_let_node);
+      Expr value = VisitExpr(inner_let_node->value);
+      auto props = GetOnDeviceProps(value);
+      if (props.body.defined() && !props.is_fixed) {
+        VLOG(1) << "revising let-bound expression of let:" << std::endl
+                << PrettyPrint(expr) << std::endl
+                << "to be fixed to device " << props.device_type;
+        value = OnDevice(props.body, props.device_type, /*is_fixed=*/true);
+      }
+      bindings.emplace_back(inner_let_node->var, value, inner_let_node->span);
+      expr = inner_let_node->body;
+    }
+    expr = VisitExpr(expr);
+    // TODO(mbs): Avoid copy.
+    for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
+      expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr,
+                 /*span=*/std::get<2>(*itr));
+    }
+    return expr;
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) final {
+    Expr body = VisitExpr(function_node->body);
+    auto props = GetOnDeviceProps(body);
+    if (props.body.defined() && !props.is_fixed) {
+      VLOG(1) << "revising body of function:" << std::endl
+              << PrettyPrint(GetRef<Function>(function_node)) << std::endl
+              << "to be fixed to device " << props.device_type;
+      body = OnDevice(props.body, props.device_type, /*is_fixed=*/true);
+    }
+    // TODO(mbs): Avoid copy
+    return Function(function_node->params, body, function_node->ret_type,
+                    function_node->type_params, function_node->attrs, function_node->span);
+  }
+};
+
+/******
+****** Phase 1
+******/
+
+/*
+ * \brief Collects the system of device constraints for all sub-expressions in a module.
+ * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter.
+ */
+class DeviceAnalyzer : public ExprVisitor {
+ public:
+  explicit DeviceAnalyzer(IRModule mod)
+      : mod_(std::move(mod)), domains_(std::make_unique<DeviceDomains>()) {}
+
+  /*!
+   * \brief Returns the expression-to-device-domain map for all expressions in all the global
+   * function definitions in the module. Expressions may have free domains, these will be resolved
+   * by \p DeviceDefaulter below.
+   */
+  std::unique_ptr<DeviceDomains> Analyze() {
+    VLOG_CONTEXT << "DeviceAnalyzer";
+    for (const auto& pair : mod_->functions) {
+      VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'";
+      domains_->UnifyExprExact(pair.first, pair.second);
+      VisitExpr(pair.second);
+    }
+    return std::move(domains_);
+  }
+
+ private:
+  void VisitExpr_(const CallNode* call_node) final {
+    auto call = GetRef<Call>(call_node);
+
+    // Find the higher-order domain for the callee. See DomainForCallee for the special rules
+    // for primitives.
+    VisitExpr(call_node->op);
+    auto func_domain = domains_->DomainForCallee(call);  // higher-order
+
+    // Build the domain for the function implied by its arguments and call context.
+    ICHECK_EQ(func_domain->function_arity(), call_node->args.size());
+    std::vector<DeviceDomainPtr> args_and_result_domains;
+    args_and_result_domains.reserve(call_node->args.size() + 1);
+    for (const auto& arg : call_node->args) {
+      args_and_result_domains.emplace_back(domains_->DomainFor(arg));
+      VisitExpr(arg);
+    }
+    args_and_result_domains.emplace_back(domains_->DomainFor(call));
+    auto implied_domain =
+        DeviceDomains::MakeDomain(std::move(args_and_result_domains));  // higher-order
+
+    VLOG(1) << "initial call function domain:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "and implied domain:" << std::endl
+            << domains_->ToString(implied_domain) << "for call:" << std::endl
+            << PrettyPrint(call);
+
+    // The above must match.
+    try {
+      domains_->Unify(func_domain, implied_domain);  // higher-order
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:"

Review comment:
       !! fatal! probably we need to address the TODO before merging...i think?

##########
File path: tests/python/relay/test_pass_plan_devices.py
##########
@@ -0,0 +1,1405 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License
+
+
+"""Unit tests for the PlanDevices pass. We check:
+    - The pass alone given the expected AST, though we need to manually run InferTypes.
+    - The pass is idempotent."""
+
+# TODO(mbs): All the input/expected programs should be directly quoted using @script
+# TODO(mbs): Not testing Match and Constructor since not supported by Python bindings?
+# TODO(mbs): Add back reference implementation tests once VM is ready.
+
+import tvm
+from tvm import relay
+import tvm.testing
+import numpy as np
+
+N = 5
+M = 7
+CPU = tvm.device("cpu")  # device_type=1
+GPU = tvm.device("cuda")  # device_type=2
+DEFAULT = GPU
+
+
+def rewrite_and_assert(in_mod, expected_mod):
+    """Manually run the pass and assert it's structurally equals to the expected."""
+    actual_mod = relay.transform.InferType()(in_mod)
+    actual_mod = relay.transform.PlanDevices(DEFAULT)(actual_mod)
+    actual_mod = relay.transform.InferType()(actual_mod)
+    expected_mod = relay.transform.InferType()(expected_mod)
+    if not tvm.ir.structural_equal(actual_mod, expected_mod):
+        # Print everything in full so we can see what's going on when things fail.
+        print("Input module:")
+        print(in_mod)
+        print("Expected module:")
+        print(expected_mod)
+        print("Actual module:")
+        print(actual_mod)
+        # Assert again so as to see the actual disagreeing sub-expressions.
+        tvm.ir.assert_structural_equal(actual_mod, expected_mod)
+
+
+def rand(shape):
+    return np.random.rand(*shape).astype("float32")
+
+
+def rands(shape, n):
+    return [rand(shape) for i in range(n)]
+
+
+def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, args):
+    """Test in_mod against expected_mod and reference_func using args."""
+    # Correctness
+    rewrite_and_assert(in_mod, expected_mod)
+    # Idempotence
+    rewrite_and_assert(expected_mod, expected_mod)
+    # TODO(mbs): Add back compiling and comparing to reference implementation once VM is ready.
+
+
+#
+# Annotation shorthands
+#
+
+
+def on_cpu(expr: relay.Expr):
+    return relay.annotation.on_device(expr, CPU)
+
+
+def on_gpu(expr: relay.Expr):
+    return relay.annotation.on_device(expr, GPU)
+
+
+def cpu_to_gpu(expr: relay.Expr):
+    return relay.op.device_copy(expr, CPU, GPU)
+
+
+def gpu_to_cpu(expr: relay.Expr):
+    return relay.op.device_copy(expr, GPU, CPU)
+
+
+def fixed_cpu(expr: relay.Expr):
+    return relay.annotation.on_device(expr, CPU, True)
+
+
+def fixed_gpu(expr: relay.Expr):
+    return relay.annotation.on_device(expr, GPU, True)
+
+
+def test_plain():
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    c = relay.var("c", shape=shape)
+    d = relay.var("d", shape=shape)
+
+    # def @main(a, b, c, d) { subtract(add(a, b), add(c, d)) }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function([a, b, c, d], relay.subtract(relay.add(a, b), relay.add(c, d)))
+        )
+
+    # def @main(a, b, c, d, on_device={param_device_types=[2,2,2,2], result_device_type=2}) {
+    #   subtract(add(a, b), add(c, d))
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function([a, b, c, d], relay.subtract(relay.add(a, b), relay.add(c, d))),
+                [GPU, GPU, GPU, GPU],
+                GPU,
+            )
+        )
+
+    def ref(a, b, c, d):
+        return np.subtract(np.add(a, b), np.add(c, d))
+
+    exercise(input(), expected(), ref, rands(shape, 4))
+
+
+def test_left_add_on_cpu():
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    c = relay.var("c", shape=shape)
+    d = relay.var("d", shape=shape)
+
+    # def @main(a, b, c, d) { subtract(on_cpu(add(a, b)), add(c, d)) }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function([a, b, c, d], relay.subtract(on_cpu(relay.add(a, b)), relay.add(c, d)))
+        )
+
+    # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) {
+    #    subtract(cpu_to_gpu(fixed_cpu(add(a, b)), add(c, d))
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [a, b, c, d],
+                    relay.subtract(cpu_to_gpu(fixed_cpu(relay.add(a, b))), relay.add(c, d)),
+                ),
+                [CPU, CPU, GPU, GPU],
+                GPU,
+            )
+        )
+
+    def ref(a, b, c, d):
+        return np.subtract(np.add(a, b), np.add(c, d))
+
+    exercise(input(), expected(), ref, rands(shape, 4))
+
+
+def test_left_add_on_cpu_via_copy():
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    c = relay.var("c", shape=shape)
+    d = relay.var("d", shape=shape)
+
+    # def @main(a, b, c, d) { subtract(cpu_to_gpu(add(a, b)), add(c, d)) }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [a, b, c, d], relay.subtract(cpu_to_gpu(relay.add(a, b)), relay.add(c, d))
+            )
+        )
+
+    # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) {
+    #    subtract(cpu_to_gpu(fixed_cpu(add(a, b)), add(c, d))
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [a, b, c, d],
+                    relay.subtract(cpu_to_gpu(fixed_cpu(relay.add(a, b))), relay.add(c, d)),
+                ),
+                [CPU, CPU, GPU, GPU],
+                GPU,
+            )
+        )
+
+    def ref(a, b, c, d):
+        return np.subtract(np.add(a, b), np.add(c, d))
+
+    exercise(input(), expected(), ref, rands(shape, 4))
+
+
+def test_both_adds_on_cpu():
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    c = relay.var("c", shape=shape)
+    d = relay.var("d", shape=shape)
+
+    # def @main(a, b, c, d) { subtract(on_cpu(add(a, b)), on_cpu(add(c, d))) }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [a, b, c, d], relay.subtract(on_cpu(relay.add(a, b)), on_cpu(relay.add(c, d)))
+            )
+        )
+
+    # def @main(a, b, c, d, on_device={param_device_types=[1,1,1,1], result_device_type=2}) {
+    #    subtract(cpu_to_gpu(fixed_cpu(add(a, b)), cpu_to_gpu(fixed_cpu(add(c, d))))
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [a, b, c, d],
+                    relay.subtract(
+                        cpu_to_gpu(fixed_cpu(relay.add(a, b))),
+                        cpu_to_gpu(fixed_cpu(relay.add(c, d))),
+                    ),
+                ),
+                [CPU, CPU, CPU, CPU],
+                GPU,
+            )
+        )
+
+    def ref(a, b, c, d):
+        return np.subtract(np.add(a, b), np.add(c, d))
+
+    exercise(input(), expected(), ref, rands(shape, 4))
+
+
+def test_sharing():
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+
+    # def @main(a, b) {
+    #   %0 = add(a, b)
+    #   subtract(on_cpu(%0), %0) }
+    def input():
+        add = relay.add(a, b)
+        return tvm.IRModule.from_expr(
+            relay.Function([a, b], relay.subtract(on_cpu(add), on_cpu(add)))
+        )
+
+    # def @main(a, b, on_device={param_device_types=[1,1], result_device_type=2}) {
+    #    %0 = add(a, b)
+    #    subtract(cpu_to_gpu(fixed_cpu(%0), cpu_to_gpu(fixed_cpu(%0)))
+    def expected():
+        add = relay.add(a, b)
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [a, b], relay.subtract(cpu_to_gpu(fixed_cpu(add)), cpu_to_gpu(fixed_cpu(add)))
+                ),
+                [CPU, CPU],
+                GPU,
+            )
+        )
+
+    def ref(a, b):
+        x = np.add(a, b)
+        return np.subtract(x, x)
+
+    exercise(input(), expected(), ref, rands(shape, 2))
+
+
+def test_let_on_cpu():
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    c = relay.var("c", shape=shape)
+    d = relay.var("d", shape=shape)
+    l = relay.Var("l")
+    r = relay.Var("r")
+
+    # def @main(a, b, c, d) {
+    #   let l = add(a, b);
+    #   let r = add(c, d);
+    #   subtract(on_cpu(l), r)
+    # }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [a, b, c, d],
+                relay.Let(
+                    l, relay.add(a, b), relay.Let(r, relay.add(c, d), relay.subtract(on_cpu(l), r))
+                ),
+            )
+        )
+
+    # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) {
+    #    let l = fixed_cpu(add(a, b));
+    #    let r = add(c, d);
+    #    subtract(cpu_to_gpu(l), r)
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [a, b, c, d],
+                    relay.Let(
+                        l,
+                        fixed_cpu(relay.add(a, b)),
+                        relay.Let(r, relay.add(c, d), relay.subtract(cpu_to_gpu(l), r)),
+                    ),
+                ),
+                [CPU, CPU, GPU, GPU],
+                GPU,
+            )
+        )
+
+    def ref(a, b, c, d):
+        return np.subtract(np.add(a, b), np.add(c, d))
+
+    exercise(input(), expected(), ref, rands(shape, 4))
+
+
+def test_func_param_on_cpu():
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    c = relay.var("c", shape=shape)
+    d = relay.var("d", shape=shape)
+    f = relay.Var("f")
+    x = relay.Var("x")
+    y = relay.Var("y")
+
+    # def @main(a, b, c, d) {
+    #   let f = fn(x, y) { on_cpu(add(x, y)) }   -- forces both body and result on CPU
+    #   subtract(f(a, b), add(c, d))
+    # }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [a, b, c, d],
+                relay.Let(
+                    f,
+                    relay.Function([x, y], on_cpu(relay.add(x, y))),
+                    relay.subtract(relay.Call(f, [a, b]), relay.add(c, d)),
+                ),
+            )
+        )
+
+    # def @main(a, b, c, d, on_device={param_device_types=[1,1,1,1], result_device_type=1}) {
+    #   let f = fn(x, y, on_device={param_device_types[1,1], result_device_type=1}) {
+    #     add(x, y)
+    #   };
+    #   subtract(f(a, b), add(c, d))
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [a, b, c, d],
+                    relay.Let(
+                        f,
+                        relay.annotation.function_on_device(
+                            relay.Function([x, y], relay.add(x, y)), [CPU, CPU], CPU
+                        ),
+                        relay.subtract(relay.Call(f, [a, b]), relay.add(c, d)),
+                    ),
+                ),
+                [CPU, CPU, CPU, CPU],
+                CPU,
+            )
+        )
+
+    def ref(a, b, c, d):
+        return np.subtract(np.add(a, b), np.add(c, d))
+
+    exercise(input(), expected(), ref, rands(shape, 4))
+
+
+def test_func_result_on_cpu():
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    c = relay.var("c", shape=shape)
+    d = relay.var("d", shape=shape)
+    f = relay.Var("f")
+    x = relay.Var("x")
+    y = relay.Var("y")
+
+    # def @main(a, b, c, d) {
+    #   let f = fn(x, y) { add(x, y) }
+    #   subtract(on_cpu(f(a, b)), add(c, d))
+    # }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [a, b, c, d],
+                relay.Let(
+                    f,
+                    relay.Function([x, y], relay.add(x, y)),
+                    relay.subtract(on_cpu(relay.Call(f, [a, b])), relay.add(c, d)),
+                ),
+            )
+        )
+
+    # def @main(a, b, c, d, on_device={param_device_types=[1,1,2,2], result_device_type=2}) {
+    #   let f = fixed_cpu(fn(x, y, on_device={param_device_types=[1,1], result_device_type=1}) {
+    #     add(x, y)
+    #   });
+    #   subtract(cpu_to_gpu(fixed_cpu(f(a, b))), add(c, d))
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [a, b, c, d],
+                    relay.Let(
+                        f,
+                        fixed_cpu(
+                            relay.annotation.function_on_device(
+                                relay.Function([x, y], relay.add(x, y)), [CPU, CPU], CPU
+                            )
+                        ),
+                        relay.subtract(
+                            cpu_to_gpu(fixed_cpu(relay.Call(f, [a, b]))), relay.add(c, d)
+                        ),
+                    ),
+                ),
+                [CPU, CPU, GPU, GPU],
+                GPU,
+            )
+        )
+
+    def ref(a, b, c, d):
+        return np.subtract(np.add(a, b), np.add(c, d))
+
+    exercise(input(), expected(), ref, rands(shape, 4))
+
+
+def test_higher_order():
+    shape = (N, M)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
+    f = relay.Var("f")
+    g = relay.Var("g")
+    a = relay.Var("a")
+    h = relay.Var("h")
+    b = relay.Var("b")
+
+    # The constraint on a flows back to y via f and h
+    # def @main(x, y) {
+    #   let f = fn(g) { fn(a) { add(g(on_cpu(a)), x) } }
+    #   let h = fn(b) { relu(b) }
+    #   subtract(x, f(h)(y))
+    # }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [x, y],
+                relay.Let(
+                    f,
+                    relay.Function(
+                        [g], relay.Function([a], relay.add(relay.Call(g, [on_cpu(a)]), x))
+                    ),
+                    relay.Let(
+                        h,
+                        relay.Function([b], relay.negative(b)),
+                        relay.subtract(x, relay.Call(relay.Call(f, [h]), [y])),
+                    ),
+                ),
+            )
+        )
+
+    # def @main(x, y, on_device={param_device_types=[GPU, CPU], result_device_type=GPU}) {
+    #   let f = fn(g, on_device={param_device_types=[GPU], result_device_type=GPU}) {
+    #     fn(a, on_device={param_device_types=[CPU], result_device_type=GPU}) {
+    #       add(g(cpu_to_gpu(a)), x)
+    #     }
+    #   }
+    #   let h = fn(b, on_device={param_device_types=[GPU], result_device_type=GPU}) { negative(b) }
+    #   subtract(x, f(h)(y))
+    # }
+    def expected():
+        # Yeah, this is illegible.
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [x, y],
+                    relay.Let(
+                        f,
+                        relay.annotation.function_on_device(
+                            relay.Function(
+                                [g],
+                                relay.annotation.function_on_device(
+                                    relay.Function(
+                                        [a], relay.add(relay.Call(g, [cpu_to_gpu(a)]), x)
+                                    ),
+                                    [CPU],
+                                    GPU,
+                                ),
+                            ),
+                            [GPU],
+                            GPU,
+                        ),
+                        relay.Let(
+                            h,
+                            relay.annotation.function_on_device(
+                                relay.Function([b], relay.negative(b)), [GPU], GPU
+                            ),
+                            relay.subtract(x, relay.Call(relay.Call(f, [h]), [y])),
+                        ),
+                    ),
+                ),
+                [GPU, CPU],
+                GPU,
+            )
+        )
+
+    def ref(x, y):
+        def f(g):
+            return lambda a: np.add(g(a), x)
+
+        def h(b):
+            return np.negative(b)
+
+        return np.subtract(x, f(h)(y))
+
+    exercise(input(), expected(), ref, rands(shape, 2))
+
+
+def test_function_in_tuple():
+    shape = (N, M)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    y = relay.var("y", shape=shape)
+    f = relay.Var("f")
+    t = relay.Var("t")
+
+    # Since f end up in a tuple its argument and result is forced to be on the CPU
+    # def @main(x, y) {
+    #   let f = fn(a, b) { add(a, on_cpu(b)) }
+    #   let t = (f, x)
+    #   t.0(t.1, y)
+    # }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [x, y],
+                relay.Let(
+                    f,
+                    relay.Function([a, b], relay.add(a, on_cpu(b))),
+                    relay.Let(
+                        t,
+                        relay.Tuple([f, x]),
+                        relay.Call(relay.TupleGetItem(t, 0), [relay.TupleGetItem(t, 1), y]),
+                    ),
+                ),
+            )
+        )
+
+    # def @main(x, y, on_device={param_device_types=[1,1], result_device_type=1}) {
+    #   let f = fn(a, b, on_device={param_device_types=[1,1], result_device_type=1}) { add(a, b) }
+    #   let t = (f, x)
+    #   t.0(t.1, y)
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [x, y],
+                    relay.Let(
+                        f,
+                        relay.annotation.function_on_device(
+                            relay.Function([a, b], relay.add(a, b)), [CPU, CPU], CPU
+                        ),
+                        relay.Let(
+                            t,
+                            relay.Tuple([f, x]),
+                            relay.Call(relay.TupleGetItem(t, 0), [relay.TupleGetItem(t, 1), y]),
+                        ),
+                    ),
+                ),
+                [CPU, CPU],
+                CPU,
+            )
+        )
+
+    def ref(x, y):
+        return np.add(x, y)
+
+    exercise(input(), expected(), ref, rands(shape, 2))
+
+
+def test_device_copy():
+    shape = (N, M)
+    x = relay.var("x", shape=shape)
+    const = relay.const(rand(shape))
+
+    # def @main(x) { add(cpu_to_gpu(x), const) }
+    def input():
+        return tvm.IRModule.from_expr(relay.Function([x], relay.add(cpu_to_gpu(x), const)))
+
+    # def @main(x, on_device={param_device_types=[1], result_device_type=2}) {
+    #   add(cpu_to_gpu(x), constant)
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function([x], relay.add(cpu_to_gpu(x), const)), [CPU], GPU
+            )
+        )
+
+    def ref(x):
+        return np.add(x, const.data.numpy())
+
+    exercise(input(), expected(), ref, rands(shape, 1))
+
+
+def test_shape_func():
+    p = relay.var("p")
+    data_shape = (relay.Any(),)
+    x = relay.var("x", shape=data_shape)
+    y = relay.var("y", shape=data_shape)
+    s = relay.var("s", shape=(1,), dtype="int64")
+
+    # def @main(x, s) {
+    #   let p = fixed_gpu(fn(y) { relu(y) })    -- simulates a primitive post FuseOps
+    #   shape_func(p,
+    #              (shape_of(fixed_gpu(x)),),   -- shape of primitive input tensor
+    #              (s,),                        -- space for output shape
+    #              [False])                     -- calling with input shapes not tensors
+    # }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [x, s],
+                relay.Let(
+                    p,
+                    fixed_gpu(relay.Function([y], relay.nn.relu(y))),
+                    relay.op.vm.shape_func(
+                        p,
+                        relay.Tuple([relay.op.vm.shape_of(fixed_gpu(x))]),
+                        relay.Tuple([s]),
+                        [False],
+                    ),
+                ),
+            )
+        )
+
+    # def @main(x, s, on_device={param_device_types=[2,1], result_device_type=1}) {
+    #   let p = fixed_gpu(fn(y, param_device_types=[2], result_device_type=2) { relu(y) })
+    #   shape_func(p,
+    #              (shape_of(x),),
+    #              (s,),
+    #              [False])
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [x, s],
+                    relay.Let(
+                        p,
+                        fixed_gpu(
+                            relay.annotation.function_on_device(
+                                relay.Function([y], relay.nn.relu(y)), [GPU], GPU
+                            )
+                        ),
+                        relay.op.vm.shape_func(
+                            p, relay.Tuple([relay.op.vm.shape_of(x)]), relay.Tuple([s]), [False]
+                        ),
+                    ),
+                ),
+                [GPU, CPU],
+                CPU,
+            )
+        )
+
+    # Don't try to execute, too fiddly to setup.
+    exercise(input(), expected(), None, None)
+
+
+def test_shape_of():
+    compiletime_shape = (relay.Any(), relay.Any())
+    runtime_shape = (N, M)
+    x = relay.var("x", shape=compiletime_shape)
+
+    # We need to use fixed_gpu since the result of on_gpu will default to the result device of @main which is cpu,
+    # which then forces a copy.
+    # TODO(mbs): Perhaps the defaulting heuristics are being too clever?
+    # def @main(x) { shape_of(fixed_gpu(x)) }
+    def input():
+        return tvm.IRModule.from_expr(relay.Function([x], relay.op.vm.shape_of(fixed_gpu(x))))
+
+    # def @main(x, on_device={param_device_types=[2], result_dev_type=1}) {
+    #   shape_of(x)
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function([x], relay.op.vm.shape_of(x)), [GPU], CPU
+            )
+        )
+
+    def ref(x):
+        return x.shape
+
+    exercise(input(), expected(), ref, rands(runtime_shape, 1))
+
+
+def test_alloc_storage():
+    size = relay.Var("size", relay.scalar_type("int64"))
+    alignment = relay.Var("alignment", relay.scalar_type("int64"))
+    main = relay.GlobalVar("main")
+    stdlib = tvm.IRModule()
+    stdlib.import_from_std("core.rly")
+
+    # def @main(size, alignment) { alloc_storage(size, alignment, GPU) }
+    def input():
+        mod = tvm.IRModule()
+        mod.update(stdlib)
+        mod[main] = relay.Function(
+            [size, alignment], relay.op.memory.alloc_storage(size, alignment, GPU)
+        )
+        return mod
+
+    # def @main(size, alignment, on_device={param_device_types=[1,1], result_device_type=2}) {
+    #   alloc_storage(size, alignment, GPU)
+    # }
+    def expected():
+        mod = tvm.IRModule()
+        mod.update(stdlib)
+        mod[main] = relay.annotation.function_on_device(
+            relay.Function([size, alignment], relay.op.memory.alloc_storage(size, alignment, GPU)),
+            [CPU, CPU],
+            GPU,
+        )
+        return mod
+
+    # Don't try to execute, too fiddly to setup.
+    exercise(input(), expected(), None, None)
+
+
+def test_alloc_tensor():
+    stdlib = tvm.IRModule()
+    stdlib.import_from_std("core.rly")
+    sto_type = relay.TypeCall(stdlib.get_global_type_var("Storage"), [])
+    sto = relay.Var("sto", sto_type)
+    main = relay.GlobalVar("main")
+    shape = relay.const(np.array([3, 2]), dtype="int64")
+
+    # def @main(sto) { alloc_tensor(sto, 0, [3, 2]) }
+    def input():
+        mod = tvm.IRModule()
+        mod.update(stdlib)
+        mod[main] = relay.Function(
+            [sto], relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), shape)
+        )
+        return mod
+
+    # def @main(sto, on_device={param_device_types=[2], result_device_type=2}) {
+    #   alloc_tensor(sto, fixed_cpu(0), fixed_cpu([3, 2]))
+    # }
+    def expected():
+        mod = tvm.IRModule()
+        mod.update(stdlib)
+        mod[main] = relay.annotation.function_on_device(
+            relay.Function(
+                [sto],
+                relay.op.memory.alloc_tensor(
+                    sto, fixed_cpu(relay.const(0, dtype="int64")), fixed_cpu(shape)
+                ),
+            ),
+            [GPU],
+            GPU,
+        )
+        return mod
+
+    # Don't try to execute, too fiddly to setup.
+    exercise(input(), expected(), None, None)
+
+
+def test_reshape_tensor():
+    shape = (2, 8)
+    x = relay.var("x", shape=shape, dtype="float32")
+    newshape_expr = relay.const([2, 4, 2], dtype="int64")
+    newshape_prim = [2, 4, 2]
+
+    # def @main(x) { reshape_tensor(x, shape, newshape=[2,4,2]) }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function([x], relay.op.vm.reshape_tensor(x, newshape_expr, newshape_prim))
+        )
+
+    # def @main(x, on_device={param_device_types=[2], result_device_type=2}) {
+    #   reshape_tensor(x, fixed_cpu(shape), newshape=[2,4,2])
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [x], relay.op.vm.reshape_tensor(x, fixed_cpu(newshape_expr), newshape_prim)
+                ),
+                [GPU],
+                GPU,
+            )
+        )
+
+    def ref(x):
+        return np.reshape(x, newshape_prim)
+
+    exercise(input(), expected(), ref, rands(shape, 1))
+
+
+def test_dynamic_input():
+    compiletime_shape = (relay.Any(), relay.Any())
+    runtime_shape = (N, M)
+    x0 = relay.var("x0", shape=compiletime_shape)
+    x1 = relay.var("x1", shape=compiletime_shape)
+
+    # def @main(x0, x1) { add(x0, x1) }
+    def input():
+        return tvm.IRModule.from_expr(relay.Function([x0, x1], relay.add(x0, x1)))
+
+    # def @main(x0, x1), on_device={param_device_types=[2,2], result_device_type=2}) {
+    #   add(x0, x1)
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function([x0, x1], relay.add(x0, x1)), [GPU, GPU], GPU
+            )
+        )
+
+    def ref(x0, x1):
+        return np.add(x0, x1)
+
+    exercise(input(), expected(), ref, rands(runtime_shape, 2))
+
+
+def test_redundant_annotation():
+    shape = (N, M)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
+    z = relay.var("z", shape=shape)
+
+    # def @main(x, y, z) {
+    #   %0 = add(x, y)
+    #   add(subtract(on_cpu(%0), z), on_cpu(%0))
+    # }
+    def input():
+        a = relay.add(x, y)
+        return tvm.IRModule.from_expr(
+            relay.Function([x, y, z], relay.add(relay.subtract(on_cpu(a), z), on_cpu(a)))
+        )
+
+    # def @main(x, y, z, on_device={param_device_types=[1,1,2], result_device_type=2}) {
+    #   %0 = add(x, y)
+    #   add(subtract(cpu_to_gpu(fixed_cpu(%0)), z), cpu_to_gpu(fixed_cpu(%0)))
+    # }
+    def expected():
+        a = relay.add(x, y)
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [x, y, z],
+                    relay.add(
+                        relay.subtract(cpu_to_gpu(fixed_cpu(a)), z), cpu_to_gpu(fixed_cpu(a))
+                    ),
+                ),
+                [CPU, CPU, GPU],
+                GPU,
+            )
+        )
+
+    def ref(x, y, z):
+        a = np.add(x, y)
+        return np.add(np.subtract(a, z), a)
+
+    exercise(input(), expected(), ref, rands(shape, 3))
+
+
+def test_annotate_expr():
+    shape = (N, M)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
+    z = relay.var("z", shape=shape)
+
+    # def @main(x, y, z) { on_cpu(subtract(on_gpu(add(x, y)), z)) } -- forces function result also on cpu
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function([x, y, z], on_cpu(relay.subtract(on_gpu(relay.add(x, y)), z)))
+        )
+
+    # def @main(x, y, z, on_device={param_device_types=[2,2,1], result_device_type=1}) {
+    #   subtract(gpu_to_cpu(fixed_gpu(add(x, y))), z)
+    # }
+    def expected():
+        add = relay.add(x, y)
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [x, y, z], relay.subtract(gpu_to_cpu(fixed_gpu(relay.add(x, y))), z)
+                ),
+                [GPU, GPU, CPU],
+                CPU,
+            )
+        )
+
+    def ref(x, y, z):
+        return np.subtract(np.add(x, y), z)
+
+    exercise(input(), expected(), ref, rands(shape, 3))
+
+
+def test_annotate_all():
+    shape = (N, M)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
+    z = relay.var("z", shape=shape)
+
+    # def @main(x, y, z) { on_cpu(subtract(on_cpu(add(x, y)), z) }  -- top-level also forces result to be CPU
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function([x, y, z], on_cpu(relay.subtract(on_cpu(relay.add(x, y)), z)))
+        )
+
+    # def @main(x, y, z, on_device={param_device_types=[CPU, CPU, CPU], result_device_type=CPU}) {
+    #   subtract(add(x, y), z)
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function([x, y, z], relay.subtract(relay.add(x, y), z)), [CPU, CPU, CPU], CPU
+            )
+        )
+
+    def ref(x, y, z):
+        return np.subtract(np.add(x, y), z)
+
+    exercise(input(), expected(), ref, rands(shape, 3))
+
+
+def test_conv_network():
+    r"""The network and devices are as follows:
+    data1     data2    <--- CPU
+      |         |
+    conv2d    conv2d   <--- CPU
+       \       /
+        \     /
+          add          <--- GPU
+           |
+         conv2d        <--- CPU
+           |
+        <result>       <--- CPU
+    """
+    batch_size = 1
+    dshape = (batch_size, 64, 56, 56)
+    wshape = (64, 64, 3, 3)
+    weight = relay.var("weight", shape=wshape)
+    data1 = relay.var("data1", shape=dshape)
+    data2 = relay.var("data2", shape=dshape)
+
+    def input():
+        conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
+        conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
+        add = relay.add(on_cpu(conv2d_1), on_cpu(conv2d_2))
+        conv2d_3 = relay.nn.conv2d(
+            on_gpu(add), weight, channels=64, kernel_size=(3, 3), padding=(1, 1)
+        )
+        return tvm.IRModule.from_expr(relay.Function([data1, data2, weight], on_cpu(conv2d_3)))
+
+    def expected():
+        conv2d_1 = relay.nn.conv2d(data1, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
+        conv2d_2 = relay.nn.conv2d(data2, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
+        add = relay.add(cpu_to_gpu(fixed_cpu(conv2d_1)), cpu_to_gpu(fixed_cpu(conv2d_2)))
+        conv2d_3 = relay.nn.conv2d(
+            gpu_to_cpu(fixed_gpu(add)), weight, channels=64, kernel_size=(3, 3), padding=(1, 1)
+        )
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function([data1, data2, weight], conv2d_3), [CPU, CPU, CPU], CPU
+            )
+        )
+
+    # Don't try to execute, we don't have a reference conv2d
+    exercise(input(), expected(), None, None)
+
+
+def test_tuple_get_item():
+    shape = (3, 3, 4)
+    x = relay.Var("x", relay.ty.TensorType(shape, "float32"))
+    t = relay.Var("t")
+
+    # We'll device copy after projection, not before.
+    # def @main(x) {
+    #   let t = split(x, 3);
+    #   subtract(on_cpu(t).0, on_cpu(t).1)
+    # }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [x],
+                relay.Let(
+                    t,
+                    relay.op.split(x, 3).astuple(),
+                    on_gpu(
+                        relay.subtract(
+                            relay.TupleGetItem(on_cpu(t), 0), relay.TupleGetItem(on_cpu(t), 1)
+                        )
+                    ),
+                ),
+            )
+        )
+
+    # def @main(x, on_device={param_device_type=[1], result_device_type=2}) {
+    #   let t = fixed_cpu(split(x, 3))
+    #   subtract(cpu_to_gpu(fixed_cpu(t.0)), cpu_to_gpu(fixed_cpu(t.1)))
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [x],
+                    relay.Let(
+                        t,
+                        fixed_cpu(relay.op.split(x, 3).astuple()),
+                        relay.subtract(
+                            cpu_to_gpu(fixed_cpu(relay.TupleGetItem(t, 0))),
+                            cpu_to_gpu(fixed_cpu(relay.TupleGetItem(t, 1))),
+                        ),
+                    ),
+                ),
+                [CPU],
+                GPU,
+            )
+        )
+
+    def ref(x):
+        t = np.split(x, 3)
+        return np.subtract(t[0], t[1])
+
+    exercise(input(), expected(), ref, rands(shape, 1))
+
+
+def test_propogation():
+    R""" The network and devices are as follows:
+                  x           <--- CPU
+                  |
+                 log          <--- CPU
+                /   \
+              log2 log10      <--- GPU
+                \   /
+                 add          <--- GPU
+                  |
+                 tan          <--- CPU
+                  |
+               <result>       <--- CPU
+    """
+    shape = (N, M)
+    x = relay.var("x", shape=shape)
+
+    def input():
+        log = relay.log(x)
+        log2 = relay.log2(on_cpu(log))
+        log10 = relay.log10(on_cpu(log))
+        add = relay.add(on_gpu(log2), on_gpu(log10))
+        tan = relay.tan(on_gpu(add))
+        return tvm.IRModule.from_expr(relay.Function([x], on_cpu(tan)))
+
+    def expected():
+        log = relay.log(x)
+        log2 = relay.log2(cpu_to_gpu(fixed_cpu(log)))
+        log10 = relay.log10(cpu_to_gpu(fixed_cpu(log)))
+        add = relay.add(log2, log10)
+        tan = relay.tan(gpu_to_cpu(fixed_gpu(add)))
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(relay.Function([x], tan), [CPU], CPU)
+        )
+
+    def ref(x):
+        y = np.log(x)
+        return np.tan(np.add(np.log2(y), np.log10(y)))
+
+    exercise(input(), expected(), ref, rands(shape, 1))
+
+
+def test_fusible_network():
+    R""" The network is as follows:
+               x     y      <--- GPU
+                \   /
+                 add        <--- GPU
+                /   \
+           negative  \      <--- CPU
+              \       \
+               \  negative  <--- GPU
+                \   /
+                 add        <--- GPU
+                  |
+               negative     <--- CPU
+                  |
+               <result>     <--- CPU
+    """
+    shape = (N, M)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
+
+    def input():
+        add = relay.add(x, y)
+        sqrt = relay.negative(on_gpu(add))
+        log = relay.negative(add)
+        subtract = relay.add(on_cpu(sqrt), log)
+        exp = relay.negative(on_gpu(subtract))
+        return tvm.IRModule.from_expr(relay.Function([x, y], on_cpu(exp)))
+
+    def expected():
+        add = relay.add(x, y)
+        sqrt = relay.negative(gpu_to_cpu(fixed_gpu(add)))
+        log = relay.negative(add)
+        subtract = relay.add(cpu_to_gpu(fixed_cpu(sqrt)), log)
+        exp = relay.negative(gpu_to_cpu(fixed_gpu(subtract)))
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(relay.Function([x, y], exp), [GPU, GPU], CPU)
+        )
+
+    def ref(x, y):
+        z = np.add(x, y)
+        return np.negative(np.add(np.negative(z), np.negative(z)))
+
+    exercise(input(), expected(), ref, rands(shape, 2))
+
+
+def test_unpropagatable_graph():
+    r"""The network is as follows:
+    a      b            <--- CPU
+    \     /
+     \   /   c     d    <--- GPU
+      \ /    \     /
+      add     \   /     <--- CPU
+       \       \ /
+        \    multiply   <--- GPU
+         \     /
+        subtract        <--- CPU
+           |
+        <result>        <--- CPU
+    """
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    c = relay.var("c", shape=shape)
+    d = relay.var("d", shape=shape)
+
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [a, b, c, d],
+                on_cpu(relay.subtract(on_cpu(relay.add(a, b)), on_gpu(relay.multiply(c, d)))),
+            )
+        )
+
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [a, b, c, d],
+                    relay.subtract(relay.add(a, b), gpu_to_cpu(fixed_gpu(relay.multiply(c, d)))),
+                ),
+                [CPU, CPU, GPU, GPU],
+                CPU,
+            )
+        )
+
+    def ref(a, b, c, d):
+        return np.subtract(np.add(a, b), np.multiply(c, d))
+
+    exercise(input(), expected(), ref, rands(shape, 4))
+
+
+def test_conditional():
+    shape = (N, M)
+    x = relay.Var("x", relay.ty.scalar_type("bool"))
+    y = relay.var("y", shape=shape)
+    z = relay.var("z", shape=shape)
+    f = relay.Var("f")
+    g = relay.Var("g")
+    h = relay.Var("h")
+    a1 = relay.Var("a")
+    a2 = relay.Var("a")
+
+    # def @main(x, y, z) {
+    #   let f = fn(a) { add(a, fixed_cpu(y)) }
+    #   let g = fn(a) { subtract(a, y) }
+    #   let h = if (x) {
+    #     f
+    #   } else {
+    #     g
+    #   }
+    #   h(z)
+    # }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [x, y, z],
+                relay.Let(
+                    f,
+                    relay.Function([a1], relay.add(a1, fixed_cpu(y))),
+                    relay.Let(
+                        g,
+                        relay.Function([a2], relay.subtract(a2, y)),
+                        relay.Let(h, relay.If(x, f, g), relay.Call(h, [z])),
+                    ),
+                ),
+            )
+        )
+
+    # def @main(x, y, z, on_device={param_device_types=[1,1,1], result_device_type=1}) {
+    #   let f = fn(a, on_device={param_device_types=[1], result_device_type=1}) { add(a, y) }
+    #   let g = fn
+    #   (a, on_device={param_device_types=[1], result_device_type=1}) { subtract(a, y) }
+    #   let h = if (x) {
+    #     f
+    #   } else {
+    #     g
+    #   }
+    #   h(z)
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [x, y, z],
+                    relay.Let(
+                        f,
+                        relay.annotation.function_on_device(
+                            relay.Function([a1], relay.add(a1, y)), [CPU], CPU
+                        ),
+                        relay.Let(
+                            g,
+                            relay.annotation.function_on_device(
+                                relay.Function([a2], relay.subtract(a2, y)), [CPU], CPU
+                            ),
+                            relay.Let(h, relay.If(x, f, g), relay.Call(h, [z])),
+                        ),
+                    ),
+                ),
+                [CPU, CPU, CPU],
+                CPU,
+            )
+        )
+
+    def ref(x, y, z):
+        def f(a):
+            return np.add(a, y)
+
+        def g(a):
+            return np.subtract(a, y)
+
+        h = f if x else g
+        return h(z)
+
+    exercise(input(), expected(), ref, [True, rand(shape), rand(shape)])
+
+
+def test_global():
+    shape = (N, M)
+    a = relay.var("a", shape=shape)
+    b = relay.var("b", shape=shape)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
+    f = relay.GlobalVar("f")
+    main = relay.GlobalVar("main")
+
+    # def @f(a, b) { add(a, on_cpu(b)) }
+    # def @main(x, y) { @f(y, x) }
+    def input():
+        mod = tvm.IRModule()
+        mod[f] = relay.Function(
+            [a, b], relay.add(a, on_cpu(b)), relay.ty.TensorType(shape, "float32")
+        )
+        mod[main] = relay.Function(
+            [x, y], relay.Call(f, [y, x]), relay.ty.TensorType(shape, "float32")
+        )
+        return mod
+
+    # def @f(a, b, on_device={param_device_types=[2,1], result_device_type=2}) { add(a, on_cpu(b)) }
+    # def @main(x, y, on_device={param_device_types=[1,2], result_device_type=2}) { @f(y, x) }
+    def expected():
+        mod = tvm.IRModule()
+        mod[f] = relay.annotation.function_on_device(
+            relay.Function(
+                [a, b], relay.add(a, cpu_to_gpu(b)), relay.ty.TensorType(shape, "float32")
+            ),
+            [GPU, CPU],
+            GPU,
+        )
+        mod[main] = relay.annotation.function_on_device(
+            relay.Function([x, y], relay.Call(f, [y, x]), relay.ty.TensorType(shape, "float32")),
+            [CPU, GPU],
+            GPU,
+        )
+        return mod
+
+    def ref(x, y):
+        def f(a, b):
+            return np.add(a, b)
+
+        return f(x, y)
+
+    exercise(input(), expected(), ref, rands(shape, 2))
+
+
+# Note that match and ADTs don't appear to be supported for direct AST
+# construction.
+
+
+def test_ref():
+    shape = (N, M)
+    x = relay.var("x", shape=shape)
+    y = relay.var("y", shape=shape)
+    r = relay.var("r")
+    dummy = relay.var("dummy")
+
+    # def @main(x, y) {
+    #   r = ref(x)
+    #   ref_write(r, on_cpu(y))
+    #   add(x, ref_read(r))
+    # }
+    def input():
+        return tvm.IRModule.from_expr(
+            relay.Function(
+                [x, y],
+                relay.Let(
+                    r,
+                    relay.RefCreate(x),
+                    relay.Let(dummy, relay.RefWrite(r, on_cpu(y)), relay.add(x, relay.RefRead(r))),
+                ),
+            )
+        )
+
+    # def @main(x, y, on_device={param_device_types=[GPU, CPU], result_device_type=GPU}) {
+    #   r = ref(x)
+    #   ref_write(r, cpu_to_gpu(y))
+    #   add(x, ref_read(r))
+    # }
+    def expected():
+        return tvm.IRModule.from_expr(
+            relay.annotation.function_on_device(
+                relay.Function(
+                    [x, y],
+                    relay.Let(
+                        r,
+                        relay.RefCreate(x),
+                        relay.Let(
+                            dummy, relay.RefWrite(r, cpu_to_gpu(y)), relay.add(x, relay.RefRead(r))
+                        ),
+                    ),
+                ),
+                [GPU, CPU],
+                GPU,
+            )
+        )
+
+    def ref(x, y):
+        r = {"value": x}
+        r["value"] = y
+        return np.add(x, r["value"])
+
+    # Don't try to execute, no backend currently supports both cross-devices and references.
+    exercise(input(), expected(), None, None)
+
+
+if __name__ == "__main__":
+    test_plain()

Review comment:
       `sys.exit(pytest.main([__file__] + sys.argv[1:]))`
   
   one day we will write a helper :)

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
+ * same device. However each call site can use a different device. In other words primitives are
+ * 'device polymorphic' since we compile and execute them for each required device.
+ *
+ * For most Relay expressions the device for the overall expression is the same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on different devices, even
+ *    different from each other. Every call to the function must use the same choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default device.
+ *  - Unconstrained function parameters devices default to the device for the function result.
+ *  - Unconstrained let-bound expression devices default to the device for the overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to
+ * the global default device. Worth a design doc with motivating examples I think.
+ *
+ * Phase 3
+ * -------
+ * Finally, the result of this analysis is reified into the result as:
+ *  - Additional "on_device" attributes (an Attrs resolving to a \p FunctionOnDeviceAttrs) for
+ *    every function (both top-level and local). These describe the devices for the function's
+ *    parameters and the result.
+ *  - Additional "device_copy" CallNodes where a copy is required in order to respect the
+ *    intent of the original "on_device" CallNodes.
+ *  - Additional "on_device" CallNodes where the device type of an expression does not match
+ *    that of the lexically enclosing "on_device" CallNode or function attribute. In practice
+ *    this means "on_device" CallNodes may appear in two places:
+ *     - On a let-bound expression if its device differs from the overall let expression.
+ *     - On a call argument if its device differs from the call result. In particular, the
+ *       argument to a "device_copy" call will always be wrapped in an "on_device". (That may
+ *       seem pedantic but simplifies downstream handling.)
+ *    However since we make it easy to track devices for variables we never wrap an "on_device"
+ *    around a var or global var. These uses of "on_device" imply both the argument and result are
+ *    on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true,
+ *    which helps make this pass idempotent.
+ *
+ * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms to recover the device
+ * for any expression for their own use, e.g. during memory planning. All downstream passes must
+ * preserve the lexical scoping of the "on_device" CallNodes. In particular conversion to ANF
+ * must respect the lexical scoping convention:
+ * \code
+ * f(on_device(g(h(a, b), c), device_type=CPU))
+ * ==>
+ * let %x0 = on_device(h(a, b), device_type=CPU)
+ * let %x1 = on_device(g(%x0), device-type=CPU)
+ * f(on_device(%x1, device_type=CPU))
+ * \endcode
+ *
+ * This pass should be run before FuseOps it can use device-specific fusion rules.
+ *
+ * 'Stored on' vs 'Executes on'
+ * ----------------------------
+ * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the
+ * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for
+ * primitives.
+ *
+ * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are
+ * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific
+ * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to
+ * know exactly which device (possibly one of a number of available 'CPU'-like devices) is
+ * responsible for execution. Currently that's handled independently by the \p AnnotateTargets
+ * pass, but we'd like to fold that into device planning here to ensure everything is consistent.
+ *
+ * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay
+ * expression (eg an if expression) on one device even though the tensor data resides on
+ * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on'
+ * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just
+ * compile the function body for the function's result device.
+ *
+ * This works after conversion to ANF provided the compilation for a let expression is prepared
+ * to make a cross-device call. However we leave it to a downstream transformation to heuristically
+ * minimize cross-device calls by moving device copies out of functions. E.g.:
+ * \code
+ *   def @f() {  // execute on CPU
+ *     let x = on_device(...GPU computation..., device_type=GPU);
+ *     device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
+ *   }
+ *   def @main() {
+ *     ... call @f() on CPU ...
+ *   }
+ * \endcode
+ * could be rewritten to:
+ * \code
+ *   def @f() {  // execute on GPU
+ *     let x = ...GPU computation...;
+ *     ...GPU computation...
+ *   }
+ *   def @main() {
+ *     let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
+ *     ... use x on CPU ...
+ *   }
+ * \endcode
+ *
+ * Higher-order shenanigans
+ * ------------------------
+ * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions
+ * as arguments (even anonymous functions), return functions, evaluate conditional expressions
+ * over functions, and so on. We handle this during constraint solving using the domain:
+ * \code
+ *   D  ::= <specific device type>   -- first-order
+ *        | fn(D,...,D):D            -- higher-order
+ * \endcode
+ * In this way we can determine the device for all function parameters and results. E.g. for
+ * \code
+ *   let f = fn(x, y) { ... }
+ *   let g = fn(f, z) { f(z, z) }
+ *   g(f, on_device(..., device_type=CPU))
+ * \endcode
+ * the parameters \p x and \p y will be on the CPU.
+ *
+ * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a
+ * function. Our analysis must guarantee that the function's parameters and result devices are
+ * consistent for \p e2, \p e3, and the context of the call. But:
+ *  - Which device holds the closure result of evaluating \p e1 ?
+ *  - If \p e2 is of function type, what does that mean when we say every function parameter
+ *    is on a device?
+ *  - If \p e1 returns a function, what does that mean when we say every function result is
+ *    on a device?
+ *
+ * Since higher-order aspects are later compiled away (by 'defunctionalization'
+ * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular,
+ * we really don't want our domain \p D to allow for yet another device for the function closure.
+ * So we'll just force the 'device for a function' to be the same as the device for the function's
+ * result using the notion of the 'result domain' for a domain:
+ * \code
+ *   result_domain(<specific device type>) = <specific device type>
+ *   result_domain(fn(D1,...,Dn):Dr)       = result_domain(Dr)
+ * \endcode
+ *
+ * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the
+ * analysis encounters a function inside one of those it simply forces all argument and result
+ * devices for the function to match the device for the first-order expression. For example,
+ * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function
+ * parameters and result must similarly be on the GPU.
+ *
+ * -------
+ * | AOR |  This pass supports all of Relay.
+ * -------
+ *    ^
+ *    |
+ *    `-- Mark's stamp of completeness :-)
+ *
+ * TODO(mbs):
+ *  * Though on_device is the identity for all types we can't wrap it around functions/constructors
+ *    taking type args (or at least not without changing type_infer.cc to see through them).
+ *    This is not currently handled generally.
+ *  * Proper diagnostics for unification failure using spans.
+ *  * Make sure the pass is idempotent even after FuseOps etc.
+ *  * Support application of constructors properly. Are they device polymorphic?
+ *  * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'.
+ *  * Support running the pass post FuseOps (so need to understand primitive functions, both
+ *    outlines and lined) and post the VM transforms (probably need to support more intrinsic
+ *    forms?).
+ *  * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default
+ *    device for primitives vs the default device for the rest of Relay.
+ *  * We'll probably need some support for partial 'device polymorphism' for functions once we
+ *    incorporate targets and memory scopes into the domain. For example it's ok for the function
+ *    body to be executed on different device ids provided they have the same target and memory
+ *    scope.
+ *  * Might be simpler to just let every type have a device annotation rather than work in
+ *    a separate domain?
+ *  * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies.
+ *  * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls
+ *    in tuples at the top level of function bodies or main expression, irrespective of the
+ *    "on_device" body. What's up with that?
+ */
+
+#include "./device_planner.h"
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+
+#include <unordered_map>
+
+#include "../op/annotation/annotation.h"
+#include "../op/memory/device_copy.h"
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+namespace {
+
+/*!
+ * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather
+ * than the original "device_copy" operator.
+ *
+ * See te_compiler.cc for where this rewriting occurs.
+ */
+DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
+  auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
+  if (tir_call_attrs == nullptr) {
+    return {};
+  }
+  if (tir_call_attrs->metadata.count("source_device") != 1 ||
+      tir_call_attrs->metadata.count("dst_device") != 1) {
+    return {};
+  }
+  ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
+  return {
+      call_node->args[0],
+      static_cast<DLDeviceType>(
+          Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
+      static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
+}
+
+class DeviceDomain;
+using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
+
+/******
+****** Domains
+******/
+
+/*!
+ * \brief Represents the domain over which we collect equality constraints.
+ *
+ * \code
+ *   D ::= ?x?                  -- first order, free
+ *       | <device_type>        -- first order, bound
+ *       | fn(D1, ..., Dn):Dr   -- higher order
+ * \endcode
+ *
+ * We require a function value to be on the same device as its result. To support that we need
+ * a notion of the 'result domain' of a domain:
+ * \code
+ *   result_domain(?x?)                = ?x?
+ *   result_domain(<device_type>)      = <device_type>
+ *   result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr)
+ * \endcode
+ */
+class DeviceDomain {
+ public:
+  /*!
+   * \brief Constructs a first-order domain of \p device_type, which may be
+   * \p kInvalidDeviceType to indicate the domain is free.
+   */
+  explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {}
+
+  /*!
+   * \brief Constructs a higher-order domain, where \p args_and_result contain the
+   * function argument and result domains in order.
+   */
+  explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result)
+      : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {}
+
+  /*! \brief Returns true if domain is first-order and free. */
+  bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); }
+
+  /*! \brief Returns true if domain is higher-order. */
+  bool is_higher_order() const { return !args_and_result_.empty(); }
+
+  DLDeviceType first_order_device_type() const {
+    ICHECK(args_and_result_.empty());
+    return device_type_;
+  }
+
+  size_t function_arity() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.size() - 1UL;
+  }
+
+  DeviceDomainPtr function_param(size_t i) const {
+    ICHECK(!args_and_result_.empty());
+    ICHECK_LT(i + 1, args_and_result_.size());
+    return args_and_result_[i];
+  }
+
+  DeviceDomainPtr function_result() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.back();
+  }
+
+ private:
+  /*!
+   * \brief If this is a function domain then always kInvalidDevice. Otherwise will be
+   * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is
+   * bound.
+   */
+  const DLDeviceType device_type_;
+
+  /*!
+   * \brief If this is a function domain then the sub-domains for each of the function's
+   * arguments, and the domain for its result. Otherwise empty.
+   */
+  const std::vector<DeviceDomainPtr> args_and_result_;
+
+  friend struct DeviceDomainHash;
+  friend struct DeviceDomainEqual;
+  friend class DeviceDomains;
+};
+
+// Ye olde boost hash mixer.
+constexpr size_t mix(size_t h1, size_t h2) {
+  return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+}
+
+// The following hash and equality helpers give each free first-order domain pointer its own
+// distinct identity.
+struct DeviceDomainHash {
+  size_t operator()(const DeviceDomainPtr& domain) const {
+    if (domain->is_free()) {
+      // Give each free first-order domain its own identity.
+      return static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get()));
+    } else {
+      size_t h = domain->args_and_result_.size();
+      h = mix(h, std::hash<int>()(static_cast<int>(domain->device_type_)));
+      for (const auto& sub_domain_ptr : domain->args_and_result_) {
+        h = mix(h, DeviceDomainHash()(sub_domain_ptr));
+      }
+      return h;
+    }
+  }
+};
+
+struct DeviceDomainEqual {
+ public:
+  bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const {
+    if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) {
+      // Mismatched arities are never equal.
+      // (Though we'll never ask to do such a comparison explicitly, the hash map
+      // may do so implicitly due to hash collisions.)
+      return false;
+    }
+    if (lhs->is_free() && rhs->is_free()) {
+      // Compare first-order free domains by their address.
+      return lhs.get() == rhs.get();
+    }
+    if (lhs->args_and_result_.empty()) {
+      // Compare first-order domains by their device type -- free vs bound will compare as false.
+      return lhs->device_type_ == rhs->device_type_;
+    } else {
+      // Compare higher-order domains pointwise.
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) {
+          return false;
+        }
+      }
+      return true;
+    }
+  }
+};
+
+/*!
+ * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation
+ * built up by calls to \p Unify.
+ */
+class DeviceDomains {
+ public:
+  DeviceDomains() = default;
+
+  /*!
+   * \brief Returns a domain appropriate for \p type who's result domain is bound
+   * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain
+   * will be free.
+   */
+  static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type) {
+    if (const auto* func_type_node = type.as<FuncTypeNode>()) {
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(func_type_node->arg_types.size() + 1);
+      for (const auto& arg_type : func_type_node->arg_types) {
+        args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType));
+      }
+      args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type));
+      return std::make_shared<DeviceDomain>(std::move(args_and_result));
+    } else {
+      return std::make_shared<DeviceDomain>(device_type);
+    }
+  }
+
+  /*!
+   * \brief Returns a higher-order domain with \p args_and_results.
+   */
+  static DeviceDomainPtr MakeDomain(std::vector<DeviceDomainPtr> arg_and_results) {
+    return std::make_shared<DeviceDomain>(std::move(arg_and_results));
+  }
+
+  /*! \brief Returns a domain with the given result device type appropriate \p device_type. */
+  static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) {
+    ICHECK_NE(device_type, kInvalidDeviceType);
+    return MakeDomain(type, device_type);
+  }
+
+  /*! \brief Returns a free domain appropriate for \p type. */
+  static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); }
+
+  /*! \brief Returns the domain representing the equivalence class containing \p domain. */
+  DeviceDomainPtr Lookup(DeviceDomainPtr domain) {
+    DeviceDomainPtr root = domain;
+    while (true) {
+      auto itr = domain_to_equiv_.find(root);
+      if (itr == domain_to_equiv_.end()) {
+        break;
+      }
+      ICHECK_NE(itr->second, root);
+      root = itr->second;
+      ICHECK_NOTNULL(root);
+    }
+    // Path compression.
+    while (domain != root) {
+      auto itr = domain_to_equiv_.find(domain);
+      ICHECK(itr != domain_to_equiv_.end());
+      domain = itr->second;
+      ICHECK_NOTNULL(domain);
+      itr->second = root;
+    }
+    return root;
+  }
+
+  /*!
+   * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs.
+   *
+   * Throws \p Error on failure.
+   */
+  DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    // TODO(mbs): Proper diagnostics.
+    ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size())
+        << "Device domains:" << std::endl
+        << ToString(lhs) << std::endl
+        << "and" << std::endl
+        << ToString(rhs) << std::endl
+        << "do not have the same kind and can't be unified.";
+    if (rhs->is_free()) {
+      return lhs;
+    } else if (lhs->is_free()) {
+      return rhs;
+    } else if (lhs->args_and_result_.empty()) {
+      // Must have consistent device types for first order domains.
+      if (lhs->device_type_ != rhs->device_type_) {
+        // TODO(mbs): Proper diagnostics.
+        std::ostringstream os;
+        os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_;
+        throw Error(os.str());
+      }
+      return lhs;
+    } else {
+      // Recurse for higher-order.
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(lhs->args_and_result_.size());
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i]));
+      }
+      return MakeDomain(std::move(args_and_result));
+    }
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p
+   * rhs disagree on bound device type.
+   *
+   * Throws \p Error on failure.
+   */
+  // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but
+  // given we have refs to functions I'm prepared to be surprised.
+  DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) {
+    lhs = Lookup(lhs);
+    rhs = Lookup(rhs);
+    auto joined_domain = Join(lhs, rhs);
+    if (!DeviceDomainEqual()(lhs, joined_domain)) {
+      domain_to_equiv_.emplace(lhs, joined_domain);
+    }
+    if (!DeviceDomainEqual()(rhs, joined_domain)) {
+      domain_to_equiv_.emplace(rhs, joined_domain);
+    }
+    return joined_domain;
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order,
+   * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as
+   * \p Unify.
+   *
+   * Throws \p Error on failure.
+   */
+  void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    if (!lhs->is_higher_order() && rhs->is_higher_order()) {
+      Collapse(lhs, rhs);
+    } else {
+      Unify(lhs, rhs);
+    }
+  }
+
+  /*! \brief Returns true if a domain is known for \p expr. */
+  bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); }
+
+  /*! \brief Returns the domain representing \p expr. */
+  DeviceDomainPtr DomainFor(const Expr& expr) {
+    ICHECK(expr.defined());
+    auto itr = expr_to_domain_.find(expr.get());
+    if (itr != expr_to_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    auto domain = Free(expr->checked_type());
+    expr_to_domain_.emplace(expr.get(), domain);
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the
+   * callee is a primitive or special operation we handle it specially. Otherwise defers to \p
+   * DomainFor(call->op).
+   *
+   * This special handling is needed:
+   * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices.
+   * - To handle some special ops which constrain devices to the CPU.
+   * - To allow the same primitive to be called on different devices at different call sites.
+   * Since each call to the op can have a different domain we index the ops by the call expression
+   * rather than the op itself.
+   */
+  DeviceDomainPtr DomainForCallee(const Call& call) {
+    auto itr = call_to_callee_domain_.find(call.get());
+    if (itr != call_to_callee_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    std::vector<DeviceDomainPtr> args_and_result;
+
+    auto on_device_props = GetOnDeviceProps(call.get());
+    auto device_copy_props = GetDeviceCopyProps(call.get());
+    if (!device_copy_props.body.defined()) {
+      device_copy_props = GetPrimitiveDeviceCopyProps(call.get());
+    }
+
+    if (on_device_props.body.defined()) {
+      // on_device(expr, device_type=<t>, is_fixed=false)
+      // on_device : fn(<t>):?x?
+      //
+      // on_device(expr, device_type=<t>, is_fixed=true)
+      // on_device: fn(<t>):<t>
+      args_and_result.emplace_back(
+          ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type));
+      if (on_device_props.is_fixed) {
+        args_and_result.emplace_back(args_and_result.front());
+      } else {
+        args_and_result.emplace_back(Free(on_device_props.body->checked_type()));
+      }
+    } else if (device_copy_props.body.defined()) {
+      // device_copy(expr, src_dev_type=<s>, dst_dev_type=<d>)
+      // device_copy: fn(<s>):<d>
+      args_and_result.emplace_back(
+          ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type));
+      args_and_result.emplace_back(
+          ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type));
+    } else if (call->op == alloc_storage_op) {
+      ICHECK_EQ(call->args.size(), 2U);
+      // alloc_storage(size, alignment, device_type=<t>)
+      // alloc_storage: fn(<cpu>, <cpu>):<t>
+      const auto* attrs = call->attrs.as<AllocStorageAttrs>();
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(
+          ForDeviceType(call->checked_type(), static_cast<DLDeviceType>(attrs->device_type)));
+    } else if (call->op == alloc_tensor_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // alloc_tensor(storage, offset, shape)
+      // alloc_tensor: fn(?x?, <cpu>, <cpu>):?x?
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op == shape_func_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // shape_func(func, inputs, outputs, is_inputs=[...])
+      // shape_func: fn(..., <cpu>, <cpu>):<cpu>
+      // where ... is a free domain appropriate for func's type
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      // TODO(mbs): I think this should be on the cpu only when is_input = [false], but
+      // what do we do when we have multiple arguments with different is_input values?
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+    } else if (call->op == shape_of_op) {
+      ICHECK_EQ(call->args.size(), 1U);
+      // shape_of(tensor)
+      // shape_of: fn(?x?):<cpu>
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      args_and_result.emplace_back(cpu_domain_);
+    } else if (call->op == invoke_tvm_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // invoke_tvm_op(op, inputs, outputs)
+      // invoke_tvm_op: fn(..., ?x?, ?x?):?x?
+      // where ... is a free domain appropriate for op's type
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op == reshape_tensor_op) {
+      ICHECK_EQ(call->args.size(), 2U);
+      // reshape_tensor(data, shape)
+      // reshape_tensor: fn(?x?, <cpu>):?x?
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op->IsInstance<OpNode>()) {
+      // <primitive>(arg1, ..., argn)
+      // <primitive>: fn(?x?, ..., ?x?):?x?
+      // (all args and result must be first-order).
+      auto free_domain = Free(arb_);
+      for (size_t i = 0; i < call->args.size(); ++i) {
+        args_and_result.emplace_back(free_domain);
+      }
+      args_and_result.emplace_back(free_domain);
+    } else {
+      // Defer to normal case where op can be an arbitrary expression.
+      return DomainFor(call->op);
+    }
+    auto domain = MakeDomain(std::move(args_and_result));
+    call_to_callee_domain_.emplace(call.get(), domain);
+    return domain;
+  }
+
+  /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */
+  void UnifyExprExact(const Expr& lhs, const Expr& rhs) {
+    auto lhs_domain = DomainFor(lhs);
+    auto rhs_domain = DomainFor(rhs);
+    try {
+      Unify(lhs_domain, rhs_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expressions:" << std::endl
+                 << PrettyPrint(lhs) << std::endl
+                 << "with device:" << std::endl
+                 << ToString(lhs_domain) << "and:" << std::endl
+                 << PrettyPrint(rhs) << std::endl
+                 << "with device:" << std::endl
+                 << ToString(rhs_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*!
+   * \brief Unifies the domain for \p expr with \p expected_domain.
+   */
+  void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) {
+    auto actual_domain = DomainFor(expr);
+    try {
+      Unify(actual_domain, expected_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expression:" << std::endl
+                 << PrettyPrint(expr) << std::endl
+                 << "with actual device:" << std::endl
+                 << ToString(actual_domain) << std::endl
+                 << "and expected device:" << std::endl
+                 << ToString(expected_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*!
+   * \brief Unifies the domain for \p expr with \p expected_domain.
+   * If \p expected_domain is higher-order but \p expr is first-order, require all arguments
+   * and the result of \p expected_domain to have the same domain as for \p expr.
+   */
+  void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) {
+    auto actual_domain = DomainFor(expr);
+    try {
+      UnifyCollapsed(actual_domain, expected_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expression:" << std::endl
+                 << PrettyPrint(expr) << std::endl
+                 << "with actual device:" << std::endl
+                 << ToString(actual_domain) << std::endl
+                 << "and expected device:" << std::endl
+                 << ToString(expected_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*! \brief Returns true if \p domain contains any free sub-domains. */
+  bool AnyFree(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    if (domain->is_free()) {
+      return true;
+    }
+    for (const auto& sub_domain : domain->args_and_result_) {
+      if (AnyFree(sub_domain)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /*
+   * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain.
+   * This can be used to handle functions within tuples, references and ADTs since we don't
+   * attempt to track anything beyond 'the device' for expressions of those first-order types.
+   *
+   * Throws \p Error on failure.
+   */
+  void Collapse(const DeviceDomainPtr& first_order_domain,
+                const DeviceDomainPtr& higher_order_domain) {
+    for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) {
+      Unify(higher_order_domain->function_param(i), first_order_domain);
+    }
+    Unify(higher_order_domain->function_result(), first_order_domain);
+  }
+
+  /*! \brief Force all free domains in \p domain to default to \p default_device_type. */
+  void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) {
+    ICHECK_NE(default_device_type, kInvalidDeviceType);
+    domain = Lookup(domain);
+    if (domain->is_free()) {
+      // Will never throw since lhs is free.
+      Unify(domain, std::make_shared<DeviceDomain>(default_device_type));
+    } else if (!domain->args_and_result_.empty()) {
+      for (const auto& sub_domain : domain->args_and_result_) {
+        SetDefault(sub_domain, default_device_type);
+      }
+    }
+  }
+
+  /*!
+   * \brief If \p domain is higher-order and its result domain is free, force it to
+   * \p default_device_type. Then force any  remaining free domains to the result domain
+   * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault.
+   */
+  void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type) {
+    if (!domain->is_higher_order()) {
+      SetDefault(domain, default_device_type);
+      return;
+    }
+    DLDeviceType result_device_type = ResultDeviceType(domain);
+    if (result_device_type == kInvalidDeviceType) {
+      // If the function result device is still free use the given default.
+      result_device_type = default_device_type;
+    }
+    // Default any remaining free parameters to the function result device.
+    SetDefault(domain, result_device_type);
+  }
+
+  /*! \brief Returns one-line description of \p domain for debugging. */
+  std::string ToString(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    std::ostringstream os;
+    if (domain->is_free()) {
+      // first-order free
+      os << "?" << static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get())) << "?";
+    } else if (domain->args_and_result_.empty()) {
+      // first-order bound
+      os << "<" << domain->device_type_ << ">";
+    } else {
+      // higher-order
+      os << "fn(";
+      for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) {
+        if (i > 0) {
+          os << ",";
+        }
+        os << ToString(domain->args_and_result_[i]);
+      }
+      os << "):" << ToString(domain->args_and_result_.back());
+    }
+    return os.str();
+  }
+
+  /*! \brief Returns description of entire system of constraints for debugging */
+  std::string ToString() {
+    std::ostringstream os;
+    for (const auto& pair : expr_to_domain_) {
+      os << "expression:" << std::endl
+         << PrettyPrint(GetRef<Expr>(pair.first)) << std::endl
+         << "domain:" << std::endl
+         << ToString(pair.second) << std::endl
+         << std::endl;
+    }
+    for (const auto& pair : call_to_callee_domain_) {
+      os << "call:" << std::endl
+         << PrettyPrint(GetRef<Call>(pair.first)) << std::endl
+         << "callee domain:" << std::endl
+         << ToString(pair.second) << std::endl
+         << std::endl;
+    }
+    return os.str();
+  }
+
+  /*!
+   * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment).
+   */
+  DeviceDomainPtr ResultDomain(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    while (!domain->args_and_result_.empty()) {
+      domain = Lookup(domain->args_and_result_.back());
+    }
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain
+   * comment).
+   */
+  DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) {
+    return ResultDomain(domain)->first_order_device_type();
+  }
+
+ private:
+  /*! \brief Intrinsics we need to handle specially. */
+  const Op& alloc_storage_op = Op::Get("memory.alloc_storage");
+  const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor");
+  const Op& shape_of_op = Op::Get("vm.shape_of");
+  const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op");
+  const Op& shape_func_op = Op::Get("vm.shape_func");
+  const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor");
+  /*! \brief The CPU device type for special operators such as dynamic shape functions. */
+  const DLDeviceType cpu_device_type_ = kDLCPU;
+  /*! \brief Placeholder for any first-order type. */
+  Type arb_ = TupleType();
+  /*! \brief The domain for first-order expressions on the CPU. */
+  DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_);
+
+  /*! \brief Maps expressions to their domains as determined during analysis. */
+  std::unordered_map<const ExprNode*, DeviceDomainPtr> expr_to_domain_;
+
+  /*!
+   * \brief Maps call expressions to the domains for their callee where the callee is a primitive.
+   */
+  std::unordered_map<const CallNode*, DeviceDomainPtr> call_to_callee_domain_;
+
+  /*! \brief Maps device domains to their equivalent domains as determined during unification. */
+  std::unordered_map<DeviceDomainPtr, DeviceDomainPtr, DeviceDomainHash, DeviceDomainEqual>
+      domain_to_equiv_;
+};
+
+/******
+****** Phase 0
+******/
+
+/*!
+ * \brief Rewrites "on_device" calls to handle some special cases.
+ */
+class RewriteOnDevices : public ExprMutator {
+ public:
+  RewriteOnDevices() = default;
+
+ private:
+  Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
+    Expr tuple = VisitExpr(tuple_get_item_node->tuple);
+    // TODO(mbs): Avoid copy.
+    Expr tuple_get_item =
+        TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span);
+    auto props = GetOnDeviceProps(tuple);
+    if (props.body.defined() && !props.is_fixed) {
+      VLOG(1) << "wrapping tuple get item:" << std::endl
+              << PrettyPrint(GetRef<TupleGetItem>(tuple_get_item_node)) << std::endl
+              << "with \"on_device\" for device " << props.device_type;
+      return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false);
+    } else {
+      return tuple_get_item;
+    }
+  }
+
+  Expr VisitExpr_(const LetNode* let_node) final {
+    auto expr = GetRef<Expr>(let_node);
+    std::vector<std::tuple<Var, Expr, Span>> bindings;
+    while (const auto* inner_let_node = expr.as<LetNode>()) {
+      Expr inner_let = GetRef<Let>(inner_let_node);
+      Expr value = VisitExpr(inner_let_node->value);
+      auto props = GetOnDeviceProps(value);
+      if (props.body.defined() && !props.is_fixed) {
+        VLOG(1) << "revising let-bound expression of let:" << std::endl
+                << PrettyPrint(expr) << std::endl
+                << "to be fixed to device " << props.device_type;
+        value = OnDevice(props.body, props.device_type, /*is_fixed=*/true);
+      }
+      bindings.emplace_back(inner_let_node->var, value, inner_let_node->span);
+      expr = inner_let_node->body;
+    }
+    expr = VisitExpr(expr);
+    // TODO(mbs): Avoid copy.
+    for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
+      expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr,
+                 /*span=*/std::get<2>(*itr));
+    }
+    return expr;
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) final {
+    Expr body = VisitExpr(function_node->body);
+    auto props = GetOnDeviceProps(body);
+    if (props.body.defined() && !props.is_fixed) {
+      VLOG(1) << "revising body of function:" << std::endl
+              << PrettyPrint(GetRef<Function>(function_node)) << std::endl
+              << "to be fixed to device " << props.device_type;
+      body = OnDevice(props.body, props.device_type, /*is_fixed=*/true);
+    }
+    // TODO(mbs): Avoid copy
+    return Function(function_node->params, body, function_node->ret_type,
+                    function_node->type_params, function_node->attrs, function_node->span);
+  }
+};
+
+/******
+****** Phase 1
+******/
+
+/*
+ * \brief Collects the system of device constraints for all sub-expressions in a module.
+ * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter.
+ */
+class DeviceAnalyzer : public ExprVisitor {
+ public:
+  explicit DeviceAnalyzer(IRModule mod)
+      : mod_(std::move(mod)), domains_(std::make_unique<DeviceDomains>()) {}
+
+  /*!
+   * \brief Returns the expression-to-device-domain map for all expressions in all the global
+   * function definitions in the module. Expressions may have free domains, these will be resolved
+   * by \p DeviceDefaulter below.
+   */
+  std::unique_ptr<DeviceDomains> Analyze() {
+    VLOG_CONTEXT << "DeviceAnalyzer";
+    for (const auto& pair : mod_->functions) {
+      VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'";
+      domains_->UnifyExprExact(pair.first, pair.second);
+      VisitExpr(pair.second);
+    }
+    return std::move(domains_);
+  }
+
+ private:
+  void VisitExpr_(const CallNode* call_node) final {
+    auto call = GetRef<Call>(call_node);
+
+    // Find the higher-order domain for the callee. See DomainForCallee for the special rules
+    // for primitives.
+    VisitExpr(call_node->op);
+    auto func_domain = domains_->DomainForCallee(call);  // higher-order
+
+    // Build the domain for the function implied by its arguments and call context.
+    ICHECK_EQ(func_domain->function_arity(), call_node->args.size());
+    std::vector<DeviceDomainPtr> args_and_result_domains;
+    args_and_result_domains.reserve(call_node->args.size() + 1);
+    for (const auto& arg : call_node->args) {
+      args_and_result_domains.emplace_back(domains_->DomainFor(arg));
+      VisitExpr(arg);
+    }
+    args_and_result_domains.emplace_back(domains_->DomainFor(call));
+    auto implied_domain =
+        DeviceDomains::MakeDomain(std::move(args_and_result_domains));  // higher-order
+
+    VLOG(1) << "initial call function domain:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "and implied domain:" << std::endl
+            << domains_->ToString(implied_domain) << "for call:" << std::endl
+            << PrettyPrint(call);
+
+    // The above must match.
+    try {
+      domains_->Unify(func_domain, implied_domain);  // higher-order
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:"
+                 << std::endl
+                 << PrettyPrint(call) << std::endl
+                 << "with function devices:" << std::endl
+                 << domains_->ToString(func_domain) << std::endl
+                 << "and implied call devices:" << std::endl
+                 << domains_->ToString(implied_domain) << std::endl
+                 << e.what();
+    }
+
+    VLOG(1) << "final call function domain:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "for call:" << std::endl
+            << PrettyPrint(call);
+  }
+
+  void VisitExpr_(const LetNode* let_node) final {
+    Expr expr = GetRef<Let>(let_node);
+    // Iteratively visit let nodes to avoid stack overflow.
+    while (expr->IsInstance<LetNode>()) {
+      Let let = Downcast<Let>(expr);
+      // Let var must be same device as value it is bound to.
+      domains_->UnifyExprExact(let->var, let->value);  // may be higher-order
+      // Let body must be same device as overall let.
+      domains_->UnifyExprExact(let, let->body);  // may be higher-order
+
+      VisitExpr(let->var);
+      VisitExpr(let->value);
+
+      expr = let->body;
+    }
+
+    // Visit the last body
+    VisitExpr(expr);
+  }
+
+  void VisitExpr_(const FunctionNode* function_node) final {
+    // No need to step into fused primitive functions as they are lowered individually according
+    // to the devices of all their call sites.
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return;
+    }
+
+    auto function = GetRef<Function>(function_node);
+    auto func_domain = domains_->DomainFor(function);  // higher-order
+
+    // The function body domain must match the function result domain.
+    domains_->UnifyExprExact(function_node->body,
+                             func_domain->function_result());  // may be higher-order
+
+    VLOG(1) << "initial function domain:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "and function body domain:" << std::endl
+            << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl
+            << "for function:" << std::endl
+            << PrettyPrint(function);
+
+    ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
+    for (size_t i = 0; i < function_node->params.size(); ++i) {
+      // The parameter domains must match the function argument domains.
+      domains_->UnifyExprExact(function_node->params[i],
+                               func_domain->function_param(i));  // may be higher-order
+      VisitExpr(function_node->params[i]);
+    }
+
+    // If the function already has an "on_device" attribute then we can further
+    // constrain the function's domain to match it.
+    Optional<Attrs> opt_attrs =
+        function_node->GetAttr<Attrs>(FunctionOnDeviceAttrs::kFunctionAttrsKey);
+    if (opt_attrs) {
+      std::vector<DeviceDomainPtr> args_and_result;
+      for (size_t i = 0; i < function_node->params.size(); ++i) {
+        args_and_result.emplace_back(
+            domains_->ForDeviceType(function_node->params[i]->checked_type(),
+                                    GetFunctionParamDeviceType(function_node, i)));
+      }
+      args_and_result.emplace_back(domains_->ForDeviceType(
+          function_node->body->checked_type(), GetFunctionResultDeviceType(function_node)));
+      auto annotation_domain = domains_->MakeDomain(std::move(args_and_result));
+      try {
+        domains_->Unify(func_domain, annotation_domain);  // higher-order
+      } catch (const Error& e) {
+        // TODO(mbs): Proper diagnostics.
+        LOG(FATAL)
+            << "Function devices are incompatible with its \"on_device\" annotation. Function:"
+            << std::endl
+            << PrettyPrint(function) << std::endl
+            << "with function devices:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "and annotation devices:" << std::endl
+            << domains_->ToString(annotation_domain) << std::endl
+            << e.what();
+      }
+    }
+
+    VisitExpr(function_node->body);
+
+    VLOG(1) << "final function domain:" << std::endl
+            << domains_->ToString(func_domain) << std::endl
+            << "and function body domain:" << std::endl
+            << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl
+            << "for function:" << std::endl
+            << PrettyPrint(function);
+  }
+
+  void VisitExpr_(const TupleNode* tuple_node) final {
+    Tuple tuple = GetRef<Tuple>(tuple_node);
+    for (size_t i = 0; i < tuple->fields.size(); i++) {
+      auto domain = domains_->DomainFor(tuple->fields[i]);  // may be higher-order
+      domains_->UnifyExprCollapsed(tuple, domain);          // collapse to first-order if needed
+      VisitExpr(tuple->fields[i]);
+    }
+  }
+
+  void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
+    TupleGetItem tuple_get_item = GetRef<TupleGetItem>(tuple_get_item_node);
+    auto domain = domains_->DomainFor(tuple_get_item);  // may be higher-order
+    domains_->UnifyExprCollapsed(tuple_get_item_node->tuple,
+                                 domain);  // collapse to first-order if needed
+    VisitExpr(tuple_get_item_node->tuple);
+  }
+
+  class DevicePatternAnalyzer : public PatternVisitor {
+   public:
+    DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node)
+        : domains_(domains), adt_node_(adt_node) {}
+
+   private:
+    void VisitPattern_(const PatternVarNode* pattern_var_node) final {
+      auto var_domain = domains_->DomainFor(pattern_var_node->var);  // may be higher order
+      domains_->UnifyExprCollapsed(GetRef<Expr>(adt_node_),
+                                   var_domain);  // collapse to first-order if needed
+    }
+
+    /*! \brief (Mutable borrow of) the domains for all expressions processed so far. */
+    DeviceDomains* domains_;
+    /*! \brief The expression for the ADT we are matching over. */
+    const ExprNode* adt_node_;
+  };
+
+  void VisitPattern(const Pattern& pattern) final {}
+
+  void VisitExpr_(const MatchNode* match_node) final {
+    // For match node, we unify the value and the rhs of each clause
+    Match match = GetRef<Match>(match_node);
+    auto match_domain = domains_->DomainFor(match);  // may be higher-order
+    DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get());
+    domains_->UnifyExprCollapsed(match->data, match_domain);  // collapse to first-order if needed
+    for (const auto& clause : match->clauses) {
+      pattern_analyzer.VisitPattern(clause->lhs);
+      domains_->UnifyExprExact(clause->rhs, match_domain);
+      VisitExpr(clause->rhs);
+    }
+    VisitExpr(match_node->data);
+  }
+
+  void VisitExpr_(const GlobalVarNode* global_var_node) final {
+    domains_->DomainFor(GetRef<GlobalVar>(global_var_node));
+  }
+
+  void VisitExpr_(const VarNode* var_node) final { domains_->DomainFor(GetRef<Var>(var_node)); }
+
+  void VisitExpr_(const ConstantNode* constant_node) final {
+    domains_->DomainFor(GetRef<Constant>(constant_node));
+  }
+
+  void VisitExpr_(const ConstructorNode* constructor_node) final {
+    // Probably needs to be device polymorphic.
+    domains_->DomainFor(GetRef<Constructor>(constructor_node));
+  }
+
+  void VisitExpr_(const IfNode* if_node) final {
+    auto ife = GetRef<If>(if_node);
+    auto domain = domains_->DomainFor(ife);               // may be higher-order
+    domains_->UnifyExprCollapsed(if_node->cond, domain);  // collapse to first-order if needed
+    domains_->UnifyExprExact(if_node->true_branch, domain);
+    domains_->UnifyExprExact(if_node->false_branch, domain);
+    VisitExpr(if_node->cond);
+    VisitExpr(if_node->true_branch);
+    VisitExpr(if_node->false_branch);
+  }
+
+  void VisitExpr_(const OpNode* op) final {
+    // no-op, primitive operators are handled at their call-sites.
+  }
+
+  void VisitExpr_(const RefCreateNode* ref_create_node) final {
+    auto ref_create = GetRef<RefCreate>(ref_create_node);
+    auto domain = domains_->DomainFor(ref_create_node->value);  // may be higher-order
+    domains_->UnifyExprCollapsed(ref_create, domain);           // collapse to first-order if needed
+    VisitExpr(ref_create_node->value);
+  }
+
+  void VisitExpr_(const RefReadNode* ref_read_node) final {
+    auto ref_read = GetRef<RefRead>(ref_read_node);
+    auto domain = domains_->DomainFor(ref_read);               // may be higher-order
+    domains_->UnifyExprCollapsed(ref_read_node->ref, domain);  // collapse to first-order if needed
+    VisitExpr(ref_read_node->ref);
+  }
+
+  void VisitExpr_(const RefWriteNode* ref_write_node) final {
+    auto ref_write = GetRef<RefWrite>(ref_write_node);
+    auto domain = domains_->DomainFor(ref_write->value);   // may be higher-order
+    domains_->UnifyExprCollapsed(ref_write->ref, domain);  // collapse to first-order if needed
+    domains_->UnifyExprCollapsed(ref_write, domain);       // collapse to first-order if needed
+    VisitExpr(ref_write_node->ref);
+    VisitExpr(ref_write_node->value);
+  }
+
+  /*! \brief The module we are analyzing. */
+  IRModule mod_;
+  /*! \brief The domains for all expressions processed so far. */
+  std::unique_ptr<DeviceDomains> domains_;
+};
+
+/******
+****** Phase 2
+******/
+
+/*!
+ * \brief Ensures every sub-expression in a module has a device type, using both the global
+ * default and some local heuristics to avoid unnecessary additional "device_copy" CallNodes.
+ *
+ * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap
+ * order.
+ */
+class DeviceDefaulter : public ExprVisitor {
+ public:
+  DeviceDefaulter(IRModule mod, std::unique_ptr<DeviceDomains> domains,
+                  DLDeviceType default_device_type)
+      : mod_(std::move(mod)),
+        domains_(std::move(domains)),
+        default_device_type_(default_device_type) {}
+
+  std::unique_ptr<DeviceDomains> Default() {
+    VLOG_CONTEXT << "DeviceDefaulter";
+    for (const auto& pair : mod_->functions) {
+      VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'";
+      VisitExpr(pair.second);
+    }
+    return std::move(domains_);
+  }
+
+ private:
+  void VisitExpr_(const FunctionNode* function_node) final {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return;
+    }
+
+    auto function = GetRef<Function>(function_node);
+    auto func_domain = domains_->DomainFor(function);  // higher-order
+    ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
+    if (domains_->AnyFree(func_domain)) {
+      VLOG(1) << "before defaulting function:" << std::endl << domains_->ToString(func_domain);
+      domains_->SetResultDefaultThenParams(func_domain, default_device_type_);
+      VLOG(1) << "after defaulting function:" << std::endl << domains_->ToString(func_domain);
+    }
+    VisitExpr(function_node->body);
+  }
+
+  void VisitExpr_(const CallNode* call_node) final {
+    auto call = GetRef<Call>(call_node);
+    auto func_domain = domains_->DomainForCallee(call);  // higher-order
+    ICHECK_EQ(func_domain->function_arity(), call_node->args.size());
+    if (domains_->AnyFree(func_domain)) {
+      // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*)
+      // above. But for calls to primitives we may still need to force free domains to be
+      // defaulted.
+      VLOG(1) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain);
+      domains_->SetResultDefaultThenParams(func_domain, default_device_type_);
+      VLOG(1) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain);
+    }
+    return ExprVisitor::VisitExpr_(call_node);
+  }
+
+  void VisitExpr_(const LetNode* let_node) final {
+    Expr expr = GetRef<Let>(let_node);
+    // Iteratively visit let nodes to avoid stack overflow.
+    while (expr->IsInstance<LetNode>()) {
+      Let let = Downcast<Let>(expr);
+      // If the let-var device is still free force it to match the overall let.
+      auto let_domain = domains_->DomainFor(let);  // may be higher-order
+      DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain);
+      ICHECK_NE(let_device_type, kInvalidDeviceType);
+      auto let_var_domain = domains_->DomainFor(let->var);  // may be higher-order
+      if (domains_->AnyFree(let_var_domain)) {
+        VLOG(1) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain);
+        domains_->SetDefault(let_var_domain, let_device_type);
+        VLOG(1) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain);
+      }
+      VisitExpr(let->var);
+      VisitExpr(let->value);
+      expr = let->body;
+    }
+    VisitExpr(expr);
+  }
+
+  /*! \brief The module we are processing. */
+  IRModule mod_;
+  /*! \brief The domains for all expressions.  */
+  std::unique_ptr<DeviceDomains> domains_;
+  /*! \brief The default device type. */
+  DLDeviceType default_device_type_;
+};
+
+/******
+****** Phase 3
+******/
+
+/*!
+ * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every
+ * sub-expression in a module can be easily recovered by a later transformation using simple
+ * lexical scoping rules (e.g. for memory planning).
+ *
+ * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard
+ *   any existing "device_copy" CallNodes which are no-ops.
+ *
+ * - Functions are given an "on_device" attribute bound to a FunctionOnDeviceAttrs to capture
+ *   the device type for its parameters and result.
+ *
+ * - Additional "device_copy" CallNodes are inserted wherever there's a transition between
+ *   storage device types. Since the DeviceAnalyzer phase succeeded this can only happen
+ *   where the original program explicitly allowed a transition using an "on_device" CallNode.
+ *   That is, we do not not try to 'fix' a program with inconsistent devices.
+ *
+ * - Additional "on_device" CallNodes are inserted so that a later transform can discover
+ *   the device for an arbitrary sub-expression by looking only for the lexically enclosing
+ *   "on_device" CallNode or "on_device" function attribute. In particular, since function
+ *   arguments and let-bound expressions can be on a device different from the function
+ *   or let body itself we will insert "on_device" CallNodes to spell out any differences. This
+ *   applies even to the argument to a "device_copy" CallNode, which may look pedantic but
+ *   keeps downstream processing simple. The "on_device" calls should be removed before code gen,
+ *   which is easily done on-the-fly.
+ */
+class DeviceCapturer : public ExprMutator {
+ public:
+  DeviceCapturer(IRModule mod, std::unique_ptr<DeviceDomains> domains)
+      : mod_(std::move(mod)), domains_(std::move(domains)) {}
+
+  IRModule Capture() {
+    VLOG_CONTEXT << "CaptureDevices";
+    IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map);
+    for (const auto& pair : mod_->functions) {
+      VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'";
+      result->Add(pair.first, Downcast<BaseFunc>(Mutate(pair.second)));
+    }
+    return result;
+  }
+
+ private:
+  // Nothing interesting for VarNode, ConstantNode, GlobalVarNode and OpNode.
+
+  Expr VisitExpr_(const TupleNode* tuple_node) final {
+    auto tuple = GetRef<Tuple>(tuple_node);
+    Array<Expr> fields;
+    fields.reserve(tuple_node->fields.size());
+    for (const auto& field : tuple_node->fields) {
+      fields.push_back(VisitChild(tuple, field));
+    }
+    // TODO(mbs): Avoid copy
+    return Tuple(std::move(fields), tuple_node->span);
+  }
+
+  Expr VisitExpr_(const FunctionNode* function_node) final {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      return GetRef<Function>(function_node);
+    }
+
+    auto function = GetRef<Function>(function_node);
+    auto func_domain = domains_->DomainFor(function);  // higher-order
+    VLOG(1) << "capturing function:" << std::endl
+            << PrettyPrint(function) << std::endl
+            << "with domain:" << std::endl
+            << domains_->ToString(func_domain);
+
+    // Gather the parameter and result device types for the "on_device" function attribute.
+    ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
+    DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain);
+    ICHECK_NE(result_device_type, kInvalidDeviceType);
+    Array<Integer> param_device_types;
+    param_device_types.reserve(function_node->params.size());
+    for (size_t i = 0; i < function_node->params.size(); ++i) {
+      DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i));
+      ICHECK_NE(param_device_type, kInvalidDeviceType);
+      param_device_types.push_back(param_device_type);
+    }
+
+    // Rewrite the body. Note that the body may have begun with an "on_device" so
+    // be prepared to insert a "device_copy".
+    Expr body = VisitChild(
+        /*lexical_device_type=*/result_device_type,
+        /*expected_device_type=*/result_device_type,
+        /*child_device_type=*/GetDeviceType(function_node->body), function_node->body);
+
+    // TODO(mbs): Avoid copy
+    Function func = Function(function_node->params, body, function_node->ret_type,
+                             function_node->type_params, function_node->attrs, function_node->span);
+    return FunctionOnDevice(func, param_device_types, result_device_type);
+  }
+
+  Expr VisitExpr_(const CallNode* call_node) final {

Review comment:
       i roughly understand what's happening, but i think example Relay snippets would help to better understand/maintain each Pass here.

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -0,0 +1,1986 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/device_planner.cc
+ * \brief Determines a unique device to hold the result of every Relay sub-expression.
+ *
+ * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
+ * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the
+ * specific target associated with D (this is recovered independently via a TargetMap), and we
+ * do not track the storage scope within D (this is yet to be implemented).
+ *
+ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
+ * see below.
+ *
+ * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes:
+ *  - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and
+ *    'dst_dev_type' device type, which constrain the argument and context of the call
+ *     respectively. It is ok if source and destination devices are the same, such no-op copies
+ *     will be removed after accounting for the device preference.
+ *  - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which
+ *    constrains the argument of the call, but (usually, see below) leaves the context
+ *    unconstrained. These are called 'annotations' in the rest of the code, have no operational
+ *    significance by themselves, but may trigger the insertion of a new "device_copy".
+ *  - In two situations the result of an "on_device" CallNode may also be constrained to the
+ *    given device:
+ *     - The "on_device" call occurs at the top-level of a function body, or occurs as an
+ *       immediately let-bound expression. In this situation the extra degree of freedom in
+ *       the function result and let-binding leads to surprising device copies, so we simply
+ *       force the function result or let-bound variable to the given device.
+ *     - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
+ *       it ourselves during an earlier invocation of this pass. This helps make this pass
+ *       idempotent.
+ *
+ * We proceed in four phases:
+ *
+ * Phase 0
+ * -------
+ * We rewrite the programs to handle some special cases:
+ *  - "on_device" calls at the top-level of function or immediately let-bound are rewritten
+ *    to have \code is_fixed=true \endcode.
+ *  - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
+ *    \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
+ *    the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
+ *
+ * Phase 1
+ * -------
+ * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see
+ * below) to all other Relay sub-expressions. (For idempotence we also respect any existing
+ * "on_device" function attributes we introduce below.)
+ *
+ * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
+ * same device. However each call site can use a different device. In other words primitives are
+ * 'device polymorphic' since we compile and execute them for each required device.
+ *
+ * For most Relay expressions the device for the overall expression is the same as the device
+ * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
+ * itself, the condition and arms of an \p if must all be on the same device as the overall if,
+ * and so on.
+ *
+ * Some special ops (or 'dialects') are handled:
+ *  - Relay supports computing the shape of tensors and operators at runtime using "shape_of",
+ *    "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors
+ *    they describe may reside on any device.
+ *  - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
+ *    shapes reside on the CPU, but the allocated tensors may reside on any device.
+ *
+ * Two Relay expression have special handling:
+ *  - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
+ *    overall let. However the result of \p e1 may be on a different device.
+ *  - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
+ *    same device as \p body. However parameters \p x and \p may be on different devices, even
+ *    different from each other. Every call to the function must use the same choice of parameter
+ *    and result devices -- there is no 'device polymorphism' for Relay functions.
+ *
+ * Phase 2
+ * -------
+ * After flowing constraints we apply some defaulting heuristics (using a global default device)
+ * to fix the device for any as-yet unconstrained sub-expressions.
+ *  - Unconstrained function result devices default to the global default device.
+ *  - Unconstrained function parameters devices default to the device for the function result.
+ *  - Unconstrained let-bound expression devices default to the device for the overall let.
+ * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to
+ * the global default device. Worth a design doc with motivating examples I think.
+ *
+ * Phase 3
+ * -------
+ * Finally, the result of this analysis is reified into the result as:
+ *  - Additional "on_device" attributes (an Attrs resolving to a \p FunctionOnDeviceAttrs) for
+ *    every function (both top-level and local). These describe the devices for the function's
+ *    parameters and the result.
+ *  - Additional "device_copy" CallNodes where a copy is required in order to respect the
+ *    intent of the original "on_device" CallNodes.
+ *  - Additional "on_device" CallNodes where the device type of an expression does not match
+ *    that of the lexically enclosing "on_device" CallNode or function attribute. In practice
+ *    this means "on_device" CallNodes may appear in two places:
+ *     - On a let-bound expression if its device differs from the overall let expression.
+ *     - On a call argument if its device differs from the call result. In particular, the
+ *       argument to a "device_copy" call will always be wrapped in an "on_device". (That may
+ *       seem pedantic but simplifies downstream handling.)
+ *    However since we make it easy to track devices for variables we never wrap an "on_device"
+ *    around a var or global var. These uses of "on_device" imply both the argument and result are
+ *    on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true,
+ *    which helps make this pass idempotent.
+ *
+ * A helper \p LexicalOnDeviceMixin class can be used by downstream transforms to recover the device
+ * for any expression for their own use, e.g. during memory planning. All downstream passes must
+ * preserve the lexical scoping of the "on_device" CallNodes. In particular conversion to ANF
+ * must respect the lexical scoping convention:
+ * \code
+ * f(on_device(g(h(a, b), c), device_type=CPU))
+ * ==>
+ * let %x0 = on_device(h(a, b), device_type=CPU)
+ * let %x1 = on_device(g(%x0), device-type=CPU)
+ * f(on_device(%x1, device_type=CPU))
+ * \endcode
+ *
+ * This pass should be run before FuseOps it can use device-specific fusion rules.
+ *
+ * 'Stored on' vs 'Executes on'
+ * ----------------------------
+ * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the
+ * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for
+ * primitives.
+ *
+ * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are
+ * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific
+ * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to
+ * know exactly which device (possibly one of a number of available 'CPU'-like devices) is
+ * responsible for execution. Currently that's handled independently by the \p AnnotateTargets
+ * pass, but we'd like to fold that into device planning here to ensure everything is consistent.
+ *
+ * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay
+ * expression (eg an if expression) on one device even though the tensor data resides on
+ * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on'
+ * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just
+ * compile the function body for the function's result device.
+ *
+ * This works after conversion to ANF provided the compilation for a let expression is prepared
+ * to make a cross-device call. However we leave it to a downstream transformation to heuristically
+ * minimize cross-device calls by moving device copies out of functions. E.g.:
+ * \code
+ *   def @f() {  // execute on CPU
+ *     let x = on_device(...GPU computation..., device_type=GPU);
+ *     device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
+ *   }
+ *   def @main() {
+ *     ... call @f() on CPU ...
+ *   }
+ * \endcode
+ * could be rewritten to:
+ * \code
+ *   def @f() {  // execute on GPU
+ *     let x = ...GPU computation...;
+ *     ...GPU computation...
+ *   }
+ *   def @main() {
+ *     let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
+ *     ... use x on CPU ...
+ *   }
+ * \endcode
+ *
+ * Higher-order shenanigans
+ * ------------------------
+ * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions
+ * as arguments (even anonymous functions), return functions, evaluate conditional expressions
+ * over functions, and so on. We handle this during constraint solving using the domain:
+ * \code
+ *   D  ::= <specific device type>   -- first-order
+ *        | fn(D,...,D):D            -- higher-order
+ * \endcode
+ * In this way we can determine the device for all function parameters and results. E.g. for
+ * \code
+ *   let f = fn(x, y) { ... }
+ *   let g = fn(f, z) { f(z, z) }
+ *   g(f, on_device(..., device_type=CPU))
+ * \endcode
+ * the parameters \p x and \p y will be on the CPU.
+ *
+ * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a
+ * function. Our analysis must guarantee that the function's parameters and result devices are
+ * consistent for \p e2, \p e3, and the context of the call. But:
+ *  - Which device holds the closure result of evaluating \p e1 ?
+ *  - If \p e2 is of function type, what does that mean when we say every function parameter
+ *    is on a device?
+ *  - If \p e1 returns a function, what does that mean when we say every function result is
+ *    on a device?
+ *
+ * Since higher-order aspects are later compiled away (by 'defunctionalization'
+ * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular,
+ * we really don't want our domain \p D to allow for yet another device for the function closure.
+ * So we'll just force the 'device for a function' to be the same as the device for the function's
+ * result using the notion of the 'result domain' for a domain:
+ * \code
+ *   result_domain(<specific device type>) = <specific device type>
+ *   result_domain(fn(D1,...,Dn):Dr)       = result_domain(Dr)
+ * \endcode
+ *
+ * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the
+ * analysis encounters a function inside one of those it simply forces all argument and result
+ * devices for the function to match the device for the first-order expression. For example,
+ * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function
+ * parameters and result must similarly be on the GPU.
+ *
+ * -------
+ * | AOR |  This pass supports all of Relay.
+ * -------
+ *    ^
+ *    |
+ *    `-- Mark's stamp of completeness :-)
+ *
+ * TODO(mbs):
+ *  * Though on_device is the identity for all types we can't wrap it around functions/constructors
+ *    taking type args (or at least not without changing type_infer.cc to see through them).
+ *    This is not currently handled generally.
+ *  * Proper diagnostics for unification failure using spans.
+ *  * Make sure the pass is idempotent even after FuseOps etc.
+ *  * Support application of constructors properly. Are they device polymorphic?
+ *  * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'.
+ *  * Support running the pass post FuseOps (so need to understand primitive functions, both
+ *    outlines and lined) and post the VM transforms (probably need to support more intrinsic
+ *    forms?).
+ *  * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default
+ *    device for primitives vs the default device for the rest of Relay.
+ *  * We'll probably need some support for partial 'device polymorphism' for functions once we
+ *    incorporate targets and memory scopes into the domain. For example it's ok for the function
+ *    body to be executed on different device ids provided they have the same target and memory
+ *    scope.
+ *  * Might be simpler to just let every type have a device annotation rather than work in
+ *    a separate domain?
+ *  * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies.
+ *  * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls
+ *    in tuples at the top level of function bodies or main expression, irrespective of the
+ *    "on_device" body. What's up with that?
+ */
+
+#include "./device_planner.h"
+
+#include <tvm/ir/transform.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/object.h>
+
+#include <unordered_map>
+
+#include "../op/annotation/annotation.h"
+#include "../op/memory/device_copy.h"
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+namespace {
+
+/*!
+ * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather
+ * than the original "device_copy" operator.
+ *
+ * See te_compiler.cc for where this rewriting occurs.
+ */
+DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
+  auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
+  if (tir_call_attrs == nullptr) {
+    return {};
+  }
+  if (tir_call_attrs->metadata.count("source_device") != 1 ||
+      tir_call_attrs->metadata.count("dst_device") != 1) {
+    return {};
+  }
+  ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
+  return {
+      call_node->args[0],
+      static_cast<DLDeviceType>(
+          Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
+      static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
+}
+
+class DeviceDomain;
+using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
+
+/******
+****** Domains
+******/
+
+/*!
+ * \brief Represents the domain over which we collect equality constraints.
+ *
+ * \code
+ *   D ::= ?x?                  -- first order, free
+ *       | <device_type>        -- first order, bound
+ *       | fn(D1, ..., Dn):Dr   -- higher order
+ * \endcode
+ *
+ * We require a function value to be on the same device as its result. To support that we need
+ * a notion of the 'result domain' of a domain:
+ * \code
+ *   result_domain(?x?)                = ?x?
+ *   result_domain(<device_type>)      = <device_type>
+ *   result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr)
+ * \endcode
+ */
+class DeviceDomain {
+ public:
+  /*!
+   * \brief Constructs a first-order domain of \p device_type, which may be
+   * \p kInvalidDeviceType to indicate the domain is free.
+   */
+  explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {}
+
+  /*!
+   * \brief Constructs a higher-order domain, where \p args_and_result contain the
+   * function argument and result domains in order.
+   */
+  explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result)
+      : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {}
+
+  /*! \brief Returns true if domain is first-order and free. */
+  bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); }
+
+  /*! \brief Returns true if domain is higher-order. */
+  bool is_higher_order() const { return !args_and_result_.empty(); }
+
+  DLDeviceType first_order_device_type() const {
+    ICHECK(args_and_result_.empty());
+    return device_type_;
+  }
+
+  size_t function_arity() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.size() - 1UL;
+  }
+
+  DeviceDomainPtr function_param(size_t i) const {
+    ICHECK(!args_and_result_.empty());
+    ICHECK_LT(i + 1, args_and_result_.size());
+    return args_and_result_[i];
+  }
+
+  DeviceDomainPtr function_result() const {
+    ICHECK(!args_and_result_.empty());
+    return args_and_result_.back();
+  }
+
+ private:
+  /*!
+   * \brief If this is a function domain then always kInvalidDevice. Otherwise will be
+   * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is
+   * bound.
+   */
+  const DLDeviceType device_type_;
+
+  /*!
+   * \brief If this is a function domain then the sub-domains for each of the function's
+   * arguments, and the domain for its result. Otherwise empty.
+   */
+  const std::vector<DeviceDomainPtr> args_and_result_;
+
+  friend struct DeviceDomainHash;
+  friend struct DeviceDomainEqual;
+  friend class DeviceDomains;
+};
+
+// Ye olde boost hash mixer.
+constexpr size_t mix(size_t h1, size_t h2) {
+  return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+}
+
+// The following hash and equality helpers give each free first-order domain pointer its own
+// distinct identity.
+struct DeviceDomainHash {
+  size_t operator()(const DeviceDomainPtr& domain) const {
+    if (domain->is_free()) {
+      // Give each free first-order domain its own identity.
+      return static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get()));
+    } else {
+      size_t h = domain->args_and_result_.size();
+      h = mix(h, std::hash<int>()(static_cast<int>(domain->device_type_)));
+      for (const auto& sub_domain_ptr : domain->args_and_result_) {
+        h = mix(h, DeviceDomainHash()(sub_domain_ptr));
+      }
+      return h;
+    }
+  }
+};
+
+struct DeviceDomainEqual {
+ public:
+  bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const {
+    if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) {
+      // Mismatched arities are never equal.
+      // (Though we'll never ask to do such a comparison explicitly, the hash map
+      // may do so implicitly due to hash collisions.)
+      return false;
+    }
+    if (lhs->is_free() && rhs->is_free()) {
+      // Compare first-order free domains by their address.
+      return lhs.get() == rhs.get();
+    }
+    if (lhs->args_and_result_.empty()) {
+      // Compare first-order domains by their device type -- free vs bound will compare as false.
+      return lhs->device_type_ == rhs->device_type_;
+    } else {
+      // Compare higher-order domains pointwise.
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) {
+          return false;
+        }
+      }
+      return true;
+    }
+  }
+};
+
+/*!
+ * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation
+ * built up by calls to \p Unify.
+ */
+class DeviceDomains {
+ public:
+  DeviceDomains() = default;
+
+  /*!
+   * \brief Returns a domain appropriate for \p type who's result domain is bound
+   * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain
+   * will be free.
+   */
+  static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type) {
+    if (const auto* func_type_node = type.as<FuncTypeNode>()) {
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(func_type_node->arg_types.size() + 1);
+      for (const auto& arg_type : func_type_node->arg_types) {
+        args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType));
+      }
+      args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type));
+      return std::make_shared<DeviceDomain>(std::move(args_and_result));
+    } else {
+      return std::make_shared<DeviceDomain>(device_type);
+    }
+  }
+
+  /*!
+   * \brief Returns a higher-order domain with \p args_and_results.
+   */
+  static DeviceDomainPtr MakeDomain(std::vector<DeviceDomainPtr> arg_and_results) {
+    return std::make_shared<DeviceDomain>(std::move(arg_and_results));
+  }
+
+  /*! \brief Returns a domain with the given result device type appropriate \p device_type. */
+  static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) {
+    ICHECK_NE(device_type, kInvalidDeviceType);
+    return MakeDomain(type, device_type);
+  }
+
+  /*! \brief Returns a free domain appropriate for \p type. */
+  static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); }
+
+  /*! \brief Returns the domain representing the equivalence class containing \p domain. */
+  DeviceDomainPtr Lookup(DeviceDomainPtr domain) {
+    DeviceDomainPtr root = domain;
+    while (true) {
+      auto itr = domain_to_equiv_.find(root);
+      if (itr == domain_to_equiv_.end()) {
+        break;
+      }
+      ICHECK_NE(itr->second, root);
+      root = itr->second;
+      ICHECK_NOTNULL(root);
+    }
+    // Path compression.
+    while (domain != root) {
+      auto itr = domain_to_equiv_.find(domain);
+      ICHECK(itr != domain_to_equiv_.end());
+      domain = itr->second;
+      ICHECK_NOTNULL(domain);
+      itr->second = root;
+    }
+    return root;
+  }
+
+  /*!
+   * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs.
+   *
+   * Throws \p Error on failure.
+   */
+  DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    // TODO(mbs): Proper diagnostics.
+    ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size())
+        << "Device domains:" << std::endl
+        << ToString(lhs) << std::endl
+        << "and" << std::endl
+        << ToString(rhs) << std::endl
+        << "do not have the same kind and can't be unified.";
+    if (rhs->is_free()) {
+      return lhs;
+    } else if (lhs->is_free()) {
+      return rhs;
+    } else if (lhs->args_and_result_.empty()) {
+      // Must have consistent device types for first order domains.
+      if (lhs->device_type_ != rhs->device_type_) {
+        // TODO(mbs): Proper diagnostics.
+        std::ostringstream os;
+        os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_;
+        throw Error(os.str());
+      }
+      return lhs;
+    } else {
+      // Recurse for higher-order.
+      std::vector<DeviceDomainPtr> args_and_result;
+      args_and_result.reserve(lhs->args_and_result_.size());
+      for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
+        args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i]));
+      }
+      return MakeDomain(std::move(args_and_result));
+    }
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p
+   * rhs disagree on bound device type.
+   *
+   * Throws \p Error on failure.
+   */
+  // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but
+  // given we have refs to functions I'm prepared to be surprised.
+  DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) {
+    lhs = Lookup(lhs);
+    rhs = Lookup(rhs);
+    auto joined_domain = Join(lhs, rhs);
+    if (!DeviceDomainEqual()(lhs, joined_domain)) {
+      domain_to_equiv_.emplace(lhs, joined_domain);
+    }
+    if (!DeviceDomainEqual()(rhs, joined_domain)) {
+      domain_to_equiv_.emplace(rhs, joined_domain);
+    }
+    return joined_domain;
+  }
+
+  /*!
+   * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order,
+   * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as
+   * \p Unify.
+   *
+   * Throws \p Error on failure.
+   */
+  void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+    if (!lhs->is_higher_order() && rhs->is_higher_order()) {
+      Collapse(lhs, rhs);
+    } else {
+      Unify(lhs, rhs);
+    }
+  }
+
+  /*! \brief Returns true if a domain is known for \p expr. */
+  bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); }
+
+  /*! \brief Returns the domain representing \p expr. */
+  DeviceDomainPtr DomainFor(const Expr& expr) {
+    ICHECK(expr.defined());
+    auto itr = expr_to_domain_.find(expr.get());
+    if (itr != expr_to_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    auto domain = Free(expr->checked_type());
+    expr_to_domain_.emplace(expr.get(), domain);
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the
+   * callee is a primitive or special operation we handle it specially. Otherwise defers to \p
+   * DomainFor(call->op).
+   *
+   * This special handling is needed:
+   * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices.
+   * - To handle some special ops which constrain devices to the CPU.
+   * - To allow the same primitive to be called on different devices at different call sites.
+   * Since each call to the op can have a different domain we index the ops by the call expression
+   * rather than the op itself.
+   */
+  DeviceDomainPtr DomainForCallee(const Call& call) {
+    auto itr = call_to_callee_domain_.find(call.get());
+    if (itr != call_to_callee_domain_.end()) {
+      return Lookup(itr->second);
+    }
+    std::vector<DeviceDomainPtr> args_and_result;
+
+    auto on_device_props = GetOnDeviceProps(call.get());
+    auto device_copy_props = GetDeviceCopyProps(call.get());
+    if (!device_copy_props.body.defined()) {
+      device_copy_props = GetPrimitiveDeviceCopyProps(call.get());
+    }
+
+    if (on_device_props.body.defined()) {
+      // on_device(expr, device_type=<t>, is_fixed=false)
+      // on_device : fn(<t>):?x?
+      //
+      // on_device(expr, device_type=<t>, is_fixed=true)
+      // on_device: fn(<t>):<t>
+      args_and_result.emplace_back(
+          ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type));
+      if (on_device_props.is_fixed) {
+        args_and_result.emplace_back(args_and_result.front());
+      } else {
+        args_and_result.emplace_back(Free(on_device_props.body->checked_type()));
+      }
+    } else if (device_copy_props.body.defined()) {
+      // device_copy(expr, src_dev_type=<s>, dst_dev_type=<d>)
+      // device_copy: fn(<s>):<d>
+      args_and_result.emplace_back(
+          ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type));
+      args_and_result.emplace_back(
+          ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type));
+    } else if (call->op == alloc_storage_op) {
+      ICHECK_EQ(call->args.size(), 2U);
+      // alloc_storage(size, alignment, device_type=<t>)
+      // alloc_storage: fn(<cpu>, <cpu>):<t>
+      const auto* attrs = call->attrs.as<AllocStorageAttrs>();
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(
+          ForDeviceType(call->checked_type(), static_cast<DLDeviceType>(attrs->device_type)));
+    } else if (call->op == alloc_tensor_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // alloc_tensor(storage, offset, shape)
+      // alloc_tensor: fn(?x?, <cpu>, <cpu>):?x?
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op == shape_func_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // shape_func(func, inputs, outputs, is_inputs=[...])
+      // shape_func: fn(..., <cpu>, <cpu>):<cpu>
+      // where ... is a free domain appropriate for func's type
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      // TODO(mbs): I think this should be on the cpu only when is_input = [false], but
+      // what do we do when we have multiple arguments with different is_input values?
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(cpu_domain_);
+    } else if (call->op == shape_of_op) {
+      ICHECK_EQ(call->args.size(), 1U);
+      // shape_of(tensor)
+      // shape_of: fn(?x?):<cpu>
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      args_and_result.emplace_back(cpu_domain_);
+    } else if (call->op == invoke_tvm_op) {
+      ICHECK_EQ(call->args.size(), 3U);
+      // invoke_tvm_op(op, inputs, outputs)
+      // invoke_tvm_op: fn(..., ?x?, ?x?):?x?
+      // where ... is a free domain appropriate for op's type
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(Free(call->args[0]->checked_type()));
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op == reshape_tensor_op) {
+      ICHECK_EQ(call->args.size(), 2U);
+      // reshape_tensor(data, shape)
+      // reshape_tensor: fn(?x?, <cpu>):?x?
+      auto free_domain = Free(call->checked_type());
+      args_and_result.emplace_back(free_domain);
+      args_and_result.emplace_back(cpu_domain_);
+      args_and_result.emplace_back(free_domain);
+    } else if (call->op->IsInstance<OpNode>()) {
+      // <primitive>(arg1, ..., argn)
+      // <primitive>: fn(?x?, ..., ?x?):?x?
+      // (all args and result must be first-order).
+      auto free_domain = Free(arb_);
+      for (size_t i = 0; i < call->args.size(); ++i) {
+        args_and_result.emplace_back(free_domain);
+      }
+      args_and_result.emplace_back(free_domain);
+    } else {
+      // Defer to normal case where op can be an arbitrary expression.
+      return DomainFor(call->op);
+    }
+    auto domain = MakeDomain(std::move(args_and_result));
+    call_to_callee_domain_.emplace(call.get(), domain);
+    return domain;
+  }
+
+  /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */
+  void UnifyExprExact(const Expr& lhs, const Expr& rhs) {
+    auto lhs_domain = DomainFor(lhs);
+    auto rhs_domain = DomainFor(rhs);
+    try {
+      Unify(lhs_domain, rhs_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expressions:" << std::endl
+                 << PrettyPrint(lhs) << std::endl
+                 << "with device:" << std::endl
+                 << ToString(lhs_domain) << "and:" << std::endl
+                 << PrettyPrint(rhs) << std::endl
+                 << "with device:" << std::endl
+                 << ToString(rhs_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*!
+   * \brief Unifies the domain for \p expr with \p expected_domain.
+   */
+  void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) {
+    auto actual_domain = DomainFor(expr);
+    try {
+      Unify(actual_domain, expected_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expression:" << std::endl
+                 << PrettyPrint(expr) << std::endl
+                 << "with actual device:" << std::endl
+                 << ToString(actual_domain) << std::endl
+                 << "and expected device:" << std::endl
+                 << ToString(expected_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*!
+   * \brief Unifies the domain for \p expr with \p expected_domain.
+   * If \p expected_domain is higher-order but \p expr is first-order, require all arguments
+   * and the result of \p expected_domain to have the same domain as for \p expr.
+   */
+  void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) {
+    auto actual_domain = DomainFor(expr);
+    try {
+      UnifyCollapsed(actual_domain, expected_domain);
+    } catch (const Error& e) {
+      // TODO(mbs): Proper diagnostics.
+      LOG(FATAL) << "Incompatible devices for expression:" << std::endl
+                 << PrettyPrint(expr) << std::endl
+                 << "with actual device:" << std::endl
+                 << ToString(actual_domain) << std::endl
+                 << "and expected device:" << std::endl
+                 << ToString(expected_domain) << std::endl
+                 << e.what();
+    }
+  }
+
+  /*! \brief Returns true if \p domain contains any free sub-domains. */
+  bool AnyFree(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    if (domain->is_free()) {
+      return true;
+    }
+    for (const auto& sub_domain : domain->args_and_result_) {
+      if (AnyFree(sub_domain)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /*
+   * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain.
+   * This can be used to handle functions within tuples, references and ADTs since we don't
+   * attempt to track anything beyond 'the device' for expressions of those first-order types.
+   *
+   * Throws \p Error on failure.
+   */
+  void Collapse(const DeviceDomainPtr& first_order_domain,
+                const DeviceDomainPtr& higher_order_domain) {
+    for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) {
+      Unify(higher_order_domain->function_param(i), first_order_domain);
+    }
+    Unify(higher_order_domain->function_result(), first_order_domain);
+  }
+
+  /*! \brief Force all free domains in \p domain to default to \p default_device_type. */
+  void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) {
+    ICHECK_NE(default_device_type, kInvalidDeviceType);
+    domain = Lookup(domain);
+    if (domain->is_free()) {
+      // Will never throw since lhs is free.
+      Unify(domain, std::make_shared<DeviceDomain>(default_device_type));
+    } else if (!domain->args_and_result_.empty()) {
+      for (const auto& sub_domain : domain->args_and_result_) {
+        SetDefault(sub_domain, default_device_type);
+      }
+    }
+  }
+
+  /*!
+   * \brief If \p domain is higher-order and its result domain is free, force it to
+   * \p default_device_type. Then force any  remaining free domains to the result domain
+   * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault.
+   */
+  void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type) {
+    if (!domain->is_higher_order()) {
+      SetDefault(domain, default_device_type);
+      return;
+    }
+    DLDeviceType result_device_type = ResultDeviceType(domain);
+    if (result_device_type == kInvalidDeviceType) {
+      // If the function result device is still free use the given default.
+      result_device_type = default_device_type;
+    }
+    // Default any remaining free parameters to the function result device.
+    SetDefault(domain, result_device_type);
+  }
+
+  /*! \brief Returns one-line description of \p domain for debugging. */
+  std::string ToString(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    std::ostringstream os;
+    if (domain->is_free()) {
+      // first-order free
+      os << "?" << static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get())) << "?";
+    } else if (domain->args_and_result_.empty()) {
+      // first-order bound
+      os << "<" << domain->device_type_ << ">";
+    } else {
+      // higher-order
+      os << "fn(";
+      for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) {
+        if (i > 0) {
+          os << ",";
+        }
+        os << ToString(domain->args_and_result_[i]);
+      }
+      os << "):" << ToString(domain->args_and_result_.back());
+    }
+    return os.str();
+  }
+
+  /*! \brief Returns description of entire system of constraints for debugging */
+  std::string ToString() {
+    std::ostringstream os;
+    for (const auto& pair : expr_to_domain_) {
+      os << "expression:" << std::endl
+         << PrettyPrint(GetRef<Expr>(pair.first)) << std::endl
+         << "domain:" << std::endl
+         << ToString(pair.second) << std::endl
+         << std::endl;
+    }
+    for (const auto& pair : call_to_callee_domain_) {
+      os << "call:" << std::endl
+         << PrettyPrint(GetRef<Call>(pair.first)) << std::endl
+         << "callee domain:" << std::endl
+         << ToString(pair.second) << std::endl
+         << std::endl;
+    }
+    return os.str();
+  }
+
+  /*!
+   * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment).
+   */
+  DeviceDomainPtr ResultDomain(DeviceDomainPtr domain) {
+    domain = Lookup(domain);
+    while (!domain->args_and_result_.empty()) {
+      domain = Lookup(domain->args_and_result_.back());
+    }
+    return domain;
+  }
+
+  /*!
+   * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain
+   * comment).
+   */
+  DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) {
+    return ResultDomain(domain)->first_order_device_type();
+  }
+
+ private:
+  /*! \brief Intrinsics we need to handle specially. */
+  const Op& alloc_storage_op = Op::Get("memory.alloc_storage");
+  const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor");
+  const Op& shape_of_op = Op::Get("vm.shape_of");
+  const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op");
+  const Op& shape_func_op = Op::Get("vm.shape_func");
+  const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor");
+  /*! \brief The CPU device type for special operators such as dynamic shape functions. */
+  const DLDeviceType cpu_device_type_ = kDLCPU;
+  /*! \brief Placeholder for any first-order type. */
+  Type arb_ = TupleType();
+  /*! \brief The domain for first-order expressions on the CPU. */
+  DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_);
+
+  /*! \brief Maps expressions to their domains as determined during analysis. */
+  std::unordered_map<const ExprNode*, DeviceDomainPtr> expr_to_domain_;
+
+  /*!
+   * \brief Maps call expressions to the domains for their callee where the callee is a primitive.
+   */
+  std::unordered_map<const CallNode*, DeviceDomainPtr> call_to_callee_domain_;
+
+  /*! \brief Maps device domains to their equivalent domains as determined during unification. */
+  std::unordered_map<DeviceDomainPtr, DeviceDomainPtr, DeviceDomainHash, DeviceDomainEqual>
+      domain_to_equiv_;
+};
+
+/******
+****** Phase 0
+******/
+
+/*!
+ * \brief Rewrites "on_device" calls to handle some special cases.
+ */
+class RewriteOnDevices : public ExprMutator {

Review comment:
       dang this file is getting a bit long..




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org