You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/12/17 14:53:28 UTC
[tvm] branch main updated: [Relay] s/SEScope/VirtualDevice/g (#9759)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bd61d18 [Relay] s/SEScope/VirtualDevice/g (#9759)
bd61d18 is described below
commit bd61d18c198b362ec6b763e6b116bf5f5edbcefc
Author: Mark Shields <87...@users.noreply.github.com>
AuthorDate: Fri Dec 17 06:53:07 2021 -0800
[Relay] s/SEScope/VirtualDevice/g (#9759)
* [Relay] s/SEScope/VirtualDevice/g
Nobody liked 'SEScope', and 'DeviceMcDeviceFace' is too verbose, so it
seems 'VirtualDevice' has the popular vote.
---
include/tvm/ir/expr.h | 14 +-
include/tvm/ir/function.h | 12 +-
include/tvm/relay/attrs/device_copy.h | 10 +-
include/tvm/relay/attrs/memory.h | 6 +-
include/tvm/relay/attrs/on_device.h | 29 +-
include/tvm/relay/expr.h | 20 +-
include/tvm/relay/function.h | 2 +-
include/tvm/relay/transform.h | 14 +-
include/tvm/target/compilation_config.h | 37 +-
.../tvm/target/{se_scope.h => virtual_device.h} | 162 ++++----
python/tvm/relay/op/annotation/annotation.py | 12 +-
python/tvm/relay/op/tensor.py | 10 +-
python/tvm/relay/transform/transform.py | 10 +-
python/tvm/target/__init__.py | 2 +-
.../tvm/target/{se_scope.py => virtual_device.py} | 9 +-
src/printer/relay_text_printer.cc | 10 +-
src/relay/backend/aot_executor_codegen.cc | 30 +-
src/relay/backend/build_module.cc | 2 +-
src/relay/backend/graph_executor_codegen.cc | 8 +-
src/relay/backend/graph_plan_memory.cc | 36 +-
src/relay/backend/interpreter.cc | 8 +-
src/relay/backend/te_compiler.cc | 58 +--
src/relay/backend/te_compiler.h | 5 +-
src/relay/backend/utils.cc | 26 +-
src/relay/backend/utils.h | 8 +-
src/relay/backend/vm/compiler.cc | 129 +++---
src/relay/backend/vm/compiler.h | 8 +-
src/relay/backend/vm/lambda_lift.cc | 9 +-
src/relay/ir/expr.cc | 46 +--
src/relay/ir/expr_functor.cc | 14 +-
src/relay/ir/function.cc | 4 +-
src/relay/op/memory/device_copy.cc | 26 +-
src/relay/op/memory/device_copy.h | 29 +-
src/relay/op/memory/memory.cc | 4 +-
src/relay/op/memory/memory.h | 4 +-
src/relay/op/memory/on_device.cc | 74 ++--
src/relay/op/memory/on_device.h | 67 ++--
src/relay/transforms/device_aware_visitors.cc | 82 ++--
src/relay/transforms/device_aware_visitors.h | 104 ++---
src/relay/transforms/device_domains.cc | 130 +++---
src/relay/transforms/device_domains.h | 86 ++--
src/relay/transforms/device_planner.cc | 343 ++++++++--------
src/relay/transforms/fold_constant.cc | 22 +-
src/relay/transforms/memory_alloc.cc | 89 +++--
src/relay/transforms/to_a_normal_form.cc | 16 +-
src/target/compilation_config.cc | 49 +--
src/target/{se_scope.cc => virtual_device.cc} | 54 +--
src/tir/analysis/device_constraint_utils.cc | 107 ++---
src/tir/analysis/device_constraint_utils.h | 28 +-
tests/cpp/relay/op/memory/on_device_test.cc | 28 +-
tests/cpp/relay/transforms/device_domains_test.cc | 12 +-
tests/cpp/target/compilation_config_test.cc | 66 +--
tests/cpp/target/se_scope_test.cc | 119 ------
tests/cpp/target/virtual_device_test.cc | 121 ++++++
.../python/relay/op/annotation/test_annotation.py | 22 +-
tests/python/relay/op/test_tensor.py | 20 +-
.../relay/test_pass_dead_code_elimination.py | 12 +-
tests/python/relay/test_pass_plan_devices.py | 442 +++++++++++----------
.../{test_se_scope.py => test_virtual_device.py} | 32 +-
59 files changed, 1514 insertions(+), 1424 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index a6e5c8d..8937bb7 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -39,8 +39,8 @@ namespace tvm {
using tvm::runtime::String;
-// Forward-declare SEScope to avoid circular imports.
-class SEScope;
+// Forward-declare VirtualDevice to avoid circular imports.
+class VirtualDevice;
/*!
* \brief Base type of all the expressions.
@@ -169,7 +169,7 @@ class RelayExprNode : public BaseExprNode {
inline const TTypeNode* type_as() const;
/*!
- * \brief The virtual device (SEScope) for this node (the result of device planning).
+ * \brief The virtual device (VirtualDevice) for this node (the result of device planning).
* For first-order expressions (non functions), this describes where the result of evaluating the
* expression should be stored. Note that currently, all composite first-order values (tuples,
* references, ADTs) must be stored on the same virtual device. This means that it is not possible
@@ -178,7 +178,7 @@ class RelayExprNode : public BaseExprNode {
*
* For expressions that have the function type, the virtual device describes where the result of
* the call to the function or closure is stored (instead of where the function itself is stored).
- * The SEScope's Target field describes how the body of the function should be compiled.
+ * The VirtualDevice's Target field describes how the body of the function should be compiled.
*
* \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular
* import.
@@ -186,10 +186,10 @@ class RelayExprNode : public BaseExprNode {
mutable ObjectRef virtual_device_;
/*!
- * \return The virtual device (SEScope).
- * If the virtual device is not defined, returns SEScope::FullyUnconstrained().
+ * \return The virtual device (VirtualDevice).
+ * If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained().
*/
- SEScope virtual_device() const;
+ VirtualDevice virtual_device() const;
static constexpr const char* _type_key = "RelayExpr";
static constexpr const uint32_t _type_child_slots = 22;
diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h
index e466cde..051c05d 100644
--- a/include/tvm/ir/function.h
+++ b/include/tvm/ir/function.h
@@ -191,24 +191,24 @@ constexpr const char* kTarget = "target";
constexpr const char* kGlobalSymbol = "global_symbol";
/*!
- * \brief The SEScope which will hold each of the functions parameters.
+ * \brief The \p VirtualDevice which will hold each of the functions parameters.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
- * Type: Array<SEScope>
+ * Type: Array<VirtualDevice>
*/
-constexpr const char* kParamSEScopes = "param_se_scopes";
+constexpr const char* kParamVirtualDevice = "param_virtual_devices";
/*!
- * \brief The SEScope which will hold the function result.
+ * \brief The \p VirtualDevice which will hold the function result.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
- * Type: SEScope
+ * Type: VirtualDevice
*/
-constexpr const char* kResultSEScope = "result_se_scope";
+constexpr const char* kResultVirtualDevice = "result_virtual_device";
} // namespace attr
} // namespace tvm
diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h
index 6d97ab7..fe0534a 100644
--- a/include/tvm/relay/attrs/device_copy.h
+++ b/include/tvm/relay/attrs/device_copy.h
@@ -25,7 +25,7 @@
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_
#include <tvm/ir/attrs.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <string>
@@ -36,13 +36,13 @@ namespace relay {
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
- SEScope src_se_scope = SEScope::FullyUnconstrained();
- SEScope dst_se_scope = SEScope::FullyUnconstrained();
+ VirtualDevice src_virtual_device = VirtualDevice::FullyUnconstrained();
+ VirtualDevice dst_virtual_device = VirtualDevice::FullyUnconstrained();
TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
- TVM_ATTR_FIELD(src_se_scope)
+ TVM_ATTR_FIELD(src_virtual_device)
.describe("The (virtual) device and scope where the op copies data from.");
- TVM_ATTR_FIELD(dst_se_scope)
+ TVM_ATTR_FIELD(dst_virtual_device)
.describe("The (virtual) device and scope where the op copies data to.");
}
};
diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h
index 952d4af..07d6cc7 100644
--- a/include/tvm/relay/attrs/memory.h
+++ b/include/tvm/relay/attrs/memory.h
@@ -26,7 +26,7 @@
#include <tvm/ir/attrs.h>
#include <tvm/relay/expr.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <string>
#include <vector>
@@ -43,13 +43,13 @@ Expr ToTupleType(const Type& t, const std::vector<Expr>& exprs);
*/
struct AllocStorageAttrs : public tvm::AttrsNode<AllocStorageAttrs> {
DataType dtype;
- SEScope se_scope = SEScope::FullyUnconstrained();
+ VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained();
TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") {
TVM_ATTR_FIELD(dtype)
.describe("The dtype of the tensor to allocate.")
.set_default(DataType::Float(32, 1));
- TVM_ATTR_FIELD(se_scope).describe("The SEScope on which to allocate memory.");
+ TVM_ATTR_FIELD(virtual_device).describe("The virtual device on which to allocate memory.");
}
};
diff --git a/include/tvm/relay/attrs/on_device.h b/include/tvm/relay/attrs/on_device.h
index 0931865..3facc3a 100644
--- a/include/tvm/relay/attrs/on_device.h
+++ b/include/tvm/relay/attrs/on_device.h
@@ -25,7 +25,7 @@
#define TVM_RELAY_ATTRS_ON_DEVICE_H_
#include <tvm/ir/attrs.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <string>
@@ -37,42 +37,43 @@ namespace relay {
*
* The Relay call:
* \code
- * on_device(sub_expr, se_scope=S)
+ * on_device(sub_expr, virtual_device=S)
* \endcode
- * constrains \p sub_expr to execute and store its result on the \p SEScope \p S.
+ * constrains \p sub_expr to execute and store its result on the \p VirtualDevice \p S.
* However the annotation itself may appear in an expression to be executed and stored on a
- * different \p SEScope. If so the compiler will automatically insert a "device_copy" call to
- * mediate the transition between \p SEScopes.
+ * different \p VirtualDevice. If so the compiler will automatically insert a "device_copy" call to
+ * mediate the transition between \p VirtualDevices.
*
* E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then:
* \code
- * multiply(on_device(add(%x, %y), se_scope=GPU), %z)
+ * multiply(on_device(add(%x, %y), virtual_device=GPU), %z)
* \endcode
* indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU.
* The compiler will rewrite this to:
* \code
- * multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z)
+ * multiply(device_copy(add(%x, %y), src_virtual_device=GPU, dst_virtual_device=CPU), %z)
* \endcode
*
* The \p constraint_body (default true) and \p constraint_result (default false) fields can be
- * used by passes for finer-grained control over how the \p SEScope constraint should be applied.
+ * used by passes for finer-grained control over how the \p VirtualDevice constraint should be
+ * applied.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
/*!
- * \brief The \p SEScope to constraint to apply to the body, result, or both body and result
+ * \brief The \p VirtualDevice to constraint to apply to the body, result, or both body and result
* of the "on_device" call.
*/
- SEScope se_scope = SEScope::FullyUnconstrained();
+ VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained();
/*!
* \brief If false (the default), the result of the "on_device" call is not constrained to be
- * \p se_scope.
+ * \p virtual_device.
*/
bool constrain_result = false;
/*!
* \brief If true (the default), the body of the "on_device" call is constrained to be \p
- * se_scope.
+ * virtual_device.
*/
bool constrain_body = true;
@@ -87,9 +88,9 @@ struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
bool is_normal() const { return !constrain_result && constrain_body; }
TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
- TVM_ATTR_FIELD(se_scope)
+ TVM_ATTR_FIELD(virtual_device)
.describe("The (virtual) device to constrain to.")
- .set_default(SEScope::FullyUnconstrained());
+ .set_default(VirtualDevice::FullyUnconstrained());
TVM_ATTR_FIELD(constrain_result)
.describe("Whether the constraint applies to the overall expression")
.set_default(false);
diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 8bec724..04dd922 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -28,7 +28,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/ir/op.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <functional>
#include <stack>
@@ -158,7 +158,7 @@ class Tuple : public Expr {
* ret_tuple->span = tuple->span.
*/
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*!
@@ -264,7 +264,7 @@ class Var : public Expr {
*/
Var WithFields(Var var, Optional<Id> opt_vid = Optional<Id>(),
Optional<Type> opt_type_annotation = Optional<Type>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*!
@@ -391,7 +391,7 @@ Call WithFields(Call call, Optional<Expr> opt_op = Optional<Expr>(),
Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(),
Optional<Attrs> opt_attrs = Optional<Attrs>(),
Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*!
@@ -487,7 +487,7 @@ class Let : public Expr {
Let WithFields(Let let, Optional<Var> opt_var = Optional<Var>(),
Optional<Expr> opt_value = Optional<Expr>(),
Optional<Expr> opt_body = Optional<Expr>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*!
@@ -574,7 +574,7 @@ class If : public Expr {
If WithFields(If if_expr, Optional<Expr> opt_cond = Optional<Expr>(),
Optional<Expr> opt_true_branch = Optional<Expr>(),
Optional<Expr> opt_false_branch = Optional<Expr>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*! \brief Get index-th field out of a tuple. */
@@ -640,7 +640,7 @@ class TupleGetItem : public Expr {
*/
TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple = Optional<Expr>(),
Optional<Integer> opt_index = Optional<Integer>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*! \brief Create a new Reference out of initial value. */
@@ -701,7 +701,7 @@ class RefCreate : public Expr {
* ret_ref_create->value = opt_value.value()).
*/
RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value = Optional<Expr>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*! \brief Get value out of Reference. */
@@ -761,7 +761,7 @@ class RefRead : public Expr {
* if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = opt_ref.value()).
*/
RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref = Optional<Expr>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
@@ -829,7 +829,7 @@ class RefWrite : public Expr {
*/
RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref = Optional<Expr>(),
Optional<Expr> opt_value = Optional<Expr>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*!
diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index 1b8ed44..d9bf7ac 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -148,7 +148,7 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params = Optiona
Optional<Type> opt_ret_type = Optional<Type>(),
Optional<Array<TypeVar>> opt_ty_params = Optional<Array<TypeVar>>(),
Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
- Optional<SEScope> opt_virtual_device = Optional<SEScope>(),
+ Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index 2d6cdea..dfc49cb 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -31,8 +31,8 @@
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/target/compilation_config.h>
-#include <tvm/target/se_scope.h>
#include <tvm/target/target.h>
+#include <tvm/target/virtual_device.h>
#include <string>
@@ -449,22 +449,22 @@ TVM_DLL Pass RelayToTIRTargetHook();
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
*
- * \param cpu_se_scope SEScope for computations and data which must reside on a CPU, such as
- * shapes and shape functions.
+ * \param cpu_virtual_device VirtualDevice for computations and data which must reside on a CPU,
+ * such as shapes and shape functions.
*
* \return The pass.
*/
-TVM_DLL Pass ManifestAlloc(SEScope cpu_se_scope);
+TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device);
/*!
- * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p SEScope on which
- * every Relay sub-expression should run and the result stored. Captures the result of that
+ * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p VirtualDevice on
+ * which every Relay sub-expression should run and the result stored. Captures the result of that
* analysis using new "on_device" and "device_copy" CallNodes.
*
* See tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator}
* for help recovering the device for an arbitrary sub-expression in downstream transformations.
*
- * \param config Describes the targets and default \p SEScope for all primitive operators and
+ * \param config Describes the targets and default \p VirtualDevice for all primitive operators and
* host sub-expressions.
*
* \return The pass.
diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h
index 45ff774..1c47a0f 100644
--- a/include/tvm/target/compilation_config.h
+++ b/include/tvm/target/compilation_config.h
@@ -26,12 +26,12 @@
#ifndef TVM_TARGET_COMPILATION_CONFIG_H_
#define TVM_TARGET_COMPILATION_CONFIG_H_
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
namespace tvm {
/*!
- * \brief Gathers the \p Targets and distinguished \p SEScopes in canonical form needed to
+ * \brief Gathers the \p Targets and distinguished \p VirtualDevices in canonical form needed to
* compile a Relay module. Centralizes any setup and validation logic needed to transition
* from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly
* (eg a a list of \p Targets) to the configuration.
@@ -82,13 +82,13 @@ class CompilationConfigNode : public Object {
Array<Target> primitive_targets;
/*!
- * \brief \p SEScope for primitive operators which are not otherwise constrained to a particular
- * device.
+ * \brief \p VirtualDevice for primitive operators which are not otherwise constrained to a
+ * particular device.
*/
- SEScope default_primitive_se_scope = SEScope::FullyUnconstrained();
+ VirtualDevice default_primitive_virtual_device = VirtualDevice::FullyUnconstrained();
- /*! \brief SEScope for the host. */
- SEScope host_se_scope = SEScope::FullyUnconstrained();
+ /*! \brief VirtualDevice for the host. */
+ VirtualDevice host_virtual_device = VirtualDevice::FullyUnconstrained();
/*!
* \brief If defined then compile and/or run in 'homogenous execution mode'. In this mode all
@@ -104,24 +104,25 @@ class CompilationConfigNode : public Object {
void VisitAttrs(AttrVisitor* v);
/*!
- * \brief Returns a \p SEScope agreeing with \p se_scope on all its constrained fields, however:
+ * \brief Returns a \p VirtualDevice agreeing with \p virtual_device on all its constrained
+ * fields, however:
* - If the target is null then it is filled in from the known available primitive targets by
* matching on device type. Fails if no such target is known.
- * - The returned object is unique for the field values w.r.t. all other \p SEScopes returned
- * by this method.
+ * - The returned object is unique for the field values w.r.t. all other \p VirtualDevices
+ * returned by this method.
*
- * We call the result the 'canonical' \p SEScope. Two canonical \p SEScopes are structurally
- * equal if and only if they are pointer equal.
+ * We call the result the 'canonical' \p VirtualDevice. Two canonical \p VirtualDevices are
+ * structurally equal if and only if they are pointer equal.
*/
- SEScope CanonicalSEScope(const SEScope& se_scope) const;
+ VirtualDevice CanonicalVirtualDevice(const VirtualDevice& virtual_device) const;
static constexpr const char* _type_key = "CompilationConfig";
TVM_DECLARE_FINAL_OBJECT_INFO(CompilationConfigNode, Object)
private:
/*!
- * \brief Establishes the default \p SEScope for primitives and the \p SEScope for the host
- * given:
+ * \brief Establishes the default \p VirtualDevice for primitives and the \p VirtualDevice for the
+ * host given:
* - the vector of available primitive \p Targets.
* - any host \p Target.
* - any "relay.fallback_device_type" attribute on \p pass_ctx.
@@ -134,7 +135,7 @@ class CompilationConfigNode : public Object {
* CAUTION: Recreated the primitive_targets so that they all have the given/constructed
* host_target as their host (cf CheckAndUpdateHostConsistency).
*/
- void EstablishDefaultSEScopes(const transform::PassContext& pass_ctx);
+ void EstablishDefaultVirtualDevices(const transform::PassContext& pass_ctx);
/*!
* \brief Returns a freshly constructed \p Target to represent \p device_type.
@@ -147,9 +148,9 @@ class CompilationConfigNode : public Object {
Target FindPrimitiveTargetOrFail(DLDeviceType device_type) const;
/*!
- * \brief A cache of constructed SEScopes.
+ * \brief A cache of constructed virtual devices.
*/
- mutable SEScopeCache se_scope_cache_;
+ mutable VirtualDeviceCache virtual_device_cache_;
friend class CompilationConfig;
};
diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/virtual_device.h
similarity index 65%
rename from include/tvm/target/se_scope.h
rename to include/tvm/target/virtual_device.h
index 314bf05..07011ea 100644
--- a/include/tvm/target/se_scope.h
+++ b/include/tvm/target/virtual_device.h
@@ -18,12 +18,13 @@
*/
/*!
- * \file tvm/target/se_scope.h
- * \brief A compile time representation for a Storage or Execution Scope.
+ * \file tvm/target/virtual_device.h
+ * \brief A compile time representation for where data is to be stored at runtime, and how to
+ * compile code to compute it.
*/
-#ifndef TVM_TARGET_SE_SCOPE_H_
-#define TVM_TARGET_SE_SCOPE_H_
+#ifndef TVM_TARGET_VIRTUAL_DEVICE_H_
+#define TVM_TARGET_VIRTUAL_DEVICE_H_
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
@@ -44,9 +45,13 @@ namespace tvm {
using MemoryScope = String;
/*!
- * \brief Describes at compile time where data is to be stored down to the device and memory
- * scope level, or where execution is to take place, down to the device level. It is a quadruple of:
- * - A \p device_type (\p DLDeviceType). May be kInvalidDeviceType if unconstrained.
+ * \brief Describes at compile time the constraints on where data is to be stored at runtime
+ * down to the (virtual) device and memory scope level, and how to compile code to compute that
+ * data. Used by the \p PlanDevices pass to collect and solve (virtual) device constraints for
+ * the whole Relay program.
+ *
+ * Is a quadruple of:
+ * - A \p device_type (\p DLDeviceType). May be \p kInvalidDeviceType if unconstrained.
* - A \p virtual_device_id (\p int). This allows us to distinguish distinct devices
* with the same \p Target, for example in a multi-GPU system. May be -1 if unconstrained.
* See "Virtual Devices" below.
@@ -60,19 +65,19 @@ using MemoryScope = String;
* choose a value consistent with the whole program. However if a \p target is given then the \p
* device_type must equal \p target->kind->device_type.
*
- * Note that currently we assume if a function returns its result on a particular device
+ * Note that currently we assume if a function returns its result on a particular (virtual) device
* then the function body is also executed on that device. See the overview comment in
* src/relay/transforms/device_planner.cc for more details.
*
* By 'data' we include both tensors and additional supporting datastructures such as shapes,
- * Relay AST items, Relay tuples, and Relay references. Typically non-tensor data must reside
- * on a 'CPU'-like device with good support for scalars.
+ * Relay ADT items (including tuples), Relay references, and Relay closures. Typically non-tensor
+ * data must reside on a 'CPU'-like host device with good support for scalars.
*
* By 'execution' we include both (fused) primitive operators, and all the Relay expressions
* surrounding them which coordinates data and control flow. Again, typically non-primitive
* operators must be executed on a 'CPU'-like device with good support for control flow.
*
- * Since TVM targets such a wide range of systems it is not possible for \p SEScope to impose
+ * Since TVM targets such a wide range of systems it is not possible for \p VirtualDevice to impose
* much semantics on these fields, particularly for \p virtual_device_id and \p memory_scope.
* Instead we assume downstream passes and codegen will interpret an validate these fields
* appropriately.
@@ -84,7 +89,7 @@ using MemoryScope = String;
* compile time) describe a physical device on the target system. Obviously the target must agree
* with the device's microarchitecture, but we otherwise don't impose any constraints between them:
* - It's ok to use different \p Targets for the same \p Device, eg to squeeze some extra perf
- * out of a particular primitive.
+ * out of a particular primitive using particular compiler flags.
* - It's ok to use the same \p Target for multiple \p Devices, eg if we have multiple CPUs.
*
* Traditionally TVM assumes at most one \p Target per \p DLDeviceType. We are moving away from that
@@ -133,14 +138,14 @@ using MemoryScope = String;
* a memory scope to only be accessible to a device when code is compiled with particular
* \p Target options.
*
- * \p SEScopes themselves have no system-level understanding. Currently device planning will
- * simply insert "device_copy" operators wherever \p SEScopes are not exactly pointwise equal.
- * We may revisit this in the future as the work on memory pools matures.
+ * \p VirtualDevices themselves have no system-level understanding. Currently the \p PlanDevices
+ * pass will simply insert "device_copy" operators wherever \p VirtualDevices are not exactly
+ * pointwise equal. We may revisit this in the future as the work on memory pools matures.
*
* Joining and Defaulting
* ----------------------
- * It is possible to 'join' two \p SEScopes to yield the most constrained \p SEScope which agrees
- * with both join arguments. Eg:
+ * It is possible to 'join' two \p VirtualDevices to yield the most constrained \p VirtualDevice
+ * which agrees with both join arguments. Eg:
* \code
* Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "global))
* => (kDLCPU, 3, "llvm", "global")
@@ -156,9 +161,8 @@ using MemoryScope = String;
* \endcode
*
* These operations are needed during device planning.
- *
*/
-class SEScopeNode : public AttrsNode<SEScopeNode> {
+class VirtualDeviceNode : public AttrsNode<VirtualDeviceNode> {
private:
/*!
* \brief The \p DLDeviceType (represented as an int) of the virtual device. If \p target is
@@ -187,7 +191,7 @@ class SEScopeNode : public AttrsNode<SEScopeNode> {
/*!
* \brief The \p Target describing how to compile for the virtual device.
*
- * Null denotes unconstrained. Note that if a target later becomes known for this \p SEScope
+ * Null denotes unconstrained. Note that if a target later becomes known for this \p VirtualDevice
* then it must be consistent with the \p device_type if already known. This is enforced by the
* Join and Default methods.
*/
@@ -201,8 +205,8 @@ class SEScopeNode : public AttrsNode<SEScopeNode> {
MemoryScope memory_scope;
/*!
- * \brief Returns true if scope is fully unconstrained, ie no target/device type, device id
- * or memory scope is specified.
+ * \brief Returns true if virtual device is 'fully unconstrained', ie no target/device type,
+ * device id or memory scope is specified.
*/
bool IsFullyUnconstrained() const {
return !target.defined() && device_type() == kInvalidDeviceType && virtual_device_id == -1 &&
@@ -210,18 +214,18 @@ class SEScopeNode : public AttrsNode<SEScopeNode> {
}
/*!
- * \brief Returns true if scope is fully constrained, ie target, device id and memory scope are
- * all specified.
+ * \brief Returns true if virtual device is 'fully constrained', ie target, device id and memory
+ * scope are all specified.
*/
bool IsFullyConstrained() const {
return target.defined() && virtual_device_id != -1 && !memory_scope.empty();
}
/*!
- * \brief Returns the (virtual) \p Device implied by this \p SEScope. Both the \p device_type and
- * \p virtual_device_must be constrained. The returned \p Device may not correspond to any
- * physical device available at compile time or even runtime: see "Virtual vs Physical Devices"
- * above.
+ * \brief Returns the (virtual) \p Device implied by this \p VirtualDevice. Both the \p
+ * device_type and \p virtual_device_must be constrained. The returned \p Device may not
+ * correspond to any physical device available at compile time or even runtime: see "Virtual vs
+ * Physical Devices" above.
*/
Device ToDevice() const {
ICHECK(device_type() != kInvalidDeviceType);
@@ -232,7 +236,7 @@ class SEScopeNode : public AttrsNode<SEScopeNode> {
return device;
}
- TVM_DECLARE_ATTRS(SEScopeNode, "SEScope") {
+ TVM_DECLARE_ATTRS(VirtualDeviceNode, "VirtualDevice") {
TVM_ATTR_FIELD(device_type_int)
.describe("The type of the virtual device.")
.set_default(kInvalidDeviceType);
@@ -247,74 +251,72 @@ class SEScopeNode : public AttrsNode<SEScopeNode> {
.set_default("");
}
- friend class SEScope;
+ friend class VirtualDevice;
};
/*!
- * \brief Managed reference class to \p SEScopeNode.
- *
- * \sa SEScopeNode.
+ * \brief Managed reference class to \p VirtualDeviceNode.
*/
-class SEScope : public ObjectRef {
+class VirtualDevice : public ObjectRef {
public:
/*!
- * \brief Construct an SEScope.
- * \param device_type The device type for the virtual device, or kInvalidDeviceType if
+ * \brief Construct a virtual device.
+ * \param device_type The device type for the virtual device, or \p kInvalidDeviceType if
* unconstrained. If \p target is defined then must match its \p target->kind->device_type.
* \param virtual_device_id The device id for the virtual device, or -1 if unconstrained.
* \param target The target describing how to compile for the virtual device, or null if
* unconstrained.
* \param memory_scope The memory scope w.r.t. the virtual device which holds data, or "" if
* unconstrained.
- * \return The SEScope
+ * \return The virtual device.
*/
- explicit SEScope(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1,
- Target target = {}, MemoryScope memory_scope = {});
+ explicit VirtualDevice(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1,
+ Target target = {}, MemoryScope memory_scope = {});
- /*! \brief Returns the unique fully unconstrained \p SEScope. */
- static SEScope FullyUnconstrained();
+ /*! \brief Returns the unique fully unconstrained \p VirtualDevice. */
+ static VirtualDevice FullyUnconstrained();
/*!
- * \brief Returns the \p SEScope for \p device_type and (if not -1) \p virtual_device_id.
+ * \brief Returns the \p VirtualDevice for \p device_type and (if not -1) \p virtual_device_id.
* The target and memory scope will be unconstrained.
*/
- static SEScope ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) {
+ static VirtualDevice ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) {
ICHECK_GT(device_type, 0);
- return SEScope(device_type, virtual_device_id);
+ return VirtualDevice(device_type, virtual_device_id);
}
- static SEScope ForDeviceType(int device_type, int virtual_device_id = -1) {
+ static VirtualDevice ForDeviceType(int device_type, int virtual_device_id = -1) {
return ForDeviceType(static_cast<DLDeviceType>(device_type), virtual_device_id);
}
- static SEScope ForDeviceType(const Integer& device_type, int virtual_device_id = -1) {
+ static VirtualDevice ForDeviceType(const Integer& device_type, int virtual_device_id = -1) {
return ForDeviceType(static_cast<int>(device_type->value), virtual_device_id);
}
- /*! \brief Returns the \p SEScope for \p device. */
- static SEScope ForDevice(const Device& device) {
+ /*! \brief Returns the \p VirtualDevice for \p device. */
+ static VirtualDevice ForDevice(const Device& device) {
return ForDeviceType(device.device_type, device.device_id);
}
- /*! \brief Returns the \p SEScope for \p device and \p target. */
- static SEScope ForDeviceAndTarget(const Device& device, Target target) {
- return SEScope(device.device_type, device.device_id, std::move(target));
+ /*! \brief Returns the \p VirtualDevice for \p device and \p target. */
+ static VirtualDevice ForDeviceAndTarget(const Device& device, Target target) {
+ return VirtualDevice(device.device_type, device.device_id, std::move(target));
}
- /*! \brief Returns the \p SEScope for \p target. */
- static SEScope ForTarget(Target target) {
+ /*! \brief Returns the \p VirtualDevice for \p target. */
+ static VirtualDevice ForTarget(Target target) {
DLDeviceType device_type = static_cast<DLDeviceType>(target->kind->device_type);
- return SEScope(device_type, /*virtual_device_id=*/0, std::move(target));
+ return VirtualDevice(device_type, /*virtual_device_id=*/0, std::move(target));
}
- /*! \brief Returns the \p SEScope for \p memory_scope alone. */
- static SEScope ForMemoryScope(MemoryScope memory_scope) {
- return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope));
+ /*! \brief Returns the \p VirtualDevice for \p memory_scope alone. */
+ static VirtualDevice ForMemoryScope(MemoryScope memory_scope) {
+ return VirtualDevice(kInvalidDeviceType, -1, {}, std::move(memory_scope));
}
- /*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
- TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
- MemoryScope memory_scope) {
- return SEScope(device.device_type, device.device_id, std::move(target),
- std::move(memory_scope));
+ /*! \brief Returns the \p VirtualDevice for \p device, \p target and \p memory_scope. */
+ TVM_DLL static VirtualDevice ForDeviceTargetAndMemoryScope(const Device& device, Target target,
+ MemoryScope memory_scope) {
+ return VirtualDevice(device.device_type, device.device_id, std::move(target),
+ std::move(memory_scope));
}
/*!
@@ -322,41 +324,43 @@ class SEScope : public ObjectRef {
* \p lhs and \p rhs on all their constrained fields. Returns the null optional if no such
* join exists, ie there's disagreement on at least one constrained field.
*/
- static Optional<SEScope> Join(const SEScope& lhs, const SEScope& rhs);
+ static Optional<VirtualDevice> Join(const VirtualDevice& lhs, const VirtualDevice& rhs);
/*!
* \brief Returns the 'default' of \p lhs and \p rhs. The result will be \p lhs, except any
* unconstrained fields in \p lhs will take their value from \p rhs. Always well-defined.
*/
- static SEScope Default(const SEScope& lhs, const SEScope& rhs);
+ static VirtualDevice Default(const VirtualDevice& lhs, const VirtualDevice& rhs);
- TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SEScope, ObjectRef, SEScopeNode);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VirtualDevice, ObjectRef, VirtualDeviceNode);
- friend class SEScopeCache; // Private implementation helper.
+ friend class VirtualDeviceCache; // Private implementation helper.
};
/*!
- * \brief A cache of \p SEScopes. This can be used:
- * - To avoid ending up with lots of identical instances, since the space of SEScopes for any
+ * \brief A cache of \p VirtualDevices. This can be used:
+ * - To avoid ending up with lots of identical instances, since the space of VirtualDevices for any
* one compilation is very small but the number of points they need to be constructed can
* be very large (eg during device planning).
- * - So we can assume \p SEScopes are pointer equal if and only if they are structurally equal.
- * This simplifies the unification of 'device domains' which are built on \p SEScopes.
+ * - So we can assume \p VirtualDevices are pointer equal if and only if they are structurally
+ * equal. This simplifies the unification of 'device domains' which are built on \p VirtualDevices.
*/
-class SEScopeCache {
+class VirtualDeviceCache {
public:
- /*! \brief Returns the unique \p SEScope representing given fields. */
- SEScope Make(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1,
- Target target = {}, MemoryScope memory_scope = {});
+ /*! \brief Returns the unique \p VirtualDevice representing given fields. */
+ VirtualDevice Make(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1,
+ Target target = {}, MemoryScope memory_scope = {});
- /*! \brief Returns the unique \p SEScope structurally equal to the given \p se_scope. */
- SEScope Unique(const SEScope& scope);
+ /*!
+ * \brief Returns the unique \p VirtualDevice structurally equal to the given \p virtual_device.
+ */
+ VirtualDevice Unique(const VirtualDevice& virtual_device);
private:
- /*! \brief Already constructed SEScopes. */
- std::unordered_set<SEScope, StructuralHash, StructuralEqual> cache_;
+ /*! \brief Already constructed VirtualDevices. */
+ std::unordered_set<VirtualDevice, StructuralHash, StructuralEqual> cache_;
};
} // namespace tvm
-#endif // TVM_TARGET_SE_SCOPE_H_
+#endif // TVM_TARGET_VIRTUAL_DEVICE_H_
diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py
index cb4e628..f2ce6c5 100644
--- a/python/tvm/relay/op/annotation/annotation.py
+++ b/python/tvm/relay/op/annotation/annotation.py
@@ -23,11 +23,11 @@ from . import _make
from .. import op as reg
-def _make_se_scope(device):
+def _make_virtual_device(device):
if isinstance(device, _Device):
- return target.make_se_scope(device)
+ return target.make_virtual_device(device)
if isinstance(device, str):
- return target.make_se_scope(_nd.device(device))
+ return target.make_virtual_device(_nd.device(device))
raise ValueError("expecting a Device or device name, but received a %s" % (type(device)))
@@ -59,7 +59,7 @@ def on_device(body, device, constrain_result=False, constrain_body=True):
result : tvm.relay.Expr
The annotated expression.
"""
- return _make.OnDevice(body, _make_se_scope(device), constrain_result, constrain_body)
+ return _make.OnDevice(body, _make_virtual_device(device), constrain_result, constrain_body)
def function_on_device(function, param_devices, result_device):
@@ -83,7 +83,9 @@ def function_on_device(function, param_devices, result_device):
The annotated function.
"""
return _make.FunctionOnDevice(
- function, [_make_se_scope(d) for d in param_devices], _make_se_scope(result_device)
+ function,
+ [_make_virtual_device(d) for d in param_devices],
+ _make_virtual_device(result_device),
)
diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index d9847a4..20b883b 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -27,11 +27,11 @@ from ..expr import Tuple, Expr, Constant
from . import op as reg
-def _make_se_scope(device):
+def _make_virtual_device(device):
if isinstance(device, _Device):
- return target.make_se_scope(device)
+ return target.make_virtual_device(device)
if isinstance(device, str):
- return target.make_se_scope(_nd.device(device))
+ return target.make_virtual_device(_nd.device(device))
raise ValueError("expecting a Device or device name, but received a %s" % (type(device)))
@@ -1211,7 +1211,9 @@ def device_copy(data, src_device, dst_device):
result : tvm.relay.Expr
The copied result.
"""
- return _make.DeviceCopy(data, _make_se_scope(src_device), _make_se_scope(dst_device))
+ return _make.DeviceCopy(
+ data, _make_virtual_device(src_device), _make_virtual_device(dst_device)
+ )
def shape_of(data, dtype="int32"):
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index 4369009..bbe4bc2 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -1164,12 +1164,12 @@ def SimplifyExpr():
def PlanDevices(config):
"""
- Uses existing "on_device" and "device_copy" CallNodes to infer the SEScope on which
+ Uses existing "on_device" and "device_copy" calls to infer the virtual device on which
every Relay sub-expression should run and the result stored. Captures the result of that
- analysis using new "on_device" and "device_copy" CallNodes. Sub-expressions which are
- not otherwise constrained are assigned to the default_primitive_se_scope. However data and
- computations which must be hosted on a CPU (such as shapes and shape functions) use the
- cpu_se_scope.
+ analysis using new "on_device" and "device_copy" calls. Sub-expressions which are
+ not otherwise constrained are assigned to the default primitive virtual device describe by
+ config. However data and computations which must be hosted on a CPU (such as shapes and
+ shape functions) use the host virtual device of the config.
Parameters
----------
diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py
index 1b90a50..6c13ced 100644
--- a/python/tvm/target/__init__.py
+++ b/python/tvm/target/__init__.py
@@ -71,7 +71,7 @@ from .target import (
riscv_cpu,
hexagon,
)
-from .se_scope import make_se_scope
+from .virtual_device import make_virtual_device
from .compilation_config import make_compilation_config
from .tag import list_tags
from .generic_func import GenericFunc
diff --git a/python/tvm/target/se_scope.py b/python/tvm/target/virtual_device.py
similarity index 72%
rename from python/tvm/target/se_scope.py
rename to python/tvm/target/virtual_device.py
index 83df5ae..a88d405 100644
--- a/python/tvm/target/se_scope.py
+++ b/python/tvm/target/virtual_device.py
@@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Python bindings for creating SEScopes."""
+"""Python bindings for creating VirtualDevices."""
from . import _ffi_api
-def make_se_scope(device, target=None, memory_scope=""):
- return _ffi_api.SEScope_ForDeviceTargetAndMemoryScope(device, target, memory_scope)
+# TODO(mbs): We need an official Python class representation given the importance of this structure.
+
+
+def make_virtual_device(device, target=None, memory_scope=""):
+ return _ffi_api.VirtualDevice_ForDeviceTargetAndMemoryScope(device, target, memory_scope)
diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc
index d0c2cfe..fdc6c37 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -37,7 +37,7 @@
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <tvm/tir/function.h>
#include "../ir/attr_functor.h"
@@ -906,14 +906,14 @@ Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_met
printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
} else if (force_meta) {
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
- } else if (const auto* se_scope_node = value.as<SEScopeNode>()) {
+ } else if (const auto* virtual_device_node = value.as<VirtualDeviceNode>()) {
if (show_meta_data_) {
- printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(se_scope_node));
+ printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(virtual_device_node));
} else {
- // Special case: The ReprPrinter for SEScopeNodes is much easier to work with while
+ // Special case: The ReprPrinter for VirtualDeviceNodes is much easier to work with while
// debugging.
std::ostringstream os;
- os << GetRef<SEScope>(se_scope_node);
+ os << GetRef<VirtualDevice>(virtual_device_node);
return Doc::Text(os.str());
}
} else if (const auto* base_attr_node = value.as<BaseAttrsNode>()) {
diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index 9ea1e42..d901f8a 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -135,18 +135,19 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
void VisitExpr_(const TupleNode* op) final {
std::vector<int64_t> storage_ids;
- std::vector<SEScope> se_scopes;
+ std::vector<VirtualDevice> virtual_devices;
std::vector<int64_t> storage_sizes_in_bytes;
Expr expr = GetRef<Expr>(op);
for (Expr field : op->fields) {
auto sid = GetStorage(field);
storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end());
- se_scopes.insert(se_scopes.end(), sid->se_scopes.begin(), sid->se_scopes.end());
+ virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(),
+ sid->virtual_devices.end());
storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(),
sid->storage_sizes_in_bytes.begin(),
sid->storage_sizes_in_bytes.end());
}
- storage_device_map_[expr] = StorageInfo(storage_ids, se_scopes, storage_sizes_in_bytes);
+ storage_device_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes);
AssignReturnSid(expr);
}
@@ -155,7 +156,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
auto sids = GetStorage(op->tuple);
ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size());
storage_device_map_[expr] =
- StorageInfo({sids->storage_ids[op->index]}, {sids->se_scopes[op->index]},
+ StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]},
{sids->storage_sizes_in_bytes[op->index]});
AssignReturnSid(expr);
}
@@ -221,24 +222,25 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
*/
void CreateStorage(const ExprNode* op) {
Expr expr = GetRef<Expr>(op);
- return CreateStorage(expr, GetSEScope(expr));
+ return CreateStorage(expr, GetVirtualDevice(expr));
}
/*!
- * \brief Create storage to hold the result of evaluating \p expr in \p se_scope.
+ * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device.
*/
- void CreateStorage(const Expr& expr, const SEScope& se_scope) {
- ICHECK(!se_scope->IsFullyUnconstrained()) << "invalid SEScope for expr:" << std::endl
- << PrettyPrint(expr);
+ void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) {
+ ICHECK(!virtual_device->IsFullyUnconstrained())
+ << "invalid virtual device for expr:" << std::endl
+ << PrettyPrint(expr);
std::vector<int64_t> storage_ids;
- std::vector<SEScope> se_scopes;
+ std::vector<VirtualDevice> virtual_devices;
std::vector<int64_t> storage_sizes_in_bytes;
for (const auto& ttype : FlattenTupleType(expr->checked_type())) {
storage_ids.push_back(next_available_sid_++);
- se_scopes.push_back(se_scope);
+ virtual_devices.push_back(virtual_device);
storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype));
}
- storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(se_scopes),
+ storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices),
std::move(storage_sizes_in_bytes));
}
@@ -736,7 +738,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
use_unpacked_api_ = executor_config->GetAttr<Bool>("unpacked-api").value_or(Bool(false));
// TODO(mbs): Plumb from compiler config
- SEScope host_se_scope = SEScope::ForTarget(target_host_);
+ VirtualDevice host_virtual_device = VirtualDevice::ForTarget(target_host_);
IRModule lowered_mod = tec::LowerTEPass(
mod_name,
@@ -753,7 +755,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
},
- host_se_scope)(mod);
+ host_virtual_device)(mod);
auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index ab86dbf..ccfd304 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -424,7 +424,7 @@ class RelayBuildModule : public runtime::ModuleNode {
lowered_funcs.Set(ext_dev, IRModule());
}
- const Target& host_target = config_->host_se_scope->target;
+ const Target& host_target = config_->host_virtual_device->target;
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
// Generate a placeholder function that attaches linked params as its arguments.
diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc
index 16b1ddb..f61fe9b 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -245,7 +245,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
},
- config->host_se_scope)(mod);
+ config->host_virtual_device)(mod);
Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
@@ -328,10 +328,10 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
node->attrs_["storage_id"] = std::move(storage_ids);
// type
std::vector<int64_t> device_types;
- for (const auto& se_scope : storage_info->se_scopes) {
+ for (const auto& virtual_device : storage_info->virtual_devices) {
// TODO(mbs): Keeping only the device type.
- ICHECK_GT(se_scope->device_type(), 0);
- device_types.push_back(se_scope->device_type());
+ ICHECK_GT(virtual_device->device_type(), 0);
+ device_types.push_back(virtual_device->device_type());
}
size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0);
if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) {
diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc
index 3ee3187..2ad27a0 100644
--- a/src/relay/backend/graph_plan_memory.cc
+++ b/src/relay/backend/graph_plan_memory.cc
@@ -53,19 +53,21 @@ struct StorageToken {
size_t max_bytes{0};
/*! \brief The corresponding tensor type. */
TensorType ttype{nullptr};
- /*! \brief SEScope on which the memory will reside. */
- SEScope se_scope = SEScope::FullyUnconstrained();
+ /*! \brief VirtualDevice on which the memory will reside. */
+ VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained();
/*! \brief The storage id */
int64_t storage_id{-1};
- bool is_valid() const { return !se_scope->IsFullyUnconstrained(); }
+ bool is_valid() const { return !virtual_device->IsFullyUnconstrained(); }
- bool is_compatible(const StorageToken& that) const { return se_scope == that.se_scope; }
+ bool is_compatible(const StorageToken& that) const {
+ return virtual_device == that.virtual_device;
+ }
std::string ToString() const {
std::ostringstream os;
os << "{storage_id: " << storage_id << ", max_bytes: " << max_bytes
- << ", ttype: " << PrettyPrint(ttype) << ", se_scope: " << se_scope << "}";
+ << ", ttype: " << PrettyPrint(ttype) << ", virtual_device: " << virtual_device << "}";
return os.str();
}
};
@@ -167,14 +169,14 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor {
* the result of evaluating \p op.
*/
void CreateToken(const ExprNode* expr_node, bool can_realloc) {
- return CreateTokenOnDevice(expr_node, GetSEScope(GetRef<Expr>(expr_node)), can_realloc);
+ return CreateTokenOnDevice(expr_node, GetVirtualDevice(GetRef<Expr>(expr_node)), can_realloc);
}
/*!
* \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding
* the result of evaluating \p op on \p device_type.
*/
- virtual void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope,
+ virtual void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device,
bool can_realloc) = 0;
};
@@ -193,13 +195,14 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
protected:
using StorageAllocaBaseVisitor::VisitExpr_;
- void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) override {
+ void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device,
+ bool can_realloc) override {
ICHECK(!token_map_.count(op));
std::vector<StorageToken*> tokens;
for (const auto& ttype : FlattenTupleType(op->checked_type())) {
auto* token = arena_->make<StorageToken>();
token->ttype = ttype;
- token->se_scope = se_scope;
+ token->virtual_device = virtual_device;
tokens.push_back(token);
}
token_map_[op] = tokens;
@@ -256,8 +259,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
for (const auto& kv : token_map_) {
std::vector<int64_t> storage_ids;
storage_ids.reserve(kv.second.size());
- std::vector<SEScope> se_scopes;
- se_scopes.reserve(kv.second.size());
+ std::vector<VirtualDevice> virtual_devices;
+ virtual_devices.reserve(kv.second.size());
std::vector<int64_t> sid_sizes_byte;
sid_sizes_byte.reserve(kv.second.size());
@@ -268,10 +271,10 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
}
num_nodes++;
storage_ids.push_back(tok->storage_id);
- se_scopes.push_back(tok->se_scope);
+ virtual_devices.push_back(tok->virtual_device);
sid_sizes_byte.push_back(GetMemorySize(tok));
}
- auto storage_info = backend::StorageInfo(std::move(storage_ids), std::move(se_scopes),
+ auto storage_info = backend::StorageInfo(std::move(storage_ids), std::move(virtual_devices),
std::move(sid_sizes_byte));
smap.Set(GetRef<Expr>(kv.first), storage_info);
}
@@ -286,20 +289,21 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
protected:
// override create token by getting token as prototype requirements.
- void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) final {
+ void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device,
+ bool can_realloc) final {
ICHECK(!token_map_.count(op));
auto it = prototype_.find(op);
ICHECK(it != prototype_.end());
std::vector<StorageToken*> tokens;
for (StorageToken* tok : it->second) {
- ICHECK(tok->se_scope == se_scope);
+ ICHECK(tok->virtual_device == virtual_device);
if (can_realloc) {
tokens.push_back(Request(tok));
} else {
// Allocate a new token,
StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok));
- allocated_tok->se_scope = tok->se_scope;
+ allocated_tok->virtual_device = tok->virtual_device;
// ensure it never get de-allocated.
allocated_tok->ref_counter += 1;
tokens.push_back(allocated_tok);
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 82a0455..2bea810 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -474,7 +474,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
// whether the shape and or data needs to be passed, and flattening of tuples.
// Similarly, num_shape_outputs will account for flattening of tuples.
- // TODO(mbs): Take this from the host_se_scope.
+ // TODO(mbs): Take this from the host_virtual_device.
Device shape_device;
shape_device.device_type = static_cast<DLDeviceType>(prim_shape_target->kind->device_type);
shape_device.device_id = 0;
@@ -754,7 +754,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
return InvokePrimitiveOp(call_lowered_props.lowered_func, all_prim_fn_vars,
config_->optional_homogeneous_target, prim_shape_fn_var,
all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs,
- num_shape_outputs, config_->host_se_scope->target, args);
+ num_shape_outputs, config_->host_virtual_device->target, args);
} else { // All other calls
// Evaluate all arguments
std::vector<ObjectRef> args;
@@ -945,7 +945,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
* functions needed by the rewritten module.
*/
IRModule Prepare(IRModule mod, CompilationConfig config) {
- SEScope host_se_scope = config->host_se_scope;
+ VirtualDevice host_virtual_device = config->host_virtual_device;
// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq(
{transform::SimplifyInference(),
@@ -962,7 +962,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) {
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ },
- std::move(host_se_scope))});
+ std::move(host_virtual_device))});
transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 528df64..901661d 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -503,13 +503,13 @@ using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual
class LowerTensorExprMutator : public DeviceAwareExprMutator {
public:
LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, String module_name,
- TECompiler compiler, SEScope host_se_scope)
+ TECompiler compiler, VirtualDevice host_virtual_device)
: DeviceAwareExprMutator(module),
module_(module),
process_fn_(std::move(process_fn)),
module_name_(std::move(module_name)),
compiler_(std::move(compiler)),
- host_se_scope_(std::move(host_se_scope)),
+ host_virtual_device_(std::move(host_virtual_device)),
debug_op_(Op::Get("debug")) {}
/*!
@@ -609,7 +609,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
// Shape function keys use the underlying primitive function as their 'function',
// but the generic 'cpu' target as the target since all shape functions run
// on the host cpu irrespective of where the primitive runs.
- CCacheKey shape_key(func, host_se_scope_->target);
+ CCacheKey shape_key(func, host_virtual_device_->target);
CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
// Capture the shape function's global var and parameters 'states' in call
@@ -707,8 +707,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body);
if (device_copy_props.body.defined()) {
ICHECK_EQ(new_args.size(), 1);
- return DeviceCopy(new_args[0], device_copy_props.src_se_scope,
- device_copy_props.dst_se_scope);
+ return DeviceCopy(new_args[0], device_copy_props.src_virtual_device,
+ device_copy_props.dst_virtual_device);
}
}
@@ -746,9 +746,9 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
target = Target("ext_dev");
} else {
// The target corresponding to the call_node expression's annotation.
- SEScope se_scope = GetSEScope(GetRef<Call>(call_node));
- ICHECK(!se_scope->IsFullyUnconstrained());
- target = se_scope->target;
+ VirtualDevice virtual_device = GetVirtualDevice(GetRef<Call>(call_node));
+ ICHECK(!virtual_device->IsFullyUnconstrained());
+ target = virtual_device->target;
ICHECK(target.defined());
}
@@ -769,10 +769,10 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
String module_name_;
TECompiler compiler_;
/*!
- * \brief The \p SEScope for the host, which is where all shape-related data and computation
+ * \brief The \p VirtualDevice for the host, which is where all shape-related data and computation
* must live.
*/
- SEScope host_se_scope_;
+ VirtualDevice host_virtual_device_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
};
@@ -808,10 +808,11 @@ Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) {
}
Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn,
- SEScope host_se_scope) {
+ VirtualDevice host_virtual_device) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
- LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler, host_se_scope);
+ LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler,
+ host_virtual_device);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
@@ -828,7 +829,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa
}
// This is a Map<device,Map<storage_id, size>>
- // TODO(mbs): Collapsing SEScopes to just device type.
+ // TODO(mbs): Collapsing VirtualDevices to just device type.
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, backend::EnumClassHash>
sid_workspace;
// This is a Map<device, size_of_inputs_and_outputs>
@@ -841,10 +842,10 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa
for (const auto& kv : storage_info_map) {
const backend::StorageInfo& storage_info = kv.second;
const std::vector<int64_t>& storage_ids = storage_info->storage_ids;
- const std::vector<SEScope>& se_scopes = storage_info->se_scopes;
- CHECK_EQ(storage_ids.size(), se_scopes.size());
- for (uint32_t i = 0; i < se_scopes.size(); i++) {
- DLDeviceType device_type = se_scopes[i]->device_type();
+ const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices;
+ CHECK_EQ(storage_ids.size(), virtual_devices.size());
+ for (uint32_t i = 0; i < virtual_devices.size(); i++) {
+ DLDeviceType device_type = virtual_devices[i]->device_type();
sid_workspace[device_type][storage_ids[i]] = 0;
device_io[device_type] = 0;
device_consts[device_type] = 0;
@@ -877,18 +878,18 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa
<< "has size " << size_bytes << " and storage info:" << std::endl
<< storage_info;
const std::vector<int64_t>& storage_ids = storage_info->storage_ids;
- const std::vector<SEScope>& se_scopes = storage_info->se_scopes;
+ const std::vector<VirtualDevice>& virtual_devices = storage_info->virtual_devices;
if (expr->IsInstance<ConstantNode>()) {
- for (const auto& se_scope : se_scopes) {
- DLDeviceType device_type = se_scope->device_type();
+ for (const auto& virtual_device : virtual_devices) {
+ DLDeviceType device_type = virtual_device->device_type();
ICHECK_EQ(device_consts.count(device_type), 1);
device_consts[device_type] += size_bytes;
}
} else if (expr->IsInstance<VarNode>() || expr.same_as(func->body)) {
- CHECK_GE(se_scopes.size(), 1) << "must be at least one device";
- for (const auto& se_scope : se_scopes) {
- DLDeviceType device_type = se_scope->device_type();
+ CHECK_GE(virtual_devices.size(), 1) << "must be at least one device";
+ for (const auto& virtual_device : virtual_devices) {
+ DLDeviceType device_type = virtual_device->device_type();
device_io[device_type] += size_bytes;
}
} else {
@@ -899,7 +900,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa
// Here we record the largest size of the tensor
// that share the same storage id, because storage_id will
// be shared between multiple tensors that are not live simultaneously.
- DLDeviceType device_type = se_scopes[i]->device_type();
+ DLDeviceType device_type = virtual_devices[i]->device_type();
if (size_bytes > sid_workspace[device_type][storage_ids[i]]) {
sid_workspace[device_type][storage_ids[i]] = size_bytes;
}
@@ -1045,7 +1046,7 @@ void UpdateFunctionMetadata(BaseFunc func,
}
IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn,
- SEScope host_se_scope) {
+ VirtualDevice host_virtual_device) {
TECompiler compiler(module);
// TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten
@@ -1061,7 +1062,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr
// - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated
// (using call_lowered convention).
IRModule updated_module = LowerTensorExpr(module_name, compiler, std::move(process_fn),
- std::move(host_se_scope))(module);
+ std::move(host_virtual_device))(module);
// The Functions tagged with "Compiler" are now residing in the cache ready to be
// compiled by LowerExternalFunctions. However we still need a record of them in the
@@ -1161,10 +1162,11 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
return per_target_modules;
}
-Pass LowerTEPass(const String& module_name, ProcessFn process_fn, SEScope host_se_scope) {
+Pass LowerTEPass(const String& module_name, ProcessFn process_fn,
+ VirtualDevice host_virtual_device) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
- return LowerTE(module, module_name, process_fn, host_se_scope);
+ return LowerTE(module, module_name, process_fn, host_virtual_device);
};
return tvm::transform::Sequential(
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index 60dd5fe..b6f2218 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -214,10 +214,11 @@ IRModule LowerTE(
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
- * \param host_se_scope \p SEScope for host data and computations
+ * \param host_virtual_device \p VirtualDevice for host data and computations
* \returns The pass which lowers primative functions to TIR
*/
-transform::Pass LowerTEPass(const String& module_name, ProcessFn process_fn, SEScope host_se_scope);
+transform::Pass LowerTEPass(const String& module_name, ProcessFn process_fn,
+ VirtualDevice host_virtual_device);
} // namespace tec
} // namespace relay
} // namespace tvm
diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc
index 252c43f9..608d4cd 100644
--- a/src/relay/backend/utils.cc
+++ b/src/relay/backend/utils.cc
@@ -43,9 +43,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
for (auto id : node->storage_ids) {
p->stream << id << ",";
}
- p->stream << "], se_scopes=[";
- for (const auto& se_scope : node->se_scopes) {
- p->stream << se_scope << ",";
+ p->stream << "], virtual_devices=[";
+ for (const auto& virtual_device : node->virtual_devices) {
+ p->stream << virtual_device << ",";
}
p->stream << "], storage_size_in_bytes=[";
for (auto bytes : node->storage_sizes_in_bytes) {
@@ -54,13 +54,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "])";
});
-StorageInfo::StorageInfo(std::vector<int64_t> storage_ids, std::vector<SEScope> se_scopes,
+StorageInfo::StorageInfo(std::vector<int64_t> storage_ids,
+ std::vector<VirtualDevice> virtual_devices,
std::vector<int64_t> storage_sizes_in_bytes) {
- ICHECK_EQ(storage_ids.size(), se_scopes.size());
+ ICHECK_EQ(storage_ids.size(), virtual_devices.size());
ICHECK_EQ(storage_ids.size(), storage_sizes_in_bytes.size());
auto node = make_object<StorageInfoNode>();
node->storage_ids = std::move(storage_ids);
- node->se_scopes = std::move(se_scopes);
+ node->virtual_devices = std::move(virtual_devices);
node->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes);
data_ = std::move(node);
}
@@ -74,17 +75,18 @@ TVM_REGISTER_GLOBAL("relay.ir.StorageInfo")
for (auto s : sids) {
sids_v.push_back(s);
}
- std::vector<SEScope> se_scopes_v;
- se_scopes_v.reserve(device_types.size());
+ std::vector<VirtualDevice> virtual_devices_v;
+ virtual_devices_v.reserve(device_types.size());
for (const auto& device_type : device_types) {
- se_scopes_v.emplace_back(SEScope::ForDeviceType(device_type));
+ virtual_devices_v.emplace_back(VirtualDevice::ForDeviceType(device_type));
}
std::vector<int64_t> size_in_bytes_v;
size_in_bytes_v.reserve(sizes_in_bytes.size());
for (auto s : sizes_in_bytes) {
size_in_bytes_v.push_back(s);
}
- return StorageInfo(std::move(sids_v), std::move(se_scopes_v), std::move(size_in_bytes_v));
+ return StorageInfo(std::move(sids_v), std::move(virtual_devices_v),
+ std::move(size_in_bytes_v));
});
TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageInfo si) {
@@ -98,8 +100,8 @@ TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageI
// This is the legacy interface for devices as DLDeviceTypes (represented by integers)
TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes").set_body_typed([](StorageInfo si) {
Array<tvm::Integer> device_types;
- for (const auto& se_scope : si->se_scopes) {
- device_types.push_back(se_scope->device_type());
+ for (const auto& virtual_device : si->virtual_devices) {
+ device_types.push_back(virtual_device->device_type());
}
return device_types;
});
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index 64f7c65..df25a86 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -31,7 +31,7 @@
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/target/codegen.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <tvm/te/operation.h>
#include <string>
@@ -62,8 +62,8 @@ class StorageInfoNode : public Object {
// TODO(mbs): Switch from struct-of-array to array-of-struct repr throughout.
/*! \brief The set of storage ids where the expression is stored. */
std::vector<int64_t> storage_ids;
- /* \brief The SEScopes these expressions are stored within. */
- std::vector<SEScope> se_scopes;
+ /* \brief The virtual devices these expressions are stored within. */
+ std::vector<VirtualDevice> virtual_devices;
/* \brief The sizes of each storage element, in bytes. */
std::vector<int64_t> storage_sizes_in_bytes;
@@ -77,7 +77,7 @@ class StorageInfoNode : public Object {
/*! \brief The storage information for a single expression. */
class StorageInfo : public ObjectRef {
public:
- StorageInfo(std::vector<int64_t> storage_ids, std::vector<SEScope> se_scopes,
+ StorageInfo(std::vector<int64_t> storage_ids, std::vector<VirtualDevice> virtual_devices,
std::vector<int64_t> storage_sizes_in_bytes);
TVM_DEFINE_OBJECT_REF_METHODS(StorageInfo, ObjectRef, StorageInfoNode);
};
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 23aee45..73f4b67 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -235,12 +235,12 @@ std::vector<int64_t> ToAllocTensorShape(NDArray shape) {
class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
public:
- VMFunctionCompiler(VMCompilerContext* context, SEScope host_se_scope)
+ VMFunctionCompiler(VMCompilerContext* context, VirtualDevice host_virtual_device)
: DeviceAwareExprFunctor(context->module),
last_register_(0),
registers_num_(0),
context_(context),
- host_se_scope_(std::move(host_se_scope)) {}
+ host_virtual_device_(std::move(host_virtual_device)) {}
VMFunction Compile(const GlobalVar& var, const Function& func) {
std::vector<Index> param_device_indexes;
@@ -252,21 +252,21 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
// Do that flattening on-the-fly here.
Function inner_func = Downcast<Function>(func->body);
std::vector<Var> params;
- std::vector<SEScope> param_se_scopes;
+ std::vector<VirtualDevice> param_virtual_devices;
params.reserve(func->params.size() + inner_func->params.size());
- param_se_scopes.reserve(func->params.size() + inner_func->params.size());
+ param_virtual_devices.reserve(func->params.size() + inner_func->params.size());
param_device_indexes.reserve(func->params.size() + inner_func->params.size());
for (size_t i = 0; i < func->params.size(); ++i) {
params.emplace_back(func->params[i]);
- SEScope param_se_scope = GetFunctionParamSEScope(func.get(), i);
- param_se_scopes.push_back(param_se_scope);
- param_device_indexes.push_back(GetDeviceIndex(param_se_scope));
+ VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(func.get(), i);
+ param_virtual_devices.push_back(param_virtual_device);
+ param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));
}
for (size_t i = 0; i < inner_func->params.size(); ++i) {
params.emplace_back(inner_func->params[i]);
- SEScope param_se_scope = GetFunctionParamSEScope(inner_func.get(), i);
- param_se_scopes.push_back(param_se_scope);
- param_device_indexes.push_back(GetDeviceIndex(param_se_scope));
+ VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(inner_func.get(), i);
+ param_virtual_devices.push_back(param_virtual_device);
+ param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));
}
std::vector<TypeVar> type_params;
type_params.reserve(func->type_params.size() + inner_func->type_params.size());
@@ -278,12 +278,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
}
Function flattened_func = Function(params, inner_func->body, inner_func->ret_type,
type_params, func->attrs, func->span);
- VisitExpr(MaybeFunctionOnDevice(flattened_func, param_se_scopes,
- GetFunctionResultSEScope(inner_func.get())));
+ VisitExpr(MaybeFunctionOnDevice(flattened_func, param_virtual_devices,
+ GetFunctionResultVirtualDevice(inner_func.get())));
} else {
param_device_indexes.reserve(func->params.size());
for (size_t i = 0; i < func->params.size(); ++i) {
- param_device_indexes.push_back(GetDeviceIndex(GetFunctionParamSEScope(func.get(), i)));
+ param_device_indexes.push_back(
+ GetDeviceIndex(GetFunctionParamVirtualDevice(func.get(), i)));
}
VisitExpr(func);
}
@@ -333,42 +334,44 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
}
/*!
- * \brief Returns the "device index" to represent \p se_scope for primitives
+ * \brief Returns the "device index" to represent \p virtual_device for primitives
* in emitted code. Note that the host device is always at index 0.
*/
- Index GetDeviceIndex(const SEScope& se_scope) {
- ICHECK(!se_scope->IsFullyUnconstrained());
- auto itr = std::find(context_->se_scopes_.begin(), context_->se_scopes_.end(), se_scope);
- if (itr != context_->se_scopes_.end()) {
- return std::distance(context_->se_scopes_.begin(), itr);
+ Index GetDeviceIndex(const VirtualDevice& virtual_device) {
+ ICHECK(!virtual_device->IsFullyUnconstrained());
+ auto itr = std::find(context_->virtual_devices_.begin(), context_->virtual_devices_.end(),
+ virtual_device);
+ if (itr != context_->virtual_devices_.end()) {
+ return std::distance(context_->virtual_devices_.begin(), itr);
}
- ICHECK_GT(context_->se_scopes_.size(), 0);
- ICHECK_NE(se_scope, host_se_scope_); // the host scope is always at index 0
+ ICHECK_GT(context_->virtual_devices_.size(), 0);
+ ICHECK_NE(virtual_device, host_virtual_device_); // the host scope is always at index 0
- if (se_scope->device_type() == context_->se_scopes_.front()->device_type()) {
+ if (virtual_device->device_type() == context_->virtual_devices_.front()->device_type()) {
// It's ok if we see distinct scopes which share the host device type. This is because
- // we allow the SEScope for the host to be different from the SEScope for primitive
- // operations which both happen to be on the same device (typically CPU).
+ // we allow the VirtualDevice for the host to be different from the VirtualDevice for
+ // primitive operations which both happen to be on the same device (typically CPU).
return 0;
}
- // However, otherwise we allow at most one SEScope per device type.
+ // However, otherwise we allow at most one VirtualDevice per device type.
// TODO(mbs): This will eventually need to account for memory scopes somehow so device_copy
// instructions can do the right thing.
- itr = std::find_if(context_->se_scopes_.begin() + 1, context_->se_scopes_.end(),
- [&se_scope](const SEScope& existing_se_scope) {
- return existing_se_scope->device_type() == se_scope->device_type();
+ itr = std::find_if(context_->virtual_devices_.begin() + 1, context_->virtual_devices_.end(),
+ [&virtual_device](const VirtualDevice& existing_virtual_device) {
+ return existing_virtual_device->device_type() ==
+ virtual_device->device_type();
});
- CHECK(itr == context_->se_scopes_.end())
+ CHECK(itr == context_->virtual_devices_.end())
<< "The VM does not currently support using more than one device with the same device type "
"for primitives, however the program is using the distinct scopes "
- << se_scope << " and " << *itr << " of device type " << se_scope->device_type();
+ << virtual_device << " and " << *itr << " of device type " << virtual_device->device_type();
- ICHECK(se_scope != host_se_scope_);
- Index index = context_->se_scopes_.size();
- VLOG(2) << "se_scope[" << index << "] = " << se_scope;
- context_->se_scopes_.push_back(se_scope);
+ ICHECK(virtual_device != host_virtual_device_);
+ Index index = context_->virtual_devices_.size();
+ VLOG(2) << "virtual_device[" << index << "] = " << virtual_device;
+ context_->virtual_devices_.push_back(virtual_device);
return index;
}
@@ -380,7 +383,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
NDArray data = const_node->data;
size_t const_index = context_->constants.size();
auto con = GetRef<Constant>(const_node);
- Index device_index = GetDeviceIndex(GetSEScope(con));
+ Index device_index = GetDeviceIndex(GetVirtualDevice(con));
VLOG(2) << "constant[" << const_index << "] on device[" << device_index << "]";
context_->const_device_indexes.push_back(device_index);
context_->constants.push_back(const_node->data);
@@ -542,8 +545,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
// TODO(mbs): device_copy cleanup.
VisitExpr(device_copy_props.body);
RegName src_reg = last_register_;
- Index src_index = GetDeviceIndex(device_copy_props.src_se_scope);
- Index dst_index = GetDeviceIndex(device_copy_props.dst_se_scope);
+ Index src_index = GetDeviceIndex(device_copy_props.src_virtual_device);
+ Index dst_index = GetDeviceIndex(device_copy_props.dst_virtual_device);
// Since scopes distinguish by targets (including any target hosts) but at runtime we
// deal only with devices, the copy may be unnecessary.
if (src_index != dst_index) {
@@ -619,7 +622,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
auto dtype = alloc_attrs->dtype;
Emit(Instruction::AllocStorage(size_register, alignment, dtype,
- GetDeviceIndex(alloc_attrs->se_scope),
+ GetDeviceIndex(alloc_attrs->virtual_device),
NewRegister()));
})
.Match("vm.shape_of",
@@ -819,8 +822,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
size_t registers_num_;
/*! \brief Global shared meta data */
VMCompilerContext* context_;
- /*! \brief SEScope for data and computation which must reside on a CPU. */
- SEScope host_se_scope_;
+ /*! \brief VirtualDevice for data and computation which must reside on a CPU. */
+ VirtualDevice host_virtual_device_;
};
PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
@@ -873,9 +876,9 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host)
config_ = CompilationConfig(PassContext::Current(), std::move(targets), std::move(target_host));
// The first device is always for the host.
- CHECK(context_.se_scopes_.empty());
- VLOG(2) << "se_scope[0] = " << config_->host_se_scope << " (host)";
- context_.se_scopes_.push_back(config_->host_se_scope);
+ CHECK(context_.virtual_devices_.empty());
+ VLOG(2) << "virtual_device[0] = " << config_->host_virtual_device << " (host)";
+ context_.virtual_devices_.push_back(config_->host_virtual_device);
// Run the optimizations necessary to target the VM.
context_.module = OptimizeModuleImpl(std::move(mod));
@@ -896,7 +899,7 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host)
continue;
}
auto func = GetRef<Function>(n);
- VMFunctionCompiler func_compiler(&context_, config_->host_se_scope);
+ VMFunctionCompiler func_compiler(&context_, config_->host_virtual_device);
auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context_.global_map.at(gvar);
@@ -911,12 +914,12 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host)
}
// Populate virtual devices and the host device index.
- for (const auto& se_scope : context_.se_scopes_) {
- ICHECK(!se_scope->IsFullyUnconstrained());
- ICHECK_GT(se_scope->device_type(), 0);
+ for (const auto& virtual_device : context_.virtual_devices_) {
+ ICHECK(!virtual_device->IsFullyUnconstrained());
+ ICHECK_GT(virtual_device->device_type(), 0);
// TODO(mbs): We forget the memory scope.
- exec_->virtual_devices.push_back(
- Device{/*device_type=*/se_scope->device_type(), /*device_id=*/se_scope->virtual_device_id});
+ exec_->virtual_devices.push_back(Device{/*device_type=*/virtual_device->device_type(),
+ /*device_id=*/virtual_device->virtual_device_id});
}
exec_->host_device_index = kHostDeviceIndex;
@@ -952,25 +955,25 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host)
}
}
-transform::Sequential VMCompiler::MemoryOpt(const SEScope& host_se_scope) {
+transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_device) {
Array<Pass> pass_seqs;
// Remove unused functions
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Manifest the allocations.
- pass_seqs.push_back(transform::ManifestAlloc(host_se_scope));
+ pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
// Fuse & lower any new shape functions and device_copies.
- pass_seqs.push_back(FuseAndLowerOperators(host_se_scope));
+ pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device));
// Manifest the allocations needed for the shape functions.
- pass_seqs.push_back(transform::ManifestAlloc(host_se_scope));
+ pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device));
// Fuse & lower any new allocations.
- pass_seqs.push_back(FuseAndLowerOperators(host_se_scope));
+ pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device));
// TODO(mbrookhart, jroesch, masahi): this pass is very slow, and is
// incomplete to provide memory resuse optimizations. Disable it until we can
@@ -982,10 +985,10 @@ transform::Sequential VMCompiler::MemoryOpt(const SEScope& host_se_scope) {
pass_seqs.push_back(transform::FoldConstant());
// Fuse & lower yet again
- pass_seqs.push_back(FuseAndLowerOperators(host_se_scope));
+ pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device));
// Create allocations for math introduced by dynamic region math.
- pass_seqs.push_back(transform::ManifestAlloc(host_se_scope));
+ pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device));
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());
@@ -998,7 +1001,7 @@ transform::Sequential VMCompiler::MemoryOpt(const SEScope& host_se_scope) {
return transform::Sequential(std::move(pass_seqs));
}
-transform::Sequential VMCompiler::FuseAndLowerOperators(const SEScope& host_se_scope) {
+transform::Sequential VMCompiler::FuseAndLowerOperators(const VirtualDevice& host_virtual_device) {
Array<Pass> pass_seqs;
// Hoist operators to "primitive" Functions.
pass_seqs.push_back(FuseOps());
@@ -1011,7 +1014,7 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const SEScope& host_se_s
backend::UpdateConstants(func, ¶ms_);
}
},
- host_se_scope));
+ host_virtual_device));
// Since lowered functions are bound in the IRModule, we can now eliminate any unused
// let-bound functions.
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
@@ -1022,8 +1025,8 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets,
const Target& target_host) {
config_ = CompilationConfig(PassContext::Current(), targets, target_host);
// The first device always corresponds to the host.
- CHECK(context_.se_scopes_.empty());
- context_.se_scopes_.push_back(config_->host_se_scope);
+ CHECK(context_.virtual_devices_.empty());
+ context_.virtual_devices_.push_back(config_->host_virtual_device);
// TODO(mbs): exec_ is not allocated. What is the API here?
CHECK(exec_ == nullptr);
return OptimizeModuleImpl(std::move(mod));
@@ -1082,13 +1085,13 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
backend::UpdateConstants(func, ¶ms_);
}
},
- config_->host_se_scope));
+ config_->host_virtual_device));
// Since lowered functions are bound in the IRModule, we can now eliminate any unused
// let-bound functions.
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
- // Now that we have PrimFuncs, flow and solve SEScope constraints again to account for
+ // Now that we have PrimFuncs, flow and solve VirtualDevice constraints again to account for
// any memory scopes which lowering has settled on.
pass_seqs.push_back(transform::PlanDevices(config_));
@@ -1099,7 +1102,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
// external codegen.
pass_seqs.push_back(transform::Inline());
- pass_seqs.push_back(MemoryOpt(config_->host_se_scope));
+ pass_seqs.push_back(MemoryOpt(config_->host_virtual_device));
pass_seqs.push_back(transform::InferType());
transform::Sequential seq(pass_seqs);
diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h
index b8dd9d6..906e514 100644
--- a/src/relay/backend/vm/compiler.h
+++ b/src/relay/backend/vm/compiler.h
@@ -80,8 +80,8 @@ struct VMCompilerContext {
std::vector<Index> const_device_indexes;
// Map from names of primitive functions already allocated to their primitive function index.
std::unordered_map<std::string, Index> primitive_map;
- // The SEScopes corresponding to each device index.
- std::vector<SEScope> se_scopes_;
+ // The virtual devices corresponding to each device index.
+ std::vector<VirtualDevice> virtual_devices_;
};
class VMCompiler : public runtime::ModuleNode {
@@ -136,8 +136,8 @@ class VMCompiler : public runtime::ModuleNode {
IRModule OptimizeModuleImpl(IRModule mod);
- transform::Sequential MemoryOpt(const SEScope& host_se_scope);
- transform::Sequential FuseAndLowerOperators(const SEScope& host_se_scope);
+ transform::Sequential MemoryOpt(const VirtualDevice& host_virtual_device);
+ transform::Sequential FuseAndLowerOperators(const VirtualDevice& host_virtual_device);
/*!
* \brief Populate the global function names in a map where the value is used
diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc
index ffd0e46..0457459 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -112,7 +112,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
auto free_type_vars = FreeTypeVars(func, module_);
Array<Var> captured_vars;
- std::vector<SEScope> captured_var_se_scopes;
+ std::vector<VirtualDevice> captured_var_virtual_devices;
bool recursive = false;
for (const auto& var : free_vars) {
if (!letrec_.empty() && var == letrec_.back()) {
@@ -120,7 +120,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
continue;
}
captured_vars.push_back(var);
- captured_var_se_scopes.push_back(GetSEScope(var));
+ captured_var_virtual_devices.push_back(GetVirtualDevice(var));
}
// Freshen all the captured vars.
@@ -132,7 +132,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
rebinding_map.Set(free_var, var);
}
- SEScope result_se_scope = GetSEScope(func_node->body);
+ VirtualDevice result_virtual_device = GetVirtualDevice(func_node->body);
if (recursive) {
if (!captured_vars.empty()) {
@@ -195,7 +195,8 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
lifted_func =
Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(),
free_type_vars, /*attrs=*/{}, func->span);
- lifted_func = MaybeFunctionOnDevice(lifted_func, captured_var_se_scopes, result_se_scope);
+ lifted_func =
+ MaybeFunctionOnDevice(lifted_func, captured_var_virtual_devices, result_virtual_device);
lifted_func = MarkClosure(lifted_func);
}
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index 18e83f9..b680a49 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -23,15 +23,15 @@
*/
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
namespace tvm {
-SEScope RelayExprNode::virtual_device() const {
+VirtualDevice RelayExprNode::virtual_device() const {
if (virtual_device_.defined()) {
- return Downcast<SEScope>(this->virtual_device_);
+ return Downcast<VirtualDevice>(this->virtual_device_);
}
- return SEScope::FullyUnconstrained();
+ return VirtualDevice::FullyUnconstrained();
}
namespace relay {
@@ -86,9 +86,9 @@ TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr>
return Tuple(fields, span);
});
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields,
- Optional<SEScope> opt_virtual_device, Optional<Span> opt_span) {
+ Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
Array<Expr> fields = opt_fields.value_or(tuple->fields);
- SEScope virtual_device = opt_virtual_device.value_or(tuple->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device());
Span span = opt_span.value_or(tuple->span);
bool all_fields_unchanged = true;
@@ -132,10 +132,10 @@ Var::Var(Id vid, Type type_annotation, Span span) {
}
Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation,
- Optional<SEScope> opt_virtual_device, Optional<Span> opt_span) {
+ Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
Id vid = opt_vid.value_or(var->vid);
Type type_annotation = opt_type_annotation.value_or(var->type_annotation);
- SEScope virtual_device = opt_virtual_device.value_or(var->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(var->virtual_device());
Span span = opt_span.value_or(var->span);
bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) &&
@@ -180,12 +180,12 @@ Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span s
Call WithFields(Call call, Optional<Expr> opt_op, Optional<Array<Expr>> opt_args,
Optional<Attrs> opt_attrs, Optional<Array<Type>> opt_type_args,
- Optional<SEScope> opt_virtual_device, Optional<Span> opt_span) {
+ Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
Expr op = opt_op.value_or(call->op);
Array<Expr> args = opt_args.value_or(call->args);
Attrs attrs = opt_attrs.value_or(call->attrs);
Array<Type> type_args = opt_type_args.value_or(call->type_args);
- SEScope virtual_device = opt_virtual_device.value_or(call->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(call->virtual_device());
Span span = opt_span.value_or(call->span);
bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span);
@@ -253,11 +253,11 @@ Let::Let(Var var, Expr value, Expr body, Span span) {
}
Let WithFields(Let let, Optional<Var> opt_var, Optional<Expr> opt_value, Optional<Expr> opt_body,
- Optional<SEScope> opt_virtual_device, Optional<Span> opt_span) {
+ Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
Var var = opt_var.value_or(let->var);
Expr value = opt_value.value_or(let->value);
Expr body = opt_body.value_or(let->body);
- SEScope virtual_device = opt_virtual_device.value_or(let->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(let->virtual_device());
Span span = opt_span.value_or(let->span);
bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) &&
@@ -296,12 +296,12 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) {
}
If WithFields(If if_expr, Optional<Expr> opt_cond, Optional<Expr> opt_true_branch,
- Optional<Expr> opt_false_branch, Optional<SEScope> opt_virtual_device,
+ Optional<Expr> opt_false_branch, Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
Expr cond = opt_cond.value_or(if_expr->cond);
Expr true_branch = opt_true_branch.value_or(if_expr->true_branch);
Expr false_branch = opt_false_branch.value_or(if_expr->false_branch);
- SEScope virtual_device = opt_virtual_device.value_or(if_expr->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(if_expr->virtual_device());
Span span = opt_span.value_or(if_expr->span);
bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) &&
@@ -341,11 +341,11 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) {
}
TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,
- Optional<Integer> opt_index, Optional<SEScope> opt_virtual_device,
+ Optional<Integer> opt_index, Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
Expr tuple = opt_tuple.value_or(tuple_get_item->tuple);
Integer index = opt_index.value_or(tuple_get_item->index);
- SEScope virtual_device = opt_virtual_device.value_or(tuple->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device());
Span span = opt_span.value_or(tuple_get_item->span);
bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) &&
@@ -380,9 +380,9 @@ RefCreate::RefCreate(Expr value, Span span) {
}
RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value,
- Optional<SEScope> opt_virtual_device, Optional<Span> opt_span) {
+ Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
Expr value = opt_value.value_or(ref_create->value);
- SEScope virtual_device = opt_virtual_device.value_or(ref_create->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(ref_create->virtual_device());
Span span = opt_span.value_or(ref_create->span);
bool unchanged = value.same_as(ref_create->value) && span.same_as(ref_create->span);
@@ -414,10 +414,10 @@ RefRead::RefRead(Expr ref, Span span) {
data_ = std::move(n);
}
-RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref, Optional<SEScope> opt_virtual_device,
- Optional<Span> opt_span) {
+RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref,
+ Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
Expr ref = opt_ref.value_or(ref_read->ref);
- SEScope virtual_device = opt_virtual_device.value_or(ref_read->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(ref_read->virtual_device());
Span span = opt_span.value_or(ref_read->span);
bool unchanged = ref.same_as(ref_read->ref) && span.same_as(ref_read->span);
@@ -449,10 +449,10 @@ RefWrite::RefWrite(Expr ref, Expr value, Span span) {
}
RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref, Optional<Expr> opt_value,
- Optional<SEScope> opt_virtual_device, Optional<Span> opt_span) {
+ Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) {
Expr ref = opt_ref.value_or(ref_write->ref);
Expr value = opt_value.value_or(ref_write->value);
- SEScope virtual_device = opt_virtual_device.value_or(ref_write->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(ref_write->virtual_device());
Span span = opt_span.value_or(ref_write->span);
bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) &&
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index a08de39..2d6f75a 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -478,11 +478,11 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
Array<Var> new_params;
- std::vector<SEScope> new_param_se_scopes;
+ std::vector<VirtualDevice> new_param_virtual_devices;
for (size_t i = 0; i < func->params.size(); ++i) {
if (!args_map.count(func->params[i])) {
new_params.push_back(func->params[i]);
- new_param_se_scopes.push_back(GetFunctionParamSEScope(func, i));
+ new_param_virtual_devices.push_back(GetFunctionParamVirtualDevice(func, i));
}
}
if (new_body.same_as(func->body) && new_params.size() == func->params.size()) {
@@ -490,7 +490,8 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
}
auto ret =
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
- ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func));
+ ret =
+ MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func));
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> set;
for (const auto& v : FreeVars(expr)) {
set.insert(v);
@@ -498,19 +499,20 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
for (const auto& v : FreeVars(ret)) {
if (set.count(v) == 0) {
new_params.push_back(v);
- if (!GetFunctionResultSEScope(func)->IsFullyUnconstrained()) {
+ if (!GetFunctionResultVirtualDevice(func)->IsFullyUnconstrained()) {
// TODO(mbs): The function has been annotated with a device, which means we are supposed
// to be preserving device annotations on every transformation. However there's no
// such context for the free vars in args_map.
LOG(WARNING) << "introduced free var '" << PrettyPrint(v)
<< "' into function body but no device is known for it";
}
- new_param_se_scopes.push_back(SEScope::FullyUnconstrained());
+ new_param_virtual_devices.push_back(VirtualDevice::FullyUnconstrained());
}
}
ret =
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
- ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func));
+ ret =
+ MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func));
ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
return std::move(ret);
} else {
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index 4c5b867..4330540 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -42,14 +42,14 @@ Function::Function(tvm::Array<Var> params, Expr body, Type ret_type,
Function WithFields(Function function, Optional<Array<Var>> opt_params, Optional<Expr> opt_body,
Optional<Type> opt_ret_type, Optional<Array<TypeVar>> opt_ty_params,
- Optional<DictAttrs> opt_attrs, Optional<SEScope> opt_virtual_device,
+ Optional<DictAttrs> opt_attrs, Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
Array<Var> params = opt_params.value_or(function->params);
Expr body = opt_body.value_or(function->body);
Type ret_type = opt_ret_type.value_or(function->ret_type);
Array<TypeVar> ty_params = opt_ty_params.value_or(function->type_params);
DictAttrs attrs = opt_attrs.value_or(function->attrs);
- SEScope virtual_device = opt_virtual_device.value_or(function->virtual_device());
+ VirtualDevice virtual_device = opt_virtual_device.value_or(function->virtual_device());
Span span = opt_span.value_or(function->span);
bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) &&
diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc
index 690ad4a..a59e25c 100644
--- a/src/relay/op/memory/device_copy.cc
+++ b/src/relay/op/memory/device_copy.cc
@@ -50,12 +50,12 @@ const Op& DeviceCopyOp() {
return op;
}
-Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) {
- ICHECK(!src_se_scope->IsFullyUnconstrained());
- ICHECK(!dst_se_scope->IsFullyUnconstrained());
+Expr DeviceCopy(Expr expr, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device) {
+ ICHECK(!src_virtual_device->IsFullyUnconstrained());
+ ICHECK(!dst_virtual_device->IsFullyUnconstrained());
auto attrs = make_object<DeviceCopyAttrs>();
- attrs->src_se_scope = std::move(src_se_scope);
- attrs->dst_se_scope = std::move(dst_se_scope);
+ attrs->src_virtual_device = std::move(src_virtual_device);
+ attrs->dst_virtual_device = std::move(dst_virtual_device);
Span span = expr->span;
return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{},
std::move(span));
@@ -63,12 +63,13 @@ Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) {
TVM_REGISTER_GLOBAL("relay.op._make.DeviceCopy").set_body_typed(DeviceCopy);
-Expr MaybeDeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) {
- if (src_se_scope == dst_se_scope) {
+Expr MaybeDeviceCopy(Expr expr, VirtualDevice src_virtual_device,
+ VirtualDevice dst_virtual_device) {
+ if (src_virtual_device == dst_virtual_device) {
// No copy needed.
return expr;
}
- return DeviceCopy(std::move(expr), std::move(src_se_scope), std::move(dst_se_scope));
+ return DeviceCopy(std::move(expr), std::move(src_virtual_device), std::move(dst_virtual_device));
}
RELAY_REGISTER_OP("device_copy")
@@ -98,13 +99,14 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) {
const auto* device_copy_attrs = call_node->attrs.as<DeviceCopyAttrs>();
ICHECK(device_copy_attrs != nullptr) << "device_copy requires DeviceCopyAttrs";
// Follow nesting:
- // device_copy(device_copy(expr, src_se_scope=S, dst_se_scope=T),
- // src_se_scope=T, dst_se_scope=U) ==> {expr, S, U}
+ // device_copy(device_copy(expr, src_virtual_device=S, dst_virtual_device=T),
+ // src_virtual_device=T, dst_virtual_device=U) ==> {expr, S, U}
auto inner = GetDeviceCopyProps(call_node->args[0]);
if (inner.body.defined()) {
- return {inner.body, inner.src_se_scope, device_copy_attrs->dst_se_scope};
+ return {inner.body, inner.src_virtual_device, device_copy_attrs->dst_virtual_device};
} else {
- return {call_node->args[0], device_copy_attrs->src_se_scope, device_copy_attrs->dst_se_scope};
+ return {call_node->args[0], device_copy_attrs->src_virtual_device,
+ device_copy_attrs->dst_virtual_device};
}
}
return {};
diff --git a/src/relay/op/memory/device_copy.h b/src/relay/op/memory/device_copy.h
index 728deb7..bb74324 100644
--- a/src/relay/op/memory/device_copy.h
+++ b/src/relay/op/memory/device_copy.h
@@ -40,42 +40,41 @@ const Op& DeviceCopyOp();
/*!
* \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated and
- * stored at \p src_se_scope but then copied to \p dst_se_scope.
+ * stored at \p src_virtual_device but then copied to \p dst_virtual_device.
*/
-Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope);
+Expr DeviceCopy(Expr expr, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device);
/*!
* \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated and
- * stored at \p src_se_scope but then copied to \p dst_se_scope.However, return \p expr
- * directly if \p src_se_scope and \p dst_se_scope are (structurally) the same.
+ * stored at \p src_virtual_device but then copied to \p dst_virtual_device.However, return \p expr
+ * directly if \p src_virtual_device and \p dst_virtual_device are (structurally) the same.
*/
-Expr MaybeDeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope);
+Expr MaybeDeviceCopy(Expr expr, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device);
/*! \brief Result of \p GetDeviceCopyProps. */
struct DeviceCopyProps {
Expr body; // = null
- SEScope src_se_scope = SEScope::FullyUnconstrained();
- SEScope dst_se_scope = SEScope::FullyUnconstrained();
+ VirtualDevice src_virtual_device = VirtualDevice::FullyUnconstrained();
+ VirtualDevice dst_virtual_device = VirtualDevice::FullyUnconstrained();
DeviceCopyProps() = default;
- DeviceCopyProps(Expr body, SEScope src_se_scope, SEScope dst_se_scope)
+ DeviceCopyProps(Expr body, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device)
: body(std::move(body)),
- src_se_scope(std::move(src_se_scope)),
- dst_se_scope(std::move(dst_se_scope)) {}
+ src_virtual_device(std::move(src_virtual_device)),
+ dst_virtual_device(std::move(dst_virtual_device)) {}
};
/*!
- * \brief Returns the body expression, source, and destination \p SEScopes for \p call_node
+ * \brief Returns the body expression, source, and destination \p VirtualDevices for \p call_node
* if it is a "device_copy" CallNode. Otherwise returns the null expression and unconstrained
- * device and scopes.
+ * virtual device.
*/
DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node);
/*!
- * \brief Returns the body expression, source, and destination \p SEScopes for \p expr if it
- * is a "device_copy" Call. Otherwise returns the null expression and unconstrained device and
- * scopes.
+ * \brief Returns the body expression, source, and destination \p VirtualDevices for \p expr if it
+ * is a "device_copy" Call. Otherwise returns the null expression and unconstrained virtual device.
*/
DeviceCopyProps GetDeviceCopyProps(const Expr& expr);
diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc
index 315ad9c..b546bd5 100644
--- a/src/relay/op/memory/memory.cc
+++ b/src/relay/op/memory/memory.cc
@@ -50,10 +50,10 @@ TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);
// The passing value in attrs and args doesn't seem super great.
// We should consider a better solution, i.e the type relation
// being able to see the arguments as well?
-Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint) {
+Expr AllocStorage(Expr size, Expr alignment, VirtualDevice virtual_device, DataType dtype_hint) {
auto attrs = make_object<AllocStorageAttrs>();
attrs->dtype = dtype_hint;
- attrs->se_scope = std::move(se_scope);
+ attrs->virtual_device = std::move(virtual_device);
static const Op& op = Op::Get("memory.alloc_storage");
return Call(op, {std::move(size), std::move(alignment)}, Attrs(std::move(attrs)), {});
}
diff --git a/src/relay/op/memory/memory.h b/src/relay/op/memory/memory.h
index 9e93afd..690854c 100644
--- a/src/relay/op/memory/memory.h
+++ b/src/relay/op/memory/memory.h
@@ -25,7 +25,7 @@
#ifndef TVM_RELAY_OP_MEMORY_MEMORY_H_
#define TVM_RELAY_OP_MEMORY_MEMORY_H_
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <vector>
@@ -34,7 +34,7 @@
namespace tvm {
namespace relay {
-Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint);
+Expr AllocStorage(Expr size, Expr alignment, VirtualDevice virtual_device, DataType dtype_hint);
/*! \brief Returns the "memory.alloc_tensor" operator. */
const Op& MemoryAllocTensorOp();
Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype,
diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc
index 0fd86d3..48e93cc 100644
--- a/src/relay/op/memory/on_device.cc
+++ b/src/relay/op/memory/on_device.cc
@@ -43,11 +43,12 @@ const Op& OnDeviceOp() {
return op;
}
-Call OnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) {
- ICHECK((!constrain_result && !constrain_body) || !se_scope->IsFullyUnconstrained());
+Call OnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result, bool constrain_body) {
+ ICHECK((!constrain_result && !constrain_body) || !virtual_device->IsFullyUnconstrained());
auto attrs = make_object<OnDeviceAttrs>();
- attrs->se_scope =
- (constrain_result || constrain_body) ? std::move(se_scope) : SEScope::FullyUnconstrained();
+ attrs->virtual_device = (constrain_result || constrain_body)
+ ? std::move(virtual_device)
+ : VirtualDevice::FullyUnconstrained();
attrs->constrain_result = constrain_result;
attrs->constrain_body = constrain_body;
Span span = body->span; // about to be moved
@@ -57,8 +58,9 @@ Call OnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain
TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice);
-Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) {
- if (se_scope->IsFullyUnconstrained()) {
+Expr MaybeOnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result,
+ bool constrain_body) {
+ if (virtual_device->IsFullyUnconstrained()) {
// Nothing to annotate with.
return body;
}
@@ -72,40 +74,40 @@ Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool cons
}
if (body->IsInstance<FunctionNode>()) {
// If a primitive function then it is device polymorphic. Otherwise the device is captured
- // by the function's "result_se_scope" attribute.
+ // by the function's "result_virtual_device" attribute.
return body;
}
OnDeviceProps props = GetOnDeviceProps(body);
if (props.body.defined()) {
// The user is asking for
- // on_device(on_device(body, se_scope=inner), se_scope=outer)
+ // on_device(on_device(body, virtual_device=inner), virtual_device=outer)
// ^ ^ ^
// outer middle inner
// First recover the implied constraints (if any) for outer and inner, and check they don't
// contradict.
- const SEScope& inner = props.se_scope;
- const SEScope& outer = se_scope;
+ const VirtualDevice& inner = props.virtual_device;
+ const VirtualDevice& outer = virtual_device;
bool constrain_outer = constrain_result;
bool constrain_inner = props.constrain_body;
if (constrain_outer && constrain_inner) {
- ICHECK(inner == outer)
- << "Cannot constrain result and body of nested on_device calls to different SEScopes";
+ ICHECK(inner == outer) << "Cannot constrain result and body of nested on_device calls to "
+ "different virtual devices";
}
// There are two possible ways the middle sub-expression may be constrained, check they don't
// contradict.
bool constrain_middle_via_outer = constrain_body;
bool constrain_middle_via_inner = props.constrain_result;
if (constrain_middle_via_outer && constrain_middle_via_inner) {
- ICHECK(inner == outer)
- << "Cannot constrain intermediate result of nested on_device calls to different SEScopes";
+ ICHECK(inner == outer) << "Cannot constrain intermediate result of nested on_device calls to "
+ "different virtual devices";
}
// We can now ignore the middle constraint.
- // If the outer on_device has any constraint then use se_scope given for it.
- // Otherwise we can use the existing inner se_scope.
+ // If the outer on_device has any constraint then use virtual_device given for it.
+ // Otherwise we can use the existing inner virtual_device.
return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner,
constrain_outer, constrain_inner);
} else {
- return OnDevice(body, std::move(se_scope), constrain_result, constrain_body);
+ return OnDevice(body, std::move(virtual_device), constrain_result, constrain_body);
}
}
@@ -127,7 +129,7 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node) {
ICHECK(call_node->attrs.defined()) << "on_device requires attributes";
const auto* on_device_attrs = call_node->attrs.as<OnDeviceAttrs>();
ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs";
- return {call_node->args[0], on_device_attrs->se_scope, on_device_attrs->constrain_result,
+ return {call_node->args[0], on_device_attrs->virtual_device, on_device_attrs->constrain_result,
on_device_attrs->constrain_body};
}
return {};
@@ -140,38 +142,42 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {
return {};
}
-Function FunctionOnDevice(Function function, Array<SEScope> param_se_scopes,
- SEScope result_se_scope) {
- return WithAttrs(std::move(function), {{tvm::attr::kParamSEScopes, std::move(param_se_scopes)},
- {tvm::attr::kResultSEScope, std::move(result_se_scope)}});
+Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
+ VirtualDevice result_virtual_device) {
+ return WithAttrs(std::move(function),
+ {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)},
+ {tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}});
}
TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice);
-Function MaybeFunctionOnDevice(Function function, Array<SEScope> param_se_scopes,
- SEScope result_se_scope) {
- if (std::all_of(param_se_scopes.begin(), param_se_scopes.end(),
- [](const SEScope& se_scope) { return se_scope->IsFullyUnconstrained(); }) &&
- result_se_scope->IsFullyUnconstrained()) {
+Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
+ VirtualDevice result_virtual_device) {
+ if (std::all_of(param_virtual_devices.begin(), param_virtual_devices.end(),
+ [](const VirtualDevice& virtual_device) {
+ return virtual_device->IsFullyUnconstrained();
+ }) &&
+ result_virtual_device->IsFullyUnconstrained()) {
// Nothing to annotate.
return function;
}
- return FunctionOnDevice(function, std::move(param_se_scopes), std::move(result_se_scope));
+ return FunctionOnDevice(function, std::move(param_virtual_devices),
+ std::move(result_virtual_device));
}
-SEScope GetFunctionResultSEScope(const FunctionNode* function_node) {
- auto opt_se_scope = function_node->GetAttr<SEScope>(tvm::attr::kResultSEScope);
- return opt_se_scope.value_or(SEScope::FullyUnconstrained());
+VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) {
+ auto opt_virtual_device = function_node->GetAttr<VirtualDevice>(tvm::attr::kResultVirtualDevice);
+ return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained());
}
-SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i) {
+VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) {
ICHECK_LT(i, function_node->params.size())
<< "param index " << i << " out of range for function of arity "
<< function_node->params.size();
- auto opt_array = function_node->GetAttr<Array<SEScope>>(tvm::attr::kParamSEScopes);
+ auto opt_array = function_node->GetAttr<Array<VirtualDevice>>(tvm::attr::kParamVirtualDevice);
if (!opt_array) {
// No annotation.
- return SEScope::FullyUnconstrained();
+ return VirtualDevice::FullyUnconstrained();
}
ICHECK_EQ(opt_array.value().size(), function_node->params.size())
<< "annotation parameters do not match function arity";
diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h
index 2ebaf03..7489e3b 100644
--- a/src/relay/op/memory/on_device.h
+++ b/src/relay/op/memory/on_device.h
@@ -39,25 +39,25 @@ namespace relay {
const Op& OnDeviceOp();
/*!
- * \brief Wraps \p body in an "on_device" CallNode for \p se_scope.
+ * \brief Wraps \p body in an "on_device" CallNode for \p virtual_device.
*
* See \p OnDeviceAttrs for an overview.
*/
-Call OnDevice(Expr body, SEScope se_scope, bool constrain_result = false,
+Call OnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result = false,
bool constrain_body = true);
/*! \brief Result of \p GetOnDeviceProps. */
struct OnDeviceProps {
Expr body; // = null
- SEScope se_scope = SEScope::FullyUnconstrained();
+ VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained();
bool constrain_result = false;
bool constrain_body = false;
OnDeviceProps() = default;
- OnDeviceProps(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body)
+ OnDeviceProps(Expr body, VirtualDevice virtual_device, bool constrain_result, bool constrain_body)
: body(std::move(body)),
- se_scope(std::move(se_scope)),
+ virtual_device(std::move(virtual_device)),
constrain_result(constrain_result),
constrain_body(constrain_body) {}
@@ -70,7 +70,8 @@ struct OnDeviceProps {
* props.
*/
inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) {
- return OnDevice(std::move(body), props.se_scope, props.constrain_result, props.constrain_body);
+ return OnDevice(std::move(body), props.virtual_device, props.constrain_result,
+ props.constrain_body);
}
/*!
@@ -80,50 +81,50 @@ inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) {
* choices.
*/
inline Call OnDeviceCopyOk(Expr body) {
- return OnDevice(std::move(body), SEScope::FullyUnconstrained(),
+ return OnDevice(std::move(body), VirtualDevice::FullyUnconstrained(),
/*constrain_result=*/false, /*constrain_body=*/false);
}
/*!
- * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p constraint if the
- * \p SEScope for \p expr cannot otherwise be recovered by the lexical scoping convention.
+ * \brief Wraps \p expr in an "on_device" CallNode for \p virtual_device and \p constraint if the
+ * \p VirtualDevice for \p expr cannot otherwise be recovered by the lexical scoping convention.
* This means we will NOT wrap if:
- * - \p se_scope is full unconstrained, which signals there are no device annotations
+ * - \p virtual_device is full unconstrained, which signals there are no device annotations
* already in play.
* - \p expr is an operator or primitive function literal. These are device polymorphic.
* - \p expr is a non-primitive function literal. The device is captured by the
- * "result_se_scope" attribute on the function itself.
+ * "result_virtual_device" attribute on the function itself.
* - \p expr is a global var. The device is on the function attributes the global is bound to.
* - \p expr is a local var. The device is tracked by the device aware visitors for us.
* - \p expr is a constructor. These are device polymorphic.
* Nested on_device calls will never be constructed, they are instead merged on-the-fly.
*/
-Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result = false,
+Expr MaybeOnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result = false,
bool constrain_body = true);
/*! \brief As for MaybeOnDevice, but with both body and result constrained. */
-inline Expr MaybeOnDeviceFixed(Expr body, SEScope se_scope) {
- return MaybeOnDevice(std::move(body), std::move(se_scope), /*constrain_result=*/true,
+inline Expr MaybeOnDeviceFixed(Expr body, VirtualDevice virtual_device) {
+ return MaybeOnDevice(std::move(body), std::move(virtual_device), /*constrain_result=*/true,
/*constrain_body=*/true);
}
/*! \brief As for MaybeOnDevice, but with fields other than body taken from \p props. */
inline Expr MaybeOnDeviceWithProps(Expr body, const OnDeviceProps& props) {
- return MaybeOnDevice(std::move(body), props.se_scope, props.constrain_result,
+ return MaybeOnDevice(std::move(body), props.virtual_device, props.constrain_result,
props.constrain_body);
}
/*!
- * \brief Returns the body expression, \p SEScope, and constraint field for \p call_node if it
+ * \brief Returns the body expression, \p VirtualDevice, and constraint field for \p call_node if it
* is an "on_device" CallNode. Otherwise returns the null expression, the unconstrained
- * \p SEScope, and \p kBody.
+ * \p VirtualDevice, and \p kBody.
*/
OnDeviceProps GetOnDeviceProps(const CallNode* call_node);
/*!
- * \brief Returns the body expression, \p SEScope, and constraint field for \p expr if it is an
- * "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p SEScope,
- * and \p kBody.
+ * \brief Returns the body expression, \p VirtualDevice, and constraint field for \p expr if it is
+ * an "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p
+ * VirtualDevice, and \p kBody.
*/
OnDeviceProps GetOnDeviceProps(const Expr& expr);
@@ -154,29 +155,31 @@ const NodeType* AsIgnoringOnDevice(const Expr& expr) {
}
/*!
- * \brief Returns \p function annotated with "param_se_scopes" and "result_se_scope"
- * attributes capturing parameter and result \p SEScopes respectively.
+ * \brief Returns \p function annotated with "param_virtual_devices" and "result_virtual_device"
+ * attributes capturing parameter and result \p VirtualDevices respectively.
*/
-Function FunctionOnDevice(Function function, Array<SEScope> param_se_scopes, SEScope body_se_scope);
+Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
+ VirtualDevice body_virtual_device);
/*!
* \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and
- * result \p SEScopes are unconstrained.
+ * result \p VirtualDevices are unconstrained.
*/
-Function MaybeFunctionOnDevice(Function function, Array<SEScope> param_se_scopes,
- SEScope result_se_scope);
+Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
+ VirtualDevice result_virtual_device);
/*!
- * \brief Returns the \p SEScope for the resut of \p function_node, or the unconstrained
- * \p SEScope if function does not have the "result_se_scope" annotation.
+ * \brief Returns the \p VirtualDevice for the resut of \p function_node, or the unconstrained
+ * \p VirtualDevice if function does not have the "result_virtual_device" annotation.
*/
-SEScope GetFunctionResultSEScope(const FunctionNode* function_node);
+VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node);
/*!
- * \brief Returns the \p SEScope for the \p i'th parameter of \p function_node, or
- * the unconstrained \p SEScope if function does not have the "param_se_scopes" annotation.
+ * \brief Returns the \p VirtualDevice for the \p i'th parameter of \p function_node, or
+ * the unconstrained \p VirtualDevice if function does not have the "param_virtual_devices"
+ * annotation.
*/
-SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i);
+VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i);
} // namespace relay
} // namespace tvm
diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc
index 29965d2..10584da 100644
--- a/src/relay/transforms/device_aware_visitors.cc
+++ b/src/relay/transforms/device_aware_visitors.cc
@@ -38,52 +38,52 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional<IRModule>& maybe_mod)
if (maybe_mod) {
for (const auto& kv : maybe_mod.value()->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
- SEScope se_scope = GetFunctionResultSEScope(function_node);
- if (!se_scope->IsFullyUnconstrained()) {
- VLOG(2) << "global '" << kv.first->name_hint << "' has scope " << se_scope;
- global_var_se_scopes_.emplace(kv.first, se_scope);
+ VirtualDevice virtual_device = GetFunctionResultVirtualDevice(function_node);
+ if (!virtual_device->IsFullyUnconstrained()) {
+ VLOG(2) << "global '" << kv.first->name_hint << "' has virtual device " << virtual_device;
+ global_var_virtual_devices_.emplace(kv.first, virtual_device);
}
}
}
}
}
-SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const {
+VirtualDevice LexicalOnDeviceMixin::GetVirtualDevice(const Expr& expr) const {
OnDeviceProps props = GetOnDeviceProps(expr);
if (props.body.defined() && props.is_fixed()) {
- return props.se_scope;
+ return props.virtual_device;
} else if (const auto* var_node = expr.as<VarNode>()) {
// Lookup variable binding.
- auto itr = var_se_scopes_.find(GetRef<Var>(var_node));
- if (itr != var_se_scopes_.end()) {
+ auto itr = var_virtual_devices_.find(GetRef<Var>(var_node));
+ if (itr != var_virtual_devices_.end()) {
return itr->second;
}
// else: fallthrough to unconstrained
} else if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
// Lookup global variable.
- auto itr = global_var_se_scopes_.find(GetRef<GlobalVar>(global_var_node));
- if (itr != global_var_se_scopes_.end()) {
+ auto itr = global_var_virtual_devices_.find(GetRef<GlobalVar>(global_var_node));
+ if (itr != global_var_virtual_devices_.end()) {
return itr->second;
}
// else: fallthrough to unconstrained
} else if (const auto* function_node = expr.as<FunctionNode>()) {
if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
- if (!expr_se_scopes_.empty()) {
+ if (!expr_virtual_devices_.empty()) {
// Use the currently in-scope device type.
- return expr_se_scopes_.back();
+ return expr_virtual_devices_.back();
}
// else: fallthrough to unconstrained
} else {
- return GetFunctionResultSEScope(function_node);
+ return GetFunctionResultVirtualDevice(function_node);
}
} else {
- if (!expr_se_scopes_.empty()) {
+ if (!expr_virtual_devices_.empty()) {
// Use the currently in-scope device type.
- return expr_se_scopes_.back();
+ return expr_virtual_devices_.back();
}
// else: fallthrough to unconstrained
}
- return SEScope::FullyUnconstrained();
+ return VirtualDevice::FullyUnconstrained();
}
void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; }
@@ -93,34 +93,34 @@ void LexicalOnDeviceMixin::ExitFunctionBody() {
--function_nesting_;
}
-void LexicalOnDeviceMixin::PushSEScope(const SEScope& se_scope) {
- if (se_scope->IsFullyUnconstrained()) {
+void LexicalOnDeviceMixin::PushVirtualDevice(const VirtualDevice& virtual_device) {
+ if (virtual_device->IsFullyUnconstrained()) {
return;
}
- expr_se_scopes_.emplace_back(se_scope);
+ expr_virtual_devices_.emplace_back(virtual_device);
}
-void LexicalOnDeviceMixin::PopSEScope() {
- if (expr_se_scopes_.empty()) {
+void LexicalOnDeviceMixin::PopVirtualDevice() {
+ if (expr_virtual_devices_.empty()) {
return;
}
- expr_se_scopes_.pop_back();
+ expr_virtual_devices_.pop_back();
}
-void LexicalOnDeviceMixin::PushBoundVar(Var var, const SEScope& se_scope) {
- if (se_scope->IsFullyUnconstrained()) {
+void LexicalOnDeviceMixin::PushBoundVar(Var var, const VirtualDevice& virtual_device) {
+ if (virtual_device->IsFullyUnconstrained()) {
return;
}
- ICHECK(var_se_scopes_.find(var) == var_se_scopes_.end());
- var_se_scopes_.emplace(std::move(var), se_scope);
+ ICHECK(var_virtual_devices_.find(var) == var_virtual_devices_.end());
+ var_virtual_devices_.emplace(std::move(var), virtual_device);
}
void LexicalOnDeviceMixin::PopBoundVar(const Var& var) {
- auto itr = var_se_scopes_.find(var);
- if (itr == var_se_scopes_.end()) {
+ auto itr = var_virtual_devices_.find(var);
+ if (itr == var_virtual_devices_.end()) {
return;
}
- var_se_scopes_.erase(itr);
+ var_virtual_devices_.erase(itr);
}
// TODO(mbs): We'd probably have less tedious code duplication if we redefined the memoizing
@@ -133,17 +133,17 @@ void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) {
} else {
// Function parameters come into scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
- PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i));
+ PushBoundVar(function_node->params[i], GetFunctionParamVirtualDevice(function_node, i));
}
// Entering scope of function body.
- PushSEScope(GetFunctionResultSEScope(function_node));
+ PushVirtualDevice(GetFunctionResultVirtualDevice(function_node));
EnterFunctionBody();
DeviceAwareVisitExpr_(function_node);
// Leaving scope of function body.
ExitFunctionBody();
- PopSEScope();
+ PopVirtualDevice();
// Function parameters go out of scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
PopBoundVar(function_node->params[i]);
@@ -158,7 +158,7 @@ void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) {
while (const auto* inner_let_node = expr.as<LetNode>()) {
// Let-bound var (in pre visited version) goes into scope.
// (We'll just assume this is a letrec).
- PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value));
+ PushBoundVar(inner_let_node->var, GetVirtualDevice(inner_let_node->value));
PreVisitLetBinding_(inner_let_node->var, inner_let_node->value);
bindings.emplace_back(inner_let_node);
expr = inner_let_node->body;
@@ -178,10 +178,10 @@ void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) {
OnDeviceProps props = GetOnDeviceProps(call_node);
if (props.body.defined() && props.is_fixed()) {
// Entering lexical scope of fixed "on_device" call.
- PushSEScope(props.se_scope);
+ PushVirtualDevice(props.virtual_device);
VisitExpr(props.body);
// Leaving lexical scope of "on_device" call.
- PopSEScope();
+ PopVirtualDevice();
} else {
DeviceAwareVisitExpr_(call_node);
}
@@ -219,17 +219,17 @@ Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) {
} else {
// Function parameters come into scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
- PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i));
+ PushBoundVar(function_node->params[i], GetFunctionParamVirtualDevice(function_node, i));
}
// Entering scope of function body.
- PushSEScope(GetFunctionResultSEScope(function_node));
+ PushVirtualDevice(GetFunctionResultVirtualDevice(function_node));
EnterFunctionBody();
Expr result = DeviceAwareVisitExpr_(function_node);
// Leaving scope of function body.
ExitFunctionBody();
- PopSEScope();
+ PopVirtualDevice();
// Function parameters go out of scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
PopBoundVar(function_node->params[i]);
@@ -246,7 +246,7 @@ Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) {
while (const auto* inner_let_node = expr.as<LetNode>()) {
// Let-bound var (in pre visited version) goes into scope.
// (We'll just assume this is a letrec.)
- PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value));
+ PushBoundVar(inner_let_node->var, GetVirtualDevice(inner_let_node->value));
std::pair<Var, Expr> pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value);
bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node);
expr = inner_let_node->body;
@@ -269,10 +269,10 @@ Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) {
OnDeviceProps props = GetOnDeviceProps(call_node);
if (props.body.defined() && props.is_fixed()) {
// Entering lexical scope of fixed "on_device" call.
- PushSEScope(props.se_scope);
+ PushVirtualDevice(props.virtual_device);
Expr expr = VisitExpr(props.body);
// Leaving lexical scope of "on_device" call.
- PopSEScope();
+ PopVirtualDevice();
return MaybeOnDeviceWithProps(expr, props);
} else {
return DeviceAwareVisitExpr_(call_node);
diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h
index 044cda8..9340c03 100644
--- a/src/relay/transforms/device_aware_visitors.h
+++ b/src/relay/transforms/device_aware_visitors.h
@@ -42,7 +42,7 @@ namespace relay {
namespace transform {
/*!
- * \brief Helper class for expression transformers which need to keep track of the \p SEScope
+ * \brief Helper class for expression transformers which need to keep track of the \p VirtualDevice
* holding the results of expressions. This is recovered from function attributes and "on_device"
* CallNodes added by the PlanDevices pass.
*
@@ -53,11 +53,11 @@ class LexicalOnDeviceMixin {
explicit LexicalOnDeviceMixin(const Optional<IRModule>& maybe_mod);
/*!
- * \brief Returns the \p SEScope on which the result of \p expr should/will be stored, assuming
- * {Push,Pop}{SEScope,BoundVar} have been correctly called. May return the unconstrained
- * \p SEScope if the device planning pass has not been run.
+ * \brief Returns the \p VirtualDevice on which the result of \p expr should/will be stored,
+ * assuming {Push,Pop}{VirtualDevice,BoundVar} have been correctly called. May return the
+ * unconstrained \p VirtualDevice if the device planning pass has not been run.
*/
- SEScope GetSEScope(const Expr& expr) const;
+ VirtualDevice GetVirtualDevice(const Expr& expr) const;
/*! \brief Indicate a function body is being entered. */
void EnterFunctionBody();
@@ -65,19 +65,21 @@ class LexicalOnDeviceMixin {
/*! \brief Indicate a function body has been processed. */
void ExitFunctionBody();
- /*! \brief Push an \p SEScope onto the lexical SEScope stack. Ignore if unconstrained. */
- void PushSEScope(const SEScope& se_scope);
+ /*! \brief Push an \p VirtualDevice onto the lexical VirtualDevice stack. Ignore if unconstrained.
+ */
+ void PushVirtualDevice(const VirtualDevice& virtual_device);
- /*! \brief Pop an \p SEScope from the lexical SEScope stack. Ignore if stack is empty. */
- void PopSEScope();
+ /*! \brief Pop an \p VirtualDevice from the lexical VirtualDevice stack. Ignore if stack is empty.
+ */
+ void PopVirtualDevice();
- /*! \brief Remember that \p var will be stored at \p se_scope. Ignore if unconstrained.
+ /*! \brief Remember that \p var will be stored at \p virtual_device. Ignore if unconstrained.
*
* CAUTION: Despite the name we don't support re-entering the same function body.
*/
- void PushBoundVar(Var var, const SEScope& se_scope);
+ void PushBoundVar(Var var, const VirtualDevice& virtual_device);
- /*! \brief Remove the binding for \p var to its \p SEScope. Ignore if var is not bound. */
+ /*! \brief Remove the binding for \p var to its \p VirtualDevice. Ignore if var is not bound. */
void PopBoundVar(const Var& var);
/*!
@@ -93,36 +95,37 @@ class LexicalOnDeviceMixin {
int function_nesting_ = 0;
/*!
- * \brief The stack of lexically enclosing "on_device" \p SEScopes, from outermost to
+ * \brief The stack of lexically enclosing "on_device" \p VirtualDevices, from outermost to
* innermost. When visiting an expression other than a variable we can assume the expression's
- * result is to be stored on \p expr_se_scopes.back().
+ * result is to be stored on \p expr_virtual_devices.back().
*/
- std::vector<SEScope> expr_se_scopes_;
+ std::vector<VirtualDevice> expr_virtual_devices_;
/*!
- * \brief A map from in-scope local variables to their \p SEScopes. We may assume the variable is
- * only ever bound to a value stored on this \p SEScope at runtime.
+ * \brief A map from in-scope local variables to their \p VirtualDevices. We may assume the
+ * variable is only ever bound to a value stored on this \p VirtualDevice at runtime.
*
* Note: We're playing it safe and keying by object refs here just in case the Relay expression
* being rewritten has no module or other global to keep it alive.
*/
- std::unordered_map<Var, SEScope, runtime::ObjectPtrHash, runtime::ObjectPtrEqual> var_se_scopes_;
+ std::unordered_map<Var, VirtualDevice, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
+ var_virtual_devices_;
/*!
- * \brief A map from global variables to their \p SEScopes, ie the "result_se_scope" of the
- * function they are bound to in the module we are working on. We calculate and store this
+ * \brief A map from global variables to their \p VirtualDevices, ie the "result_virtual_device"
+ * of the function they are bound to in the module we are working on. We calculate and store this
* explicitly so that we don't need to hold on to any module, which is often in the process of
* being rewritten.
*/
- std::unordered_map<GlobalVar, SEScope, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
- global_var_se_scopes_;
+ std::unordered_map<GlobalVar, VirtualDevice, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
+ global_var_virtual_devices_;
};
template <typename FType>
class DeviceAwareExprFunctor;
/*!
- * \brief ExprFunctor which tracks \p SEScopes. We only support 'visitor' style implementation
+ * \brief ExprFunctor which tracks \p VirtualDevices. We only support 'visitor' style implementation
* with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without
* any memoization.
*/
@@ -143,21 +146,21 @@ class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(cons
} else {
// Function parameters come into scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
- PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i));
+ PushBoundVar(function_node->params[i], GetFunctionParamVirtualDevice(function_node, i));
}
// Entering scope of function body.
- SEScope se_scope = GetFunctionResultSEScope(function_node);
- VLOG(2) << "entering " << se_scope << " for function:" << std::endl
+ VirtualDevice virtual_device = GetFunctionResultVirtualDevice(function_node);
+ VLOG(2) << "entering " << virtual_device << " for function:" << std::endl
<< PrettyPrint(GetRef<Function>(function_node));
- PushSEScope(se_scope);
+ PushVirtualDevice(virtual_device);
EnterFunctionBody();
DeviceAwareVisitExpr_(function_node);
// Leaving scope of function body.
ExitFunctionBody();
- PopSEScope();
- VLOG(2) << "leaving " << se_scope << " for function:" << std::endl
+ PopVirtualDevice();
+ VLOG(2) << "leaving " << virtual_device << " for function:" << std::endl
<< PrettyPrint(GetRef<Function>(function_node));
// Function parameters go out of scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
@@ -173,9 +176,10 @@ class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(cons
while (const auto* inner_let_node = expr.as<LetNode>()) {
// Let-bound var (in pre visited version) goes into scope.
// (We'll just assume this is a letrec.)
- SEScope se_scope = GetSEScope(inner_let_node->value);
- VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has scope " << se_scope;
- PushBoundVar(inner_let_node->var, se_scope);
+ VirtualDevice virtual_device = GetVirtualDevice(inner_let_node->value);
+ VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has virtual device "
+ << virtual_device;
+ PushBoundVar(inner_let_node->var, virtual_device);
PreVisitLetBinding_(inner_let_node->var, inner_let_node->value);
bindings.emplace_back(inner_let_node);
expr = inner_let_node->body;
@@ -196,13 +200,13 @@ class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(cons
OnDeviceProps props = GetOnDeviceProps(call_node);
if (props.body.defined() && props.is_fixed()) {
// Entering lexical scope of "on_device" call.
- VLOG(2) << "entering " << props.se_scope << " for on_device:" << std::endl
+ VLOG(2) << "entering " << props.virtual_device << " for on_device:" << std::endl
<< PrettyPrint(GetRef<Call>(call_node));
- PushSEScope(props.se_scope);
+ PushVirtualDevice(props.virtual_device);
VisitExpr(props.body);
// Leaving lexical scope of "on_device" call.
- PopSEScope();
- VLOG(2) << "leaving " << props.se_scope << " for on_device:" << std::endl
+ PopVirtualDevice();
+ VLOG(2) << "leaving " << props.virtual_device << " for on_device:" << std::endl
<< PrettyPrint(GetRef<Call>(call_node));
} else {
DeviceAwareVisitExpr_(call_node);
@@ -210,8 +214,8 @@ class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(cons
}
/*!
- * \brief These are as for VisitExpr_. \p SEScopes for expressions and function parameters will be
- * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
+ * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters
+ * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
* functions the function_nesting count will already include that of \p function_node.
*/
@@ -254,7 +258,7 @@ class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(cons
virtual void PostVisitLetBlock_(const LetNode* let_node) {}
};
-/*! \brief ExprVisitor which tracks \p SEScopes. */
+/*! \brief ExprVisitor which tracks \p VirtualDevices. */
class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin {
public:
explicit DeviceAwareExprVisitor(const Optional<IRModule>& maybe_mod)
@@ -267,8 +271,8 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin {
void VisitExpr_(const CallNode* call_node) final;
/*!
- * \brief These are as for VisitExpr_. \p SEScopes for expressions and function parameters will be
- * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
+ * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters
+ * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
* functions the function_nesting count will already include that of \p function_node.
*/
virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node);
@@ -281,9 +285,9 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin {
virtual void PreVisitLetBlock_(const LetNode* let_node);
/*!
- * \brief Visit a let-bound expression before the let body has been visited. \p SEScopes for the
- * let-bound variable will be tracked automatically. Default implementation just visits var and
- * value.
+ * \brief Visit a let-bound expression before the let body has been visited. \p VirtualDevices for
+ * the let-bound variable will be tracked automatically. Default implementation just visits var
+ * and value.
*/
virtual void PreVisitLetBinding_(const Var& var, const Expr& value);
@@ -300,7 +304,7 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin {
virtual void PostVisitLetBlock_(const LetNode* let_node);
};
-/*! \brief ExprMutator which tracks \p SEScopes. */
+/*! \brief ExprMutator which tracks \p VirtualDevices. */
class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin {
public:
explicit DeviceAwareExprMutator(const Optional<IRModule>& maybe_mod)
@@ -311,8 +315,8 @@ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin {
Expr VisitExpr_(const CallNode* call_node) final;
/*!
- * \brief These are as for VisitExpr_. \p SEScopes for expressions and function parameters will be
- * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
+ * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters
+ * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
* functions the function_nesting count will already include that of \p function_node.
*/
virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node);
@@ -325,9 +329,9 @@ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin {
virtual void PreVisitLetBlock_(const LetNode* let_node);
/*!
- * \brief Visit a let-bound expression before the let body has been visited. \p SEScopes for the
- * let-bound variable will be tracked automatically. Default implementation just visits var and
- * value.
+ * \brief Visit a let-bound expression before the let body has been visited. \p VirtualDevices for
+ * the let-bound variable will be tracked automatically. Default implementation just visits var
+ * and value.
*/
virtual std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value);
diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc
index fd46a6d..95249f9 100644
--- a/src/relay/transforms/device_domains.cc
+++ b/src/relay/transforms/device_domains.cc
@@ -37,43 +37,44 @@ namespace relay {
namespace transform {
DeviceDomains::DeviceDomains(CompilationConfig config) : config_(std::move(config)) {
- host_domain_ = MakeFirstOrderDomain(config_->host_se_scope);
+ host_domain_ = MakeFirstOrderDomain(config_->host_virtual_device);
}
-DeviceDomainPtr DeviceDomains::MakeFirstOrderDomain(const SEScope& se_scope) {
- if (se_scope->IsFullyConstrained()) {
- auto itr = fully_constrained_se_scope_to_domain_.find(se_scope);
- if (itr != fully_constrained_se_scope_to_domain_.end()) {
+DeviceDomainPtr DeviceDomains::MakeFirstOrderDomain(const VirtualDevice& virtual_device) {
+ if (virtual_device->IsFullyConstrained()) {
+ auto itr = fully_constrained_virtual_device_to_domain_.find(virtual_device);
+ if (itr != fully_constrained_virtual_device_to_domain_.end()) {
return itr->second;
}
- DeviceDomainPtr domain = std::make_shared<DeviceDomain>(se_scope);
- fully_constrained_se_scope_to_domain_.emplace(se_scope, domain);
+ DeviceDomainPtr domain = std::make_shared<DeviceDomain>(virtual_device);
+ fully_constrained_virtual_device_to_domain_.emplace(virtual_device, domain);
return domain;
} else {
- return std::make_shared<DeviceDomain>(se_scope);
+ return std::make_shared<DeviceDomain>(virtual_device);
}
}
-DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, const SEScope& se_scope) {
+DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, const VirtualDevice& virtual_device) {
if (const auto* func_type_node = type.as<FuncTypeNode>()) {
std::vector<DeviceDomainPtr> args_and_result;
args_and_result.reserve(func_type_node->arg_types.size() + 1);
for (const auto& arg_type : func_type_node->arg_types) {
- args_and_result.emplace_back(MakeDomain(arg_type, SEScope::FullyUnconstrained()));
+ args_and_result.emplace_back(MakeDomain(arg_type, VirtualDevice::FullyUnconstrained()));
}
- args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, se_scope));
+ args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, virtual_device));
return std::make_shared<DeviceDomain>(std::move(args_and_result));
} else {
- return MakeFirstOrderDomain(se_scope);
+ return MakeFirstOrderDomain(virtual_device);
}
}
-DeviceDomainPtr DeviceDomains::ForSEScope(const Type& type, const SEScope& non_canonical_se_scope) {
- // Generally se_scope will have come from an annotation so resolve it to ensure we have
+DeviceDomainPtr DeviceDomains::ForVirtualDevice(const Type& type,
+ const VirtualDevice& non_canonical_virtual_device) {
+ // Generally the virtual device will have come from an annotation so resolve it to ensure we have
// its canonical representation.
- SEScope se_scope = config_->CanonicalSEScope(non_canonical_se_scope);
- ICHECK(!se_scope->IsFullyUnconstrained());
- return MakeDomain(type, se_scope);
+ VirtualDevice virtual_device = config_->CanonicalVirtualDevice(non_canonical_virtual_device);
+ ICHECK(!virtual_device->IsFullyUnconstrained());
+ return MakeDomain(type, virtual_device);
}
DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) {
@@ -110,17 +111,18 @@ DeviceDomainPtr DeviceDomains::JoinOrNull(const DeviceDomainPtr& lhs, const Devi
<< "do not have the same kind and can't be unified.";
if (lhs->args_and_result_.empty()) {
// Directly compare first-order.
- if (rhs->se_scope_->IsFullyUnconstrained()) {
+ if (rhs->virtual_device_->IsFullyUnconstrained()) {
return lhs;
}
- if (lhs->se_scope_->IsFullyUnconstrained()) {
+ if (lhs->virtual_device_->IsFullyUnconstrained()) {
return rhs;
}
- Optional<SEScope> joined_se_scope = SEScope::Join(lhs->se_scope_, rhs->se_scope_);
- if (!joined_se_scope) {
+ Optional<VirtualDevice> joined_virtual_device =
+ VirtualDevice::Join(lhs->virtual_device_, rhs->virtual_device_);
+ if (!joined_virtual_device) {
return nullptr;
}
- return MakeFirstOrderDomain(config_->CanonicalSEScope(joined_se_scope.value()));
+ return MakeFirstOrderDomain(config_->CanonicalVirtualDevice(joined_virtual_device.value()));
} else {
// Recurse for higher-order.
std::vector<DeviceDomainPtr> args_and_result;
@@ -205,41 +207,42 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) {
// all the argument and result devices domains must be equal, ignoring memory scopes.
// So at this point we'll let all the arguments and result be free so that memory scopes can
// differ.
- // TODO(mbs): As per header comments, need to revisit when can setup sub-SEScope constraints.
+ // TODO(mbs): As per header comments, need to revisit when can setup sub-virtual device
+ // constraints.
return DomainFor(call_lowered_props.lowered_func);
} else if (on_device_props.body.defined()) {
// By default:
- // on_device(expr, se_scope=<t>)
+ // on_device(expr, virtual_device=<t>)
// on_device : fn(<t>):?x?
// However we'll interpret the constrain_body and constrain_result fields to decide
// on free vs constrained domains for the argument and result respectively.
if (on_device_props.constrain_body) {
args_and_result.emplace_back(
- ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope));
+ ForVirtualDevice(on_device_props.body->checked_type(), on_device_props.virtual_device));
} else {
args_and_result.emplace_back(Free(on_device_props.body->checked_type()));
}
if (on_device_props.constrain_result) {
args_and_result.emplace_back(
- ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope));
+ ForVirtualDevice(on_device_props.body->checked_type(), on_device_props.virtual_device));
} else {
args_and_result.emplace_back(Free(on_device_props.body->checked_type()));
}
} else if (device_copy_props.body.defined()) {
- // device_copy(expr, src_se_scope=<s>, dst_se_scope=<d>)
+ // device_copy(expr, src_virtual_device=<s>, dst_virtual_device=<d>)
// device_copy: fn(<s>):<d>
- args_and_result.emplace_back(
- ForSEScope(device_copy_props.body->checked_type(), device_copy_props.src_se_scope));
- args_and_result.emplace_back(
- ForSEScope(device_copy_props.body->checked_type(), device_copy_props.dst_se_scope));
+ args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(),
+ device_copy_props.src_virtual_device));
+ args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(),
+ device_copy_props.dst_virtual_device));
} else if (call->op == alloc_storage_op) {
ICHECK_EQ(call->args.size(), 2U);
- // alloc_storage(size, alignment, se_scope=<t>)
+ // alloc_storage(size, alignment, virtual_device=<t>)
// alloc_storage: fn(<cpu>, <cpu>):<t>
const auto* attrs = call->attrs.as<AllocStorageAttrs>();
args_and_result.emplace_back(host_domain_);
args_and_result.emplace_back(host_domain_);
- args_and_result.emplace_back(ForSEScope(call->checked_type(), attrs->se_scope));
+ args_and_result.emplace_back(ForVirtualDevice(call->checked_type(), attrs->virtual_device));
} else if (call->op == alloc_tensor_op) {
ICHECK_EQ(call->args.size(), 3U);
// alloc_tensor(storage, offset, shape)
@@ -277,7 +280,7 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) {
// <primitive>(arg1, ..., argn)
// <primitive>: fn(?x?, ..., ?x?):?x?
// (all args and result must be first-order).
- auto free_domain = MakeFirstOrderDomain(SEScope::FullyUnconstrained());
+ auto free_domain = MakeFirstOrderDomain(VirtualDevice::FullyUnconstrained());
for (size_t i = 0; i < call->args.size(); ++i) {
args_and_result.emplace_back(free_domain);
}
@@ -314,12 +317,12 @@ void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) {
auto rhs_domain = DomainFor(rhs);
if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) {
// TODO(mbs): Proper diagnostics.
- LOG(FATAL) << "Incompatible SEScopes for expressions:" << std::endl
+ LOG(FATAL) << "Incompatible virtual devices for expressions:" << std::endl
<< PrettyPrint(lhs) << std::endl
- << "with scope:" << std::endl
+ << "with virtual device:" << std::endl
<< ToString(lhs_domain) << "and:" << std::endl
<< PrettyPrint(rhs) << std::endl
- << "with scope:" << std::endl
+ << "with virtual device:" << std::endl
<< ToString(rhs_domain);
}
}
@@ -332,21 +335,21 @@ void DeviceDomains::OptionalUnifyExprExact(const Expr& lhs, const Expr& rhs) {
if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) {
// Rollback
domain_to_equiv_ = domain_to_equiv_snapshot;
- VLOG(2) << "Unable to unify SEScopes for expression:" << std::endl
+ VLOG(2) << "Unable to unify virtual devices for expression:" << std::endl
<< PrettyPrint(lhs) << std::endl
- << "with scope:" << std::endl
+ << "with virtual device:" << std::endl
<< ToString(lhs_domain) << std::endl
<< "and expression:" << std::endl
<< PrettyPrint(rhs) << std::endl
- << "with scope:" << std::endl
+ << "with virtual device:" << std::endl
<< ToString(rhs_domain) << std::endl
- << ". Leaving scopes non-unified.";
+ << ". Leaving virtual devices non-unified.";
} else {
- VLOG(2) << "Unified SEScopes for expression:" << std::endl
+ VLOG(2) << "Unified virtual devices for expression:" << std::endl
<< PrettyPrint(lhs) << std::endl
<< "and expression:" << std::endl
<< PrettyPrint(rhs) << std::endl
- << "to scope:" << std::endl
+ << "to virtual devices:" << std::endl
<< ToString(lhs_domain);
}
}
@@ -355,11 +358,11 @@ void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expe
auto actual_domain = DomainFor(expr);
if (UnifyOrNull(actual_domain, expected_domain) == nullptr) {
// TODO(mbs): Proper diagnostics.
- LOG(FATAL) << "Incompatible SEScopes for expression:" << std::endl
+ LOG(FATAL) << "Incompatible virtual devices for expression:" << std::endl
<< PrettyPrint(expr) << std::endl
- << "with actual scope:" << std::endl
+ << "with actual virtual device:" << std::endl
<< ToString(actual_domain) << std::endl
- << "and expected scope:" << std::endl
+ << "and expected virtual device:" << std::endl
<< ToString(expected_domain);
}
}
@@ -369,11 +372,11 @@ void DeviceDomains::UnifyExprCollapsed(const Expr& expr_first_order,
auto actual_domain_first_order = DomainFor(expr_first_order);
if (!UnifyCollapsedOrFalse(actual_domain_first_order, expected_domain_maybe_higher_order)) {
// TODO(mbs): Proper diagnostics.
- LOG(FATAL) << "Incompatible SEScopes for expression:" << std::endl
+ LOG(FATAL) << "Incompatible virtual devices for expression:" << std::endl
<< PrettyPrint(expr_first_order) << std::endl
- << "with actual scope:" << std::endl
+ << "with actual virtual devices:" << std::endl
<< ToString(actual_domain_first_order) << std::endl
- << "and expected scope:" << std::endl
+ << "and expected virtual device:" << std::endl
<< ToString(expected_domain_maybe_higher_order);
}
}
@@ -382,7 +385,7 @@ bool DeviceDomains::IsFullyConstrained(DeviceDomainPtr domain) {
domain = Lookup(domain);
if (domain->args_and_result_.empty()) {
// First-order.
- return domain->se_scope_->IsFullyConstrained();
+ return domain->virtual_device_->IsFullyConstrained();
} else {
// Higher-order.
return std::all_of(
@@ -391,30 +394,31 @@ bool DeviceDomains::IsFullyConstrained(DeviceDomainPtr domain) {
}
}
-void DeviceDomains::SetDefault(DeviceDomainPtr domain, const SEScope& default_se_scope) {
- ICHECK(!default_se_scope->IsFullyUnconstrained());
+void DeviceDomains::SetDefault(DeviceDomainPtr domain,
+ const VirtualDevice& default_virtual_device) {
+ ICHECK(!default_virtual_device->IsFullyUnconstrained());
domain = Lookup(domain);
if (domain->args_and_result_.empty()) {
- DeviceDomainPtr defaulted_domain_ptr =
- UnifyOrNull(domain, MakeFirstOrderDomain(config_->CanonicalSEScope(
- SEScope::Default(domain->se_scope_, default_se_scope))));
+ DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull(
+ domain, MakeFirstOrderDomain(config_->CanonicalVirtualDevice(
+ VirtualDevice::Default(domain->virtual_device_, default_virtual_device))));
ICHECK_NOTNULL(defaulted_domain_ptr);
} else {
for (const auto& sub_domain : domain->args_and_result_) {
- SetDefault(sub_domain, default_se_scope);
+ SetDefault(sub_domain, default_virtual_device);
}
}
}
void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order,
- const SEScope& default_se_scope) {
+ const VirtualDevice& default_virtual_device) {
if (domain_maybe_higher_order->args_and_result_.empty()) {
- SetDefault(domain_maybe_higher_order, default_se_scope);
+ SetDefault(domain_maybe_higher_order, default_virtual_device);
} else {
// First set default for result domain.
- SetDefault(ResultDomain(domain_maybe_higher_order), default_se_scope);
+ SetDefault(ResultDomain(domain_maybe_higher_order), default_virtual_device);
// Then use current result domain as default for everything else.
- SetDefault(domain_maybe_higher_order, ResultSEScope(domain_maybe_higher_order));
+ SetDefault(domain_maybe_higher_order, ResultVirtualDevice(domain_maybe_higher_order));
}
}
@@ -431,11 +435,11 @@ std::string DeviceDomains::ToString(DeviceDomainPtr domain) {
std::ostringstream os;
if (domain->args_and_result_.empty()) {
// First-order.
- if (!domain->se_scope_->IsFullyConstrained()) {
+ if (!domain->virtual_device_->IsFullyConstrained()) {
os << "?" << static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get())) << "?";
}
- if (!domain->se_scope_->IsFullyUnconstrained()) {
- os << domain->se_scope_;
+ if (!domain->virtual_device_->IsFullyUnconstrained()) {
+ os << domain->virtual_device_;
}
} else {
// higher-order
diff --git a/src/relay/transforms/device_domains.h b/src/relay/transforms/device_domains.h
index 223c7d4..983ecb4 100644
--- a/src/relay/transforms/device_domains.h
+++ b/src/relay/transforms/device_domains.h
@@ -30,7 +30,7 @@
#include <tvm/relay/type.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/target/compilation_config.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <memory>
#include <string>
@@ -51,7 +51,7 @@ class DeviceDomains;
*
* \code
* D ::= ?x? -- first order, free
- * | <se_scope> -- first order, bound to specific device and memory scope
+ * | <virtual_device> -- first order, bound to specific virtual device
* | fn(D1, ..., Dn):Dr -- higher order
* \endcode
*
@@ -59,31 +59,32 @@ class DeviceDomains;
* a notion of the 'result domain' of a domain:
* \code
* result_domain(?x?) = ?x?
- * result_domain(<se_scope>) = <se_scope>
+ * result_domain(<virtual_device>) = <virtual_device>
* result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr)
* \endcode
*
- * TODO(mbs): We currently don't allow sub-SEScope constraints. Eg for a function we can
- * express that the argument and result SEScopes must be exactly equal, but we cannot express
+ * TODO(mbs): We currently don't allow sub-VirtualDevice constraints. Eg for a function we can
+ * express that the argument and result VirtualDevices must be exactly equal, but we cannot express
* that though the devices and targets for arguments and results must be equal, it is ok for
* memory scopes to differ. At the moment we can get away with this since we run PlanDevices
* twice: once with all memory scopes unconstrained, then again with just memory scopes as
* the new property to flow. However we're on thin ice here and better would be to allow
- * constraints on SEScopes to be exploded into their device/target component and their
- * memory scope component. Should we fold layout constraints into SEScopes then they would
+ * constraints on VirtualDevices to be exploded into their device/target component and their
+ * memory scope component. Should we fold layout constraints into VirtualDevices then they would
* probably be grouped with memory scopes.
*/
class DeviceDomain {
public:
/*!
- * \brief Constructs a first-order domain for \p se_scope, which may be
- * fully free (ie se_scope is unconstrained), partially free (ie se_scope has at least on
- * of its target, device id or memory scopes known), or fully fixed (ie se_scope has its target,
- * device id and memory scopes set).
+ * \brief Constructs a first-order domain for \p virtual_device, which may be
+ * fully free (ie virtual_device is unconstrained), partially free (ie virtual_device has at
+ * least on of its target, device id or memory scopes known), or fully fixed (ie virtual_device
+ * has its target, device id and memory scopes set).
*
* CAUTION: Use DeviceDomains::MakeFirstOrderDomain instead of this ctor.
*/
- explicit DeviceDomain(SEScope se_scope) : se_scope_(std::move(se_scope)) {}
+ explicit DeviceDomain(VirtualDevice virtual_device)
+ : virtual_device_(std::move(virtual_device)) {}
/*!
* \brief Constructs a higher-order domain, where \p args_and_result contain the
@@ -92,13 +93,14 @@ class DeviceDomain {
* CAUTION: Use DeviceDomains::MakeHigherOrderDomain instead of this ctor.
*/
explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result)
- : se_scope_(SEScope::FullyUnconstrained()), args_and_result_(std::move(args_and_result)) {}
+ : virtual_device_(VirtualDevice::FullyUnconstrained()),
+ args_and_result_(std::move(args_and_result)) {}
bool is_higher_order() const { return !args_and_result_.empty(); }
- SEScope first_order_se_scope() const {
+ VirtualDevice first_order_virtual_device() const {
ICHECK(args_and_result_.empty()) << "expecting domain to be first-order";
- return se_scope_;
+ return virtual_device_;
}
size_t function_arity() const {
@@ -124,7 +126,7 @@ class DeviceDomain {
* (for example, the \p target and \p device_type are constrained but the \p virtual_device_id and
* \p memory_scope are still unconstrained), or fully constrained (everything is known).
*/
- const SEScope se_scope_;
+ const VirtualDevice virtual_device_;
/*!
* \brief If this is a function domain then the sub-domains for each of the function's
@@ -146,10 +148,10 @@ class DeviceDomains {
const CompilationConfig& config() const { return config_; }
/*!
- * \brief Returns the domain representing \p se_scope. If \p se_scope is fully constrained
- * then the domain will be unique that \p se_scope.
+ * \brief Returns the domain representing \p virtual_device. If \p virtual_device is fully
+ * constrained then the domain will be unique that \p virtual_device.
*/
- DeviceDomainPtr MakeFirstOrderDomain(const SEScope& se_scope);
+ DeviceDomainPtr MakeFirstOrderDomain(const VirtualDevice& virtual_device);
/*!
* \brief Returns a higher-order domain with \p args_and_results.
@@ -159,21 +161,24 @@ class DeviceDomains {
}
/*!
- * \brief Returns a domain appropriate for \p type who's result domain is bound to \p se_scope.
- * If \p type is a function then all parameter domains will be completely free. It is valid for
- * \p se_scope to be fully unconstrained.
+ * \brief Returns a domain appropriate for \p type who's result domain is bound to \p
+ * virtual_device. If \p type is a function then all parameter domains will be completely free. It
+ * is valid for \p virtual_device to be fully unconstrained.
*/
- DeviceDomainPtr MakeDomain(const Type& type, const SEScope& se_scope);
+ DeviceDomainPtr MakeDomain(const Type& type, const VirtualDevice& virtual_device);
/*!
- * \brief Returns a domain with the given result appropriate \p non_canonical_se_scope,
- * which cannot be fully unconstrained. We first canonicalize the scope to unsure it has
+ * \brief Returns a domain with the given result appropriate \p non_canonical_virtual_device,
+ * which cannot be fully unconstrained. We first canonicalize the virtual device to unsure it has
* a target and is unique.
*/
- DeviceDomainPtr ForSEScope(const Type& type, const SEScope& non_canonical_se_scope);
+ DeviceDomainPtr ForVirtualDevice(const Type& type,
+ const VirtualDevice& non_canonical_virtual_device);
/*! \brief Returns a free domain appropriate for \p type. */
- DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, SEScope::FullyUnconstrained()); }
+ DeviceDomainPtr Free(const Type& type) {
+ return MakeDomain(type, VirtualDevice::FullyUnconstrained());
+ }
/*! \brief Returns the domain representing the equivalence class containing \p domain. */
DeviceDomainPtr Lookup(DeviceDomainPtr domain);
@@ -274,16 +279,16 @@ class DeviceDomains {
/*! \brief Returns true if \p domain is fully constrainted. */
bool IsFullyConstrained(DeviceDomainPtr domain);
- /*! \brief Force all \p SEScopes in \p domain to default to \p default_se_scope. */
- void SetDefault(DeviceDomainPtr domain, const SEScope& default_se_scope);
+ /*! \brief Force all \p VirtualDevices in \p domain to default to \p default_virtual_device. */
+ void SetDefault(DeviceDomainPtr domain, const VirtualDevice& default_virtual_device);
/*!
- * \brief If \p domain is higher-order default it's result domain to \p default_se_scope.
- * Then force all remaining \p SEScopes to the result domain (freshly defaulted or original).
- * If \p domain is first-order same as \p SetDefault.
+ * \brief If \p domain is higher-order default it's result domain to \p default_virtual_device.
+ * Then force all remaining \p VirtualDevices to the result domain (freshly defaulted or
+ * original). If \p domain is first-order same as \p SetDefault.
*/
void SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order,
- const SEScope& default_se_scope);
+ const VirtualDevice& default_virtual_device);
/*!
* \brief Returns the result domain for \p domain (see defn in DeviceDomain comment).
@@ -291,11 +296,11 @@ class DeviceDomains {
DeviceDomainPtr ResultDomain(DeviceDomainPtr domain);
/*!
- * \brief Returns the result \p SEScope (possibly unconstrained) for \p domain
+ * \brief Returns the result \p VirtualDevice (possibly unconstrained) for \p domain
* (see defn in DeviceDomain comment).
*/
- SEScope ResultSEScope(const DeviceDomainPtr& domain) {
- return ResultDomain(domain)->first_order_se_scope();
+ VirtualDevice ResultVirtualDevice(const DeviceDomainPtr& domain) {
+ return ResultDomain(domain)->first_order_virtual_device();
}
/*! \brief Returns one-line description of \p domain for debugging. */
@@ -332,16 +337,17 @@ class DeviceDomains {
std::unordered_map<DeviceDomainPtr, DeviceDomainPtr> domain_to_equiv_;
/*!
- * \brief Maps fully constrained \p SEScopes to their corresponding domains. By sharing those
- * domains we can ensure:
+ * \brief Maps fully constrained \p VirtualDevices to their corresponding domains. By sharing
+ * those domains we can ensure:
*
* \code
* domain0 != domain1 && domain0 fully constrained && domain1 fully constrained
* ==> domain0 and domain1 are incompatible
* \endcode
*/
- std::unordered_map<SEScope, DeviceDomainPtr, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
- fully_constrained_se_scope_to_domain_;
+ std::unordered_map<VirtualDevice, DeviceDomainPtr, runtime::ObjectPtrHash,
+ runtime::ObjectPtrEqual>
+ fully_constrained_virtual_device_to_domain_;
};
} // namespace transform
diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc
index bad8363..d40dd6c 100644
--- a/src/relay/transforms/device_planner.cc
+++ b/src/relay/transforms/device_planner.cc
@@ -19,14 +19,14 @@
/*!
* \file src/relay/transforms/device_planner.cc
- * \brief Determines a unique \p SEScope to hold the result of every Relay sub-expression.
+ * \brief Determines a unique \p VirtualDevice to hold the result of every Relay sub-expression.
* This pass can be run multiple times, and can be run both before and after lowering.
*
- * TODO(mbs): Rename SEScope |-> VirtualDevice, and use 'virtual device' (or just 'device')
+ * TODO(mbs): Rename VirtualDevice |-> VirtualDevice, and use 'virtual device' (or just 'device')
* throughout.
*
* We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
- * We represent D by an \p SEScope, which means we can track anywhere from an arbitrary device
+ * We represent D by an \p VirtualDevice, which means we can track anywhere from an arbitrary device
* of some \p DLDeviceType to a specific memory scope on a specific (virtual) \p Device who's
* code is compiled with a specific \p Target.
*
@@ -37,17 +37,17 @@
* resolve any remaining undetermined devices, and encoding the results on the output in a form
* that's reasonably friendly to downstream passes.
*
- * Specific \p SEScopes flow into the constraints from five places:
+ * Specific \p VirtualDevices flow into the constraints from five places:
* - Existing "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a
- * 'src_se_scope' and 'dst_se_scope' \p SEScope. Those constrain the argument and context of
- * the call respectively. It is ok if source and destination devices are the same, such no-op
- * copies will be removed after accounting for the device preference.
- * - Existing "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an 'se_scope',
- * which constrains the argument of the call, but (usually, see below) leaves the context
- * unconstrained. These are called 'annotations' in the rest of the code, have no operational
- * significance by themselves, but may trigger the insertion of a new "device_copy" call by
- * this pass. In two situations the result of an "on_device" CallNode may also be constrained
- * to the given 'se_scope':
+ * 'src_virtual_device' and 'dst_virtual_device' \p VirtualDevice. Those constrain the argument
+ * and context of the call respectively. It is ok if source and destination devices are the same,
+ * such no-op copies will be removed after accounting for the device preference.
+ * - Existing "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an
+ * 'virtual_device', which constrains the argument of the call, but (usually, see below) leaves the
+ * context unconstrained. These are called 'annotations' in the rest of the code, have no
+ * operational significance by themselves, but may trigger the insertion of a new "device_copy" call
+ * by this pass. In two situations the result of an "on_device" CallNode may also be constrained to
+ * the given 'virtual_device':
* - The "on_device" call occurs at the top-level of a function body, or occurs as an
* immediately let-bound expression. In this situation the extra degree of freedom in
* the function result and let-binding leads to surprising device copies, so we simply
@@ -56,12 +56,12 @@
* it ourselves during an earlier invocation of this pass. This helps make this pass
* idempotent.
* - Some special operators require their arguments or results to be on the 'host' (typcially
- * a CPU) \p SEScope, see below.
+ * a CPU) \p VirtualDevice, see below.
* - Any \p PrimFuncs in the \p IRModule (if \p LowerTEPass has already run) may constrain their
- * argument buffers to have a specific memory scope, which is part of \p SEScope.
- * - Annotations left over from a previous run of this pass, such as 'param_se_scopes' and
- * 'result_se_scope' function attributes we introduce below. This is so the pass is idempotent
- * and can be re-run to flow additional memory scope constraints.
+ * argument buffers to have a specific memory scope, which is part of \p VirtualDevice.
+ * - Annotations left over from a previous run of this pass, such as 'param_virtual_devices' and
+ * 'result_virtual_device' function attributes we introduce below. This is so the pass is
+ * idempotent and can be re-run to flow additional memory scope constraints.
*
* We proceed in four phases:
*
@@ -114,8 +114,8 @@
*
* Phase 2
* -------
- * After flowing constraints we apply some defaulting heuristics (using a global default \p SEScope)
- * to fix the device for any as-yet unconstrained sub-expressions.
+ * After flowing constraints we apply some defaulting heuristics (using a global default \p
+ * VirtualDevice) to fix the device for any as-yet unconstrained sub-expressions.
* - Unconstrained function result devices default to the global default device.
* - Unconstrained function parameters devices default to the device for the function result.
* - Unconstrained let-bound expression devices default to the device for the overall let.
@@ -127,9 +127,9 @@
* Phase 3
* -------
* Finally, the result of this analysis is reified into the result as:
- * - Additional "param_se_scopes" (an \p Array<SEScope>) and "result_se_scope" (an \p SEScope)
- * attributes for every function (both top-level and local). These describe the devices for
- * the function's parameters and the result.
+ * - Additional "param_virtual_devices" (an \p Array<VirtualDevice>) and "result_virtual_device"
+ * (an \p VirtualDevice) attributes for every function (both top-level and local). These describe
+ * the devices for the function's parameters and the result.
* - Additional "device_copy" CallNodes where a copy is required in order to respect the
* intent of the original "on_device" CallNodes.
* - Additional "on_device" CallNodes where the device type of an expression is not trivially
@@ -155,11 +155,11 @@
* passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion
* to ANF must respect the lexical scoping convention:
* \code
- * f(on_device(g(h(a, b), c), se_scope=CPU))
+ * f(on_device(g(h(a, b), c), virtual_device=CPU))
* ==>
- * let %x0 = on_device(h(a, b), se_scope=CPU)
- * let %x1 = on_device(g(%x0), se_scope=CPU)
- * f(on_device(%x1, se_scope=CPU))
+ * let %x0 = on_device(h(a, b), virtual_device=CPU)
+ * let %x1 = on_device(g(%x0), virtual_device=CPU)
+ * f(on_device(%x1, virtual_device=CPU))
* \endcode
*
* This pass can be run before FuseOps so that it can use device-specific fusion rules.
@@ -188,7 +188,7 @@
* minimize cross-device calls by moving device copies out of functions. E.g.:
* \code
* def @f() { // execute on CPU
- * let x = on_device(...GPU computation..., se_scope=GPU);
+ * let x = on_device(...GPU computation..., virtual_device=GPU);
* device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
* }
* def @main() {
@@ -220,7 +220,7 @@
* \code
* let f = fn(x, y) { ... }
* let g = fn(f, z) { f(z, z) }
- * g(f, on_device(..., se_scope=CPU))
+ * g(f, on_device(..., virtual_device=CPU))
* \endcode
* the parameters \p x and \p y will be on the CPU.
*
@@ -298,14 +298,14 @@ namespace {
*
* - Don't let the device for %x remain unconstrained:
* \code
- * let %x = on_device(e, se_scope=d)
- * ==> let %x = on_device(e, se_scope=d, constraint=kBoth)
+ * let %x = on_device(e, virtual_device=d)
+ * ==> let %x = on_device(e, virtual_device=d, constraint=kBoth)
* \endcode
*
* - Don't let the function result remain unconstrained:
* \code
- * fn(%x) { on_device(e, se_scope=d) }
- * ==> fn(%x) { on_device(e, se_scope=d, constraint=kBoth)
+ * fn(%x) { on_device(e, virtual_device=d) }
+ * ==> fn(%x) { on_device(e, virtual_device=d, constraint=kBoth)
* \endcode
*
* - Project-then-copy rather than copy-then-project:
@@ -321,7 +321,7 @@ namespace {
* call_lowered(@prim, (a, b))
* ==> copy_ok(call_lowered(@prim, (copy_ok(a), copy_ok(b))))
* where
- * copy_ok(x) = on_device(x, se_scope=SEScope::FullyUnconstrained,
+ * copy_ok(x) = on_device(x, virtual_device=VirtualDevice::FullyUnconstrained,
* constrain_body=False, constrain_result=False)
* \endcode
*/
@@ -338,7 +338,7 @@ class RewriteOnDevices : public ExprMutator {
if (props.body.defined() && props.is_normal()) {
VLOG(2) << "wrapping tuple get item:" << std::endl
<< PrettyPrint(GetRef<TupleGetItem>(tuple_get_item_node)) << std::endl
- << "with \"on_device\" for SEScope " << props.se_scope;
+ << "with \"on_device\" for VirtualDevice " << props.virtual_device;
return OnDeviceWithProps(tuple_get_item, props);
} else {
return tuple_get_item;
@@ -355,8 +355,8 @@ class RewriteOnDevices : public ExprMutator {
if (props.body.defined() && props.is_normal()) {
VLOG(2) << "revising let-bound expression of let:" << std::endl
<< PrettyPrint(expr) << std::endl
- << "to be fixed to SEScope " << props.se_scope;
- value = MaybeOnDeviceFixed(props.body, props.se_scope);
+ << "to be fixed to VirtualDevice " << props.virtual_device;
+ value = MaybeOnDeviceFixed(props.body, props.virtual_device);
}
bindings.emplace_back(inner_let, value);
expr = inner_let_node->body;
@@ -375,8 +375,8 @@ class RewriteOnDevices : public ExprMutator {
if (props.body.defined() && props.is_normal()) {
VLOG(2) << "revising body of function:" << std::endl
<< PrettyPrint(GetRef<Function>(function_node)) << std::endl
- << "to be fixed to SEScope " << props.se_scope;
- body = MaybeOnDeviceFixed(props.body, props.se_scope);
+ << "to be fixed to VirtualDevice " << props.virtual_device;
+ body = MaybeOnDeviceFixed(props.body, props.virtual_device);
}
return WithFields(GetRef<Function>(function_node), function_node->params, std::move(body));
}
@@ -412,12 +412,12 @@ class RewriteOnDevices : public ExprMutator {
* It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter.
*
* Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later,
- * from \code on_device(%x, se_scope=d) \endcode we know \p %x must be on device \p d, and thus
- * so must \p %y.
+ * from \code on_device(%x, virtual_device=d) \endcode we know \p %x must be on device \p d, and
+ * thus so must \p %y.
*
* Constraints can flow in interesting ways. E.g. in:
* \code
- * let %f = fn(%x, %y) { add(%x, on_device(%y, se_scope=d)) }
+ * let %f = fn(%x, %y) { add(%x, on_device(%y, virtual_device=d)) }
* let %g = fn(%f, %x, %y) { %f(%x, %y) }
* %g(%f, %a, %b)
* \endcode
@@ -468,21 +468,21 @@ class DeviceAnalyzer : public ExprVisitor {
ICHECK(func_type_node);
ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size());
- Array<SEScope> se_scopes =
+ Array<VirtualDevice> virtual_devices =
tir::GetPrimFuncArgAndResultConstraints(prim_func, GetRef<FuncType>(func_type_node));
// Build the implied domain (in terms of the function's Relay type) implied by any memory scope
// constrains in the function's buffers, for both arguments and results.
std::vector<DeviceDomainPtr> args_and_result_domains;
- args_and_result_domains.reserve(se_scopes.size());
+ args_and_result_domains.reserve(virtual_devices.size());
for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) {
- const SEScope& param_se_scope = se_scopes[i];
- VLOG(2) << "param_se_scope[" << i << "] = " << param_se_scope;
- args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(param_se_scope));
+ const VirtualDevice& param_virtual_device = virtual_devices[i];
+ VLOG(2) << "param_virtual_device[" << i << "] = " << param_virtual_device;
+ args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(param_virtual_device));
}
- const SEScope& ret_se_scope = se_scopes.back();
- VLOG(2) << "ret_se_scope = " << ret_se_scope;
- args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(ret_se_scope));
+ const VirtualDevice& ret_virtual_device = virtual_devices.back();
+ VLOG(2) << "ret_virtual_device = " << ret_virtual_device;
+ args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(ret_virtual_device));
return domains_->MakeHigherOrderDomain(std::move(args_and_result_domains));
}
@@ -520,13 +520,14 @@ class DeviceAnalyzer : public ExprVisitor {
// The above must match.
if (domains_->UnifyOrNull(func_domain, implied_domain) == nullptr) { // higher-order
// TODO(mbs): Proper diagnostics.
- LOG(FATAL) << "Function parameters and result SEScopes do not match those of call. Call:"
- << std::endl
- << PrettyPrint(call) << std::endl
- << "with function virtual devices:" << std::endl
- << domains_->ToString(func_domain) << std::endl
- << "and implied call virtual devices:" << std::endl
- << domains_->ToString(implied_domain);
+ LOG(FATAL)
+ << "Function parameters and result VirtualDevices do not match those of call. Call:"
+ << std::endl
+ << PrettyPrint(call) << std::endl
+ << "with function virtual devices:" << std::endl
+ << domains_->ToString(func_domain) << std::endl
+ << "and implied call virtual devices:" << std::endl
+ << domains_->ToString(implied_domain);
}
VLOG(2) << "final call function domain:" << std::endl
@@ -584,27 +585,28 @@ class DeviceAnalyzer : public ExprVisitor {
VisitExpr(function_node->params[i]);
}
- // If the function already has SEScope attributes then we can further constrain the
+ // If the function already has VirtualDevice attributes then we can further constrain the
// function's domain to match them.
- if (!GetFunctionResultSEScope(function_node)->IsFullyUnconstrained()) {
+ if (!GetFunctionResultVirtualDevice(function_node)->IsFullyUnconstrained()) {
std::vector<DeviceDomainPtr> args_and_result;
for (size_t i = 0; i < function_node->params.size(); ++i) {
- args_and_result.emplace_back(domains_->ForSEScope(
- function_node->params[i]->checked_type(), GetFunctionParamSEScope(function_node, i)));
+ args_and_result.emplace_back(
+ domains_->ForVirtualDevice(function_node->params[i]->checked_type(),
+ GetFunctionParamVirtualDevice(function_node, i)));
}
- args_and_result.emplace_back(domains_->ForSEScope(function_node->body->checked_type(),
- GetFunctionResultSEScope(function_node)));
+ args_and_result.emplace_back(domains_->ForVirtualDevice(
+ function_node->body->checked_type(), GetFunctionResultVirtualDevice(function_node)));
auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result));
if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order
// TODO(mbs): Proper diagnostics.
- LOG(FATAL)
- << "Function SEScopes are incompatible with its \"on_device\" annotation. Function:"
- << std::endl
- << PrettyPrint(function) << std::endl
- << "with function virtual devices:" << std::endl
- << domains_->ToString(func_domain) << std::endl
- << "and annotation virtual devices:" << std::endl
- << domains_->ToString(annotation_domain);
+ LOG(FATAL) << "Function VirtualDevices are incompatible with its \"on_device\" annotation. "
+ "Function:"
+ << std::endl
+ << PrettyPrint(function) << std::endl
+ << "with function virtual devices:" << std::endl
+ << domains_->ToString(func_domain) << std::endl
+ << "and annotation virtual devices:" << std::endl
+ << domains_->ToString(annotation_domain);
}
}
@@ -783,7 +785,7 @@ class FreeOnDeviceDefaulter : public ExprVisitor {
* \code
* def @main(%x, %y, %z) {
* let %a = add(%x, %y);
- * multiply(%a, on_device(%z, se_scope=d))
+ * multiply(%a, on_device(%z, virtual_device=d))
* }
* \endcode
* we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y,
@@ -801,7 +803,8 @@ class DeviceDefaulter : public ExprVisitor {
std::unique_ptr<DeviceDomains> Default() {
VLOG_CONTEXT << "DeviceDefaulter";
- VLOG(0) << "defaulting to SEScope " << domains_->config()->default_primitive_se_scope;
+ VLOG(0) << "defaulting to VirtualDevice "
+ << domains_->config()->default_primitive_virtual_device;
for (const auto& kv : mod_->functions) {
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
VLOG(2) << "defaulting devices for '" << kv.first->name_hint << "'";
@@ -825,7 +828,7 @@ class DeviceDefaulter : public ExprVisitor {
if (!domains_->IsFullyConstrained(func_domain)) {
VLOG(2) << "before defaulting function:" << std::endl << domains_->ToString(func_domain);
domains_->SetResultDefaultThenParams(func_domain,
- domains_->config()->default_primitive_se_scope);
+ domains_->config()->default_primitive_virtual_device);
VLOG(2) << "after defaulting function:" << std::endl << domains_->ToString(func_domain);
}
VisitExpr(function_node->body);
@@ -845,7 +848,7 @@ class DeviceDefaulter : public ExprVisitor {
// defaulted.
VLOG(2) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain);
domains_->SetResultDefaultThenParams(func_domain,
- domains_->config()->default_primitive_se_scope);
+ domains_->config()->default_primitive_virtual_device);
VLOG(2) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain);
}
return ExprVisitor::VisitExpr_(call_node);
@@ -858,12 +861,12 @@ class DeviceDefaulter : public ExprVisitor {
Let let = Downcast<Let>(expr);
// If the let-var device is still free force it to match the overall let.
auto let_domain = domains_->DomainFor(let); // may be higher-order
- SEScope let_se_scope = domains_->ResultSEScope(let_domain);
- ICHECK(!let_se_scope->IsFullyUnconstrained());
+ VirtualDevice let_virtual_device = domains_->ResultVirtualDevice(let_domain);
+ ICHECK(!let_virtual_device->IsFullyUnconstrained());
auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order
if (!domains_->IsFullyConstrained(let_var_domain)) {
VLOG(2) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain);
- domains_->SetDefault(let_var_domain, let_se_scope);
+ domains_->SetDefault(let_var_domain, let_virtual_device);
VLOG(2) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain);
}
VisitExpr(let->var);
@@ -889,7 +892,7 @@ class DeviceDefaulter : public ExprVisitor {
* - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard
* any existing "device_copy" CallNodes which are no-ops.
*
- * - Functions are given "param_se_scopes" and "result_se_scope" attributes to capture
+ * - Functions are given "param_virtual_devices" and "result_virtual_device" attributes to capture
* the device type for its parameters and result.
*
* - Additional "device_copy" CallNodes are inserted wherever there's a transition between
@@ -910,10 +913,10 @@ class DeviceDefaulter : public ExprVisitor {
*
* For example, we'll end up with programs that look like:
* \code
- * def @main(%x, %y, param_se_scopes=[...], result_se_scope=...) {
- * let %a = on_device(..., se_scope=..., is_fixed=True)
- * @f(%a, device_copy(on_device(..., se_scope=..., is_fixed=True),
- * src_se_scope=..., dst_se_scope=...))
+ * def @main(%x, %y, param_virtual_devices=[...], result_virtual_device=...) {
+ * let %a = on_device(..., virtual_device=..., is_fixed=True)
+ * @f(%a, device_copy(on_device(..., virtual_device=..., is_fixed=True),
+ * src_virtual_device=..., dst_virtual_device=...))
* }
* \endcode
*/
@@ -961,19 +964,21 @@ class DeviceCapturer : public ExprMutator {
ICHECK(func_type_node);
ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size());
- std::vector<SEScope> arg_and_result_se_scopes;
- arg_and_result_se_scopes.reserve(func_type_node->arg_types.size() + 1);
+ std::vector<VirtualDevice> arg_and_result_virtual_devices;
+ arg_and_result_virtual_devices.reserve(func_type_node->arg_types.size() + 1);
for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) {
- SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i));
- VLOG(2) << "param_se_scope[" << i << "] = " << param_se_scope;
- arg_and_result_se_scopes.push_back(param_se_scope);
+ VirtualDevice param_virtual_device =
+ domains_->ResultVirtualDevice(func_domain->function_param(i));
+ VLOG(2) << "param_virtual_device[" << i << "] = " << param_virtual_device;
+ arg_and_result_virtual_devices.push_back(param_virtual_device);
}
- SEScope ret_se_scope = domains_->ResultSEScope(func_domain->function_result());
- VLOG(2) << "ret_se_scope = " << ret_se_scope;
- arg_and_result_se_scopes.push_back(ret_se_scope);
+ VirtualDevice ret_virtual_device =
+ domains_->ResultVirtualDevice(func_domain->function_result());
+ VLOG(2) << "ret_virtual_device = " << ret_virtual_device;
+ arg_and_result_virtual_devices.push_back(ret_virtual_device);
return tir::ApplyPrimFuncArgAndResultConstraints(prim_func, GetRef<FuncType>(func_type_node),
- arg_and_result_se_scopes);
+ arg_and_result_virtual_devices);
}
// Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode
@@ -1002,26 +1007,28 @@ class DeviceCapturer : public ExprMutator {
// Gather the parameter and result device types for the function attributes.
ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
- SEScope result_se_scope = domains_->ResultSEScope(func_domain);
- ICHECK(!result_se_scope->IsFullyUnconstrained());
- Array<SEScope> param_se_scopes;
- param_se_scopes.reserve(function_node->params.size());
+ VirtualDevice result_virtual_device = domains_->ResultVirtualDevice(func_domain);
+ ICHECK(!result_virtual_device->IsFullyUnconstrained());
+ Array<VirtualDevice> param_virtual_devices;
+ param_virtual_devices.reserve(function_node->params.size());
for (size_t i = 0; i < function_node->params.size(); ++i) {
- SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i));
- ICHECK(!param_se_scope->IsFullyUnconstrained());
- param_se_scopes.push_back(param_se_scope);
+ VirtualDevice param_virtual_device =
+ domains_->ResultVirtualDevice(func_domain->function_param(i));
+ ICHECK(!param_virtual_device->IsFullyUnconstrained());
+ param_virtual_devices.push_back(param_virtual_device);
}
// Rewrite the body. Note that the body may have begun with an "on_device" so
// be prepared to insert a "device_copy".
Expr body = VisitChild(
- /*lexical_se_scope=*/result_se_scope,
- /*expected_se_scope=*/result_se_scope,
- /*child_se_scope=*/GetSEScope(function_node->body), function_node->body);
+ /*lexical_virtual_device=*/result_virtual_device,
+ /*expected_virtual_device=*/result_virtual_device,
+ /*child_virtual_device=*/GetVirtualDevice(function_node->body), function_node->body);
Function func = WithFields(GetRef<Function>(function_node), std::move(function_node->params),
std::move(body));
- return FunctionOnDevice(func, std::move(param_se_scopes), std::move(result_se_scope));
+ return FunctionOnDevice(func, std::move(param_virtual_devices),
+ std::move(result_virtual_device));
}
Expr VisitExpr_(const CallNode* call_node) final {
@@ -1031,7 +1038,7 @@ class DeviceCapturer : public ExprMutator {
// (However we'll preserve the form in the result below.)
auto vanilla_call = GetAnyCall(call_node);
- SEScope call_se_scope = GetSEScope(call);
+ VirtualDevice call_virtual_device = GetVirtualDevice(call);
auto on_device_props = GetOnDeviceProps(call_node);
if (on_device_props.body.defined()) {
@@ -1042,17 +1049,19 @@ class DeviceCapturer : public ExprMutator {
DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
if (device_copy_props.body.defined()) {
- SEScope src_se_scope = domains_->config()->CanonicalSEScope(device_copy_props.src_se_scope);
- SEScope dst_se_scope = domains_->config()->CanonicalSEScope(device_copy_props.dst_se_scope);
- ICHECK_EQ(call_se_scope, dst_se_scope);
- if (src_se_scope == dst_se_scope) {
+ VirtualDevice src_virtual_device =
+ domains_->config()->CanonicalVirtualDevice(device_copy_props.src_virtual_device);
+ VirtualDevice dst_virtual_device =
+ domains_->config()->CanonicalVirtualDevice(device_copy_props.dst_virtual_device);
+ ICHECK_EQ(call_virtual_device, dst_virtual_device);
+ if (src_virtual_device == dst_virtual_device) {
// We can pinch out existing "device_copy" CallNodes if their source and destinations
// match.
return VisitExpr(device_copy_props.body);
} else {
- return VisitChild(/*lexical_se_scope=*/dst_se_scope,
- /*expected_se_scope=*/dst_se_scope,
- /*child_se_scope=*/src_se_scope, device_copy_props.body);
+ return VisitChild(/*lexical_virtual_device=*/dst_virtual_device,
+ /*expected_virtual_device=*/dst_virtual_device,
+ /*child_virtual_device=*/src_virtual_device, device_copy_props.body);
}
}
@@ -1060,16 +1069,17 @@ class DeviceCapturer : public ExprMutator {
auto func_domain = domains_->DomainForCallee(call); // higher-order
VLOG(2) << "considering call:" << std::endl
<< PrettyPrint(call) << std::endl
- << "in scope " << call_se_scope << " with function virtual devices:" << std::endl
+ << "in virtual device " << call_virtual_device
+ << " with function virtual devices:" << std::endl
<< domains_->ToString(func_domain);
- SEScope result_se_scope = domains_->ResultSEScope(func_domain);
- ICHECK(!result_se_scope->IsFullyUnconstrained());
+ VirtualDevice result_virtual_device = domains_->ResultVirtualDevice(func_domain);
+ ICHECK(!result_virtual_device->IsFullyUnconstrained());
// The callee is on the current device.
Expr op = VisitChild(
- /*lexical_se_scope=*/call_se_scope,
- /*expected_se_scope=*/call_se_scope,
- /*child_se_scope=*/result_se_scope, vanilla_call->op);
+ /*lexical_virtual_device=*/call_virtual_device,
+ /*expected_virtual_device=*/call_virtual_device,
+ /*child_virtual_device=*/result_virtual_device, vanilla_call->op);
// Each argument can be on the device for the corresponding function parameter. However if
// any of those differ from the overall call device then wrap them in an "on_device" to
@@ -1078,13 +1088,14 @@ class DeviceCapturer : public ExprMutator {
args.reserve(vanilla_call->args.size());
ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size());
for (size_t i = 0; i < vanilla_call->args.size(); ++i) {
- SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i));
- ICHECK(!param_se_scope->IsFullyUnconstrained())
+ VirtualDevice param_virtual_device =
+ domains_->ResultVirtualDevice(func_domain->function_param(i));
+ ICHECK(!param_virtual_device->IsFullyUnconstrained())
<< "for parameter " << i << " for call:" << std::endl
<< PrettyPrint(call);
- args.push_back(VisitChild(/*lexical_se_scope=*/call_se_scope,
- /*expected_se_scope=*/param_se_scope,
- /*child_se_scope=*/GetSEScope(vanilla_call->args[i]),
+ args.push_back(VisitChild(/*lexical_virtual_device=*/call_virtual_device,
+ /*expected_virtual_device=*/param_virtual_device,
+ /*child_virtual_device=*/GetVirtualDevice(vanilla_call->args[i]),
vanilla_call->args[i]));
}
@@ -1100,27 +1111,28 @@ class DeviceCapturer : public ExprMutator {
Expr VisitExpr_(const LetNode* let_node) final {
Expr expr = GetRef<Expr>(let_node);
// Iterate through chained lets, provided they all agree on their device type.
- SEScope let_se_scope = GetSEScope(expr);
+ VirtualDevice let_virtual_device = GetVirtualDevice(expr);
std::vector<std::tuple<Var, Expr, Span>> bindings;
while (const auto* inner_let_node = expr.as<LetNode>()) {
Expr inner_let = GetRef<Let>(inner_let_node);
- if (GetSEScope(inner_let) != let_se_scope) {
+ if (GetVirtualDevice(inner_let) != let_virtual_device) {
// We have a device transition which needs to be handled.
break;
}
// The let-bound value can be on a different device than the overall let.
- // By using the fully-unconstrained SEScope for the 'lexical' scope we'll force the let-bound
- // value to *always* be wrapped by an "on_device" (see introductory comment for motivation.)
- Expr value =
- VisitChild(/*lexical_se_scope=*/SEScope::FullyUnconstrained(),
- /*expected_se_scope=*/GetSEScope(inner_let_node->var),
- /*child_se_scope=*/GetSEScope(inner_let_node->value), inner_let_node->value);
+ // By using the fully-unconstrained virtual device for the 'lexical' scope we'll force the
+ // let-bound value to *always* be wrapped by an "on_device" (see introductory comment for
+ // motivation.)
+ Expr value = VisitChild(/*lexical_virtual_device=*/VirtualDevice::FullyUnconstrained(),
+ /*expected_virtual_device=*/GetVirtualDevice(inner_let_node->var),
+ /*child_virtual_device=*/GetVirtualDevice(inner_let_node->value),
+ inner_let_node->value);
bindings.emplace_back(inner_let_node->var, value, inner_let_node->span);
expr = inner_let_node->body;
}
- Expr body = VisitChild(/*lexical_se_scope=*/let_se_scope,
- /*expected_se_scope=*/let_se_scope,
- /*child_se_scope=*/GetSEScope(expr), expr);
+ Expr body = VisitChild(/*lexical_virtual_device=*/let_virtual_device,
+ /*expected_virtual_device=*/let_virtual_device,
+ /*child_virtual_device=*/GetVirtualDevice(expr), expr);
for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body,
/*span=*/std::get<2>(*itr));
@@ -1175,68 +1187,70 @@ class DeviceCapturer : public ExprMutator {
return WithFields(std::move(match), std::move(data), std::move(clauses));
}
- SEScope GetSEScope(const Expr& expr) {
+ VirtualDevice GetVirtualDevice(const Expr& expr) {
// Look through any "on_device" CallNodes, to mimic how we will be pinching them out.
OnDeviceProps props = GetOnDeviceProps(expr);
Expr true_expr = props.body.defined() ? props.body : expr;
ICHECK(domains_->contains(true_expr));
// If expr is higher order we'll return only the result domain's device.
- SEScope se_scope = domains_->ResultSEScope(domains_->DomainFor(true_expr));
- ICHECK(!se_scope->IsFullyUnconstrained())
- << "no SEScope was determined for expression:" << std::endl
+ VirtualDevice virtual_device = domains_->ResultVirtualDevice(domains_->DomainFor(true_expr));
+ ICHECK(!virtual_device->IsFullyUnconstrained())
+ << "no VirtualDevice was determined for expression:" << std::endl
<< PrettyPrint(true_expr);
- return std::move(se_scope);
+ return std::move(virtual_device);
}
/*!
- * \brief Reconcile the \p child_se_scope for \p child with both the \p expected_se_scope
- * (as required by the expression context the \p child is in) and the \p lexical_se_scope
- * (as a downstream transform would infer based only on lexically enclosing "on_device"
- * CallNodes and function attributes.) Generally \p lexical_se_scope and \p
- * expected_se_scope are the same by definition, but may differ in arguments to functions
+ * \brief Reconcile the \p child_virtual_device for \p child with both the \p
+ * expected_virtual_device (as required by the expression context the \p child is in) and the \p
+ * lexical_virtual_device (as a downstream transform would infer based only on lexically enclosing
+ * "on_device" CallNodes and function attributes.) Generally \p lexical_virtual_device and \p
+ * expected_virtual_device are the same by definition, but may differ in arguments to functions
* and let-bound expressions.
*
- * If \p child_se_scope differs from \p expected_se_scope, wrap it as:
+ * If \p child_virtual_device differs from \p expected_virtual_device, wrap it as:
* \code
- * device_copy(on_device(child', se_scope=child_se_scope),
- * src_dev_type=child_se_scope, dst_dev_type=expected_se_scope)
+ * device_copy(on_device(child', virtual_device=child_virtual_device),
+ * src_dev_type=child_virtual_device, dst_dev_type=expected_virtual_device)
* \endcode
* (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the
* child.
*
- * If \p expected_se_scope differs from \p lexical_se_scope, then (also) wrap
+ * If \p expected_virtual_device differs from \p lexical_virtual_device, then (also) wrap
* the expression as:
* \code
- * on_device(..., se_scope=expected_se_scope)
+ * on_device(..., virtual_device=expected_virtual_device)
* \endcode
*
* TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped
* by a "device_copy", even though those copies will generally all be to the same destination
* device.
*/
- Expr VisitChild(const SEScope& lexical_se_scope, const SEScope& expected_se_scope,
- const SEScope& child_se_scope, const Expr& child) {
- ICHECK(!expected_se_scope->IsFullyUnconstrained());
+ Expr VisitChild(const VirtualDevice& lexical_virtual_device,
+ const VirtualDevice& expected_virtual_device,
+ const VirtualDevice& child_virtual_device, const Expr& child) {
+ ICHECK(!expected_virtual_device->IsFullyUnconstrained());
if (child->IsInstance<OpNode>() || child->IsInstance<ConstructorNode>()) {
// Primitive operators and contructors don't need to be rewritten and can have a
// different domain at each call site.
return child;
}
Expr result = VisitExpr(child);
- if (child_se_scope != expected_se_scope) {
- VLOG(2) << "creating " << DeviceCopyOp()->name << " from virtual device " << child_se_scope
- << " to virtual device " << expected_se_scope << " for:" << std::endl
+ if (child_virtual_device != expected_virtual_device) {
+ VLOG(2) << "creating " << DeviceCopyOp()->name << " from virtual device "
+ << child_virtual_device << " to virtual device " << expected_virtual_device
+ << " for:" << std::endl
<< PrettyPrint(result);
// Also wrap the child in an "on_device" so downstream transforms can track devices
// lexically.
- result = MaybeOnDeviceFixed(result, child_se_scope);
- result = DeviceCopy(result, child_se_scope, expected_se_scope);
+ result = MaybeOnDeviceFixed(result, child_virtual_device);
+ result = DeviceCopy(result, child_virtual_device, expected_virtual_device);
}
- if (expected_se_scope != lexical_se_scope) {
- VLOG(2) << "creating " << OnDeviceOp()->name << " for virtual device " << expected_se_scope
- << " for:" << std::endl
+ if (expected_virtual_device != lexical_virtual_device) {
+ VLOG(2) << "creating " << OnDeviceOp()->name << " for virtual device "
+ << expected_virtual_device << " for:" << std::endl
<< PrettyPrint(result);
- result = MaybeOnDeviceFixed(result, expected_se_scope);
+ result = MaybeOnDeviceFixed(result, expected_virtual_device);
}
return result;
}
@@ -1246,9 +1260,10 @@ class DeviceCapturer : public ExprMutator {
* is expected to be on the same device as the \p parent.
*/
Expr VisitChild(const Expr& parent, const Expr& child) {
- SEScope expected_se_scope = GetSEScope(parent);
- SEScope child_se_scope = GetSEScope(child);
- return VisitChild(expected_se_scope, expected_se_scope, child_se_scope, child);
+ VirtualDevice expected_virtual_device = GetVirtualDevice(parent);
+ VirtualDevice child_virtual_device = GetVirtualDevice(child);
+ return VisitChild(expected_virtual_device, expected_virtual_device, child_virtual_device,
+ child);
}
/*! \brief Module we are rewriting, so we can lookup global variables. */
@@ -1282,7 +1297,7 @@ tvm::transform::Pass PlanDevicesCore(CompilationConfig config) {
VLOG(3) << "Domains after defaulting: " << std::endl << domains->ToString();
// Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture
- // the above map, and attach additional "param_se_scopes" and "result_se_scope"
+ // the above map, and attach additional "param_virtual_devices" and "result_virtual_device"
// attributes to all function definitions.
return DeviceCapturer(mod, std::move(domains)).Capture();
},
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 831d28b..dd81957 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -41,8 +41,8 @@ namespace transform {
namespace {
/*!
* \brief Returns whether \p expr is a literal \p Constant, optionally wrapped by an "on_device"
- * annotation CallNode (which serves only to associate an \p SEScope to the constant and has no
- * operational effect).
+ * annotation CallNode (which serves only to associate an \p VirtualDevice to the constant and has
+ * no operational effect).
*/
bool IsSimpleConstant(const Expr& expr) {
return AsIgnoringOnDevice<ConstantNode>(expr) != nullptr;
@@ -86,19 +86,19 @@ class ConstantFolder : public MixedModeMutator {
// the variable.
//
// We need to retain any "on_device" annotation so that downstream 'device aware'
- // passes can still retrieve the \p SEScope for the constant in its new position(s). Eg:
- // def @f(..., result_se_scope=D) {
- // let %x = on_device(... something we eval to a constant..., se_scope=E)
+ // passes can still retrieve the virtual device for the constant in its new position(s). Eg:
+ // def @f(..., result_virtual_device=D) {
+ // let %x = on_device(... something we eval to a constant..., virtual_device=E)
// @f(..., %x, ...)
// }
- // Here the default scope is D, whereas the argument %x to @f is on E (and @f expects
- // that). No on_device annotation is required in the call according to the convention used
- // by the device-aware visitors.
+ // Here the default virtual device is D, whereas the argument %x to @f is on E (and @f
+ // expects that). No on_device annotation is required in the call according to the
+ // convention used by the device-aware visitors.
//
// However once we've inlined the constant we need to insert an on_device, again to
// respect the convention used by the device-aware visitors.
- // def @f(..., result_se_scope=D) {
- // @f(..., on_device(...the constant..., se_scope=E), ...)
+ // def @f(..., result_virtual_device=D) {
+ // @f(..., on_device(...the constant..., virtual_device=E), ...)
// }
VLOG(1) << "Replacing let-binding for " << op->var->name_hint()
<< " with constant:" << std::endl
@@ -214,7 +214,7 @@ class ConstantFolder : public MixedModeMutator {
Expr result = tuple_node->fields[tuple_get_item_node->index];
OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple);
if (props.body.defined()) {
- // (on_device((x, y, z), se_scope=D).1 ==> on_device(y, se_scope=D)
+ // (on_device((x, y, z), virtual_device=D).1 ==> on_device(y, virtual_device=D)
return MaybeOnDeviceWithProps(result, props);
} else {
return result;
diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc
index 25827d5..74da99b 100644
--- a/src/relay/transforms/memory_alloc.cc
+++ b/src/relay/transforms/memory_alloc.cc
@@ -61,10 +61,10 @@ namespace relay {
class DialectRewriter : public transform::DeviceAwareExprMutator {
public:
- DialectRewriter(IRModule mod, SEScope host_se_scope)
+ DialectRewriter(IRModule mod, VirtualDevice host_virtual_device)
: transform::DeviceAwareExprMutator(mod),
mod_(std::move(mod)),
- host_se_scope_(std::move(host_se_scope)) {}
+ host_virtual_device_(std::move(host_virtual_device)) {}
Function Rewrite(const Function& expr) { return Downcast<Function>(Mutate(expr)); }
@@ -79,10 +79,10 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
for (auto field : tuple_node->fields) {
auto new_field = Mutate(field);
if (new_field->IsInstance<ConstantNode>()) {
- SEScope se_scope = GetSEScope(field);
- ICHECK(!se_scope->IsFullyUnconstrained());
+ VirtualDevice virtual_device = GetVirtualDevice(field);
+ ICHECK(!virtual_device->IsFullyUnconstrained());
Var const_var("const", Type(nullptr));
- new_field = scope.Push(const_var, MaybeOnDeviceFixed(new_field, se_scope));
+ new_field = scope.Push(const_var, MaybeOnDeviceFixed(new_field, virtual_device));
}
new_fields.push_back(new_field);
}
@@ -93,9 +93,9 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
Expr new_value = Mutate(value);
- SEScope se_scope = GetSEScope(value);
- ICHECK(!se_scope->IsFullyUnconstrained());
- scopes_.back().Push(var, MaybeOnDeviceFixed(new_value, se_scope));
+ VirtualDevice virtual_device = GetVirtualDevice(value);
+ ICHECK(!virtual_device->IsFullyUnconstrained());
+ scopes_.back().Push(var, MaybeOnDeviceFixed(new_value, virtual_device));
// Since we always need a let block on which to bind sub-expressions the rewritten bindings
// are tracked in the current scopes. But return the rewritten binding anyway.
return {var, new_value};
@@ -132,8 +132,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
Call call = GetRef<Call>(call_node);
VLOG(1) << "converting lowered call to DPS:" << std::endl << PrettyPrint(call);
- SEScope se_scope = GetSEScope(call);
- ICHECK(!se_scope->IsFullyUnconstrained());
+ VirtualDevice virtual_device = GetVirtualDevice(call);
+ ICHECK(!virtual_device->IsFullyUnconstrained());
LetList& scope = scopes_.back();
std::vector<Expr> new_args;
@@ -171,19 +171,20 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
// by a companion shape function.
if (IsDynamic(ret_type)) {
return DynamicInvoke(&scope, call_lowered_props.lowered_func, ins, call_lowered_props.attrs,
- out_types, ret_type, se_scope);
+ out_types, ret_type, virtual_device);
}
// Handle ordinary primitive calls.
Array<Expr> outputs;
for (size_t i = 0; i < out_types.size(); ++i) {
- outputs.push_back(MakeStaticAllocation(&scope, out_types[i], se_scope, std::to_string(i)));
+ outputs.push_back(
+ MakeStaticAllocation(&scope, out_types[i], virtual_device, std::to_string(i)));
}
Tuple outs(outputs);
Expr invoke =
InvokeTVMOp(call_lowered_props.lowered_func, ins, outs,
Downcast<DictAttrs>(call_lowered_props.attrs.metadata.at("relay_attrs")));
- scope.Push(MaybeOnDeviceFixed(invoke, se_scope));
+ scope.Push(MaybeOnDeviceFixed(invoke, virtual_device));
return ToTupleType(ret_type, std::vector<Expr>(outputs.begin(), outputs.end()));
}
@@ -199,7 +200,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
/*! Returns an \p alloc_tensor call for a tensor of \p shape and \p dtype over \p storage. */
inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype,
Array<IndexExpr> assert_shape) {
- Expr offset = MaybeOnDeviceFixed(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_);
+ Expr offset =
+ MaybeOnDeviceFixed(MakeConstantScalar(DataType::Int(64), 0), host_virtual_device_);
return tvm::relay::AllocTensor(storage, std::move(offset), std::move(shape), dtype,
assert_shape);
}
@@ -234,28 +236,28 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
}
// Allocate a tensor with a statically known shape.
- Var MakeStaticAllocation(LetList* scope, const TensorType& type, const SEScope& se_scope,
- String name_hint) {
+ Var MakeStaticAllocation(LetList* scope, const TensorType& type,
+ const VirtualDevice& virtual_device, String name_hint) {
std::vector<int64_t> int_shape;
for (auto it : type->shape) {
const auto* imm = it.as<IntImmNode>();
CHECK(imm) << "expect static int shape";
int_shape.push_back(imm->value);
}
- Expr shape = MaybeOnDeviceFixed(MakeConstant(int_shape), host_se_scope_);
- Expr size = MaybeOnDeviceFixed(ComputeStorage(type), host_se_scope_);
+ Expr shape = MaybeOnDeviceFixed(MakeConstant(int_shape), host_virtual_device_);
+ Expr size = MaybeOnDeviceFixed(ComputeStorage(type), host_virtual_device_);
// Alignment is directly captured in the instruction rather than calculated, so we
// don't want to wrap it with an "on_device".
Expr alignment = ComputeAlignment(type->dtype);
// Run type inference later to get the correct type.
Var var("storage_" + name_hint, Type(nullptr));
- Expr value = AllocStorage(size, alignment, se_scope, type->dtype);
- auto sto = scope->Push(var, MaybeOnDeviceFixed(value, se_scope));
+ Expr value = AllocStorage(size, alignment, virtual_device, type->dtype);
+ auto sto = scope->Push(var, MaybeOnDeviceFixed(value, virtual_device));
// TODO(@jroesch): There is a bug with typing based on the constant shape.
auto tensor = AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape);
Var tensor_var("tensor_" + name_hint, Type(nullptr));
- return scope->Push(tensor_var, MaybeOnDeviceFixed(tensor, se_scope));
+ return scope->Push(tensor_var, MaybeOnDeviceFixed(tensor, virtual_device));
}
/*!
@@ -294,21 +296,21 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
Expr sh_of = Mutate(ShapeOf(exprs[j]));
Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr));
shape_func_ins.push_back(
- scope->Push(in_shape_var, MaybeOnDeviceFixed(sh_of, host_se_scope_)));
+ scope->Push(in_shape_var, MaybeOnDeviceFixed(sh_of, host_virtual_device_)));
input_pos++;
}
} else if (state == tec::kNeedInputData) {
auto new_arg = Mutate(arg); // already accounts for device
- SEScope arg_se_scope = GetSEScope(arg);
- ICHECK(!arg_se_scope->IsFullyUnconstrained());
+ VirtualDevice arg_virtual_device = GetVirtualDevice(arg);
+ ICHECK(!arg_virtual_device->IsFullyUnconstrained());
// The dynamic shape function is expecting its data on the host/CPU, so insert a
// device_copy otherwise. (We'll need to fuse & lower these copies in the same way
// we fuse & lower other operators we insert for, eg, dynamic tensor size calculation.)
- new_arg = MaybeDeviceCopy(MaybeOnDeviceFixed(new_arg, arg_se_scope), arg_se_scope,
- host_se_scope_);
+ new_arg = MaybeDeviceCopy(MaybeOnDeviceFixed(new_arg, arg_virtual_device),
+ arg_virtual_device, host_virtual_device_);
Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr));
shape_func_ins.push_back(
- scope->Push(in_shape_var, MaybeOnDeviceFixed(new_arg, host_se_scope_)));
+ scope->Push(in_shape_var, MaybeOnDeviceFixed(new_arg, host_virtual_device_)));
input_pos++;
} else {
// TODO(@jroesch): handle kNeedBoth
@@ -327,8 +329,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
ICHECK(tensor_type_node);
// Put the shape func on the host. This also ensures that everything between
// shape_of and shape_func is similarly on the host.
- Var alloc = MakeStaticAllocation(scope, GetRef<TensorType>(tensor_type_node), host_se_scope_,
- "out_shape_" + std::to_string(i));
+ Var alloc = MakeStaticAllocation(scope, GetRef<TensorType>(tensor_type_node),
+ host_virtual_device_, "out_shape_" + std::to_string(i));
out_shapes.push_back(alloc);
}
@@ -336,26 +338,27 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
auto shape_call = InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes),
Downcast<DictAttrs>(attrs.metadata.at("relay_attrs")));
Var shape_func_var("shape_func", Type(nullptr));
- scope->Push(shape_func_var, MaybeOnDeviceFixed(shape_call, host_se_scope_));
+ scope->Push(shape_func_var, MaybeOnDeviceFixed(shape_call, host_virtual_device_));
return out_shapes;
}
// Generate the code for invoking the TVM primitive \p func who's results have dynamic shapes.
Expr DynamicInvoke(LetList* scope, const Expr& func, const Tuple& ins,
const CallLoweredAttrs& attrs, const std::vector<TensorType>& out_types,
- const Type& ret_type, const SEScope& se_scope) {
+ const Type& ret_type, const VirtualDevice& virtual_device) {
Array<Expr> out_shapes = EmitShapeFunc(scope, ins, attrs);
std::vector<Var> storages;
CHECK_EQ(out_shapes.size(), out_types.size());
for (size_t i = 0; i < out_shapes.size(); ++i) {
auto out_shape = out_shapes[i];
auto out_type = out_types[i];
- auto size = MaybeOnDeviceFixed(ComputeStorageInRelay(out_shape, out_type), host_se_scope_);
+ auto size =
+ MaybeOnDeviceFixed(ComputeStorageInRelay(out_shape, out_type), host_virtual_device_);
// Alignment is directly captured in the instruction so don't wrap in "on_device".
auto alignment = ComputeAlignment(out_type->dtype);
Var sto_var("storage_" + std::to_string(i), Type(nullptr));
- auto val = AllocStorage(size, alignment, se_scope, out_type->dtype);
- storages.push_back(scope->Push(sto_var, MaybeOnDeviceFixed(val, se_scope)));
+ auto val = AllocStorage(size, alignment, virtual_device, out_type->dtype);
+ storages.push_back(scope->Push(sto_var, MaybeOnDeviceFixed(val, virtual_device)));
}
Array<Expr> outs;
@@ -365,13 +368,13 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
auto storage = storages[i];
auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape);
Var out_var("out_" + std::to_string(i), Type(nullptr));
- outs.push_back(scope->Push(out_var, MaybeOnDeviceFixed(alloc, se_scope)));
+ outs.push_back(scope->Push(out_var, MaybeOnDeviceFixed(alloc, virtual_device)));
}
Tuple tuple_outs(outs);
auto call =
InvokeTVMOp(func, ins, tuple_outs, Downcast<DictAttrs>(attrs.metadata.at("relay_attrs")));
- scope->Push(MaybeOnDeviceFixed(call, se_scope));
+ scope->Push(MaybeOnDeviceFixed(call, virtual_device));
return ToTupleType(ret_type,
std::vector<Expr>(tuple_outs->fields.begin(), tuple_outs->fields.end()));
}
@@ -395,7 +398,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
CHECK(imm) << "expect static int shape";
shape.push_back(imm->value);
}
- shape_expr = MaybeOnDeviceFixed(MakeConstant(shape), host_se_scope_);
+ shape_expr = MaybeOnDeviceFixed(MakeConstant(shape), host_virtual_device_);
}
return ReshapeTensor(ins->fields[0], shape_expr, ret_ty->shape);
}
@@ -404,7 +407,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
const Op& device_copy_op_ = Op::Get("device_copy");
runtime::DataType compute_dtype_ = runtime::DataType::Int(64);
IRModule mod_;
- SEScope host_se_scope_;
+ VirtualDevice host_virtual_device_;
std::vector<LetList> scopes_;
};
@@ -421,16 +424,16 @@ Pass ManifestAllocImportStorage() {
/*required=*/{});
}
-Pass ManifestAllocImpl(SEScope host_se_scope) {
- auto pass_func = [host_se_scope](Function func, IRModule mod, PassContext ctxt) {
- return DialectRewriter(mod, host_se_scope).Rewrite(func);
+Pass ManifestAllocImpl(VirtualDevice host_virtual_device) {
+ auto pass_func = [host_virtual_device](Function func, IRModule mod, PassContext ctxt) {
+ return DialectRewriter(mod, host_virtual_device).Rewrite(func);
};
return CreateFunctionPass(pass_func, 0, "ManifestAllocImpl", {});
}
-Pass ManifestAlloc(SEScope host_se_scope) {
+Pass ManifestAlloc(VirtualDevice cpu_virtual_device) {
std::vector<Pass> passes = {ManifestAllocImportStorage(), InferType(),
- ManifestAllocImpl(std::move(host_se_scope)), InferType()};
+ ManifestAllocImpl(std::move(cpu_virtual_device)), InferType()};
return Sequential(passes, "ManifestAlloc");
}
diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc
index 741de6d..321839d 100644
--- a/src/relay/transforms/to_a_normal_form.cc
+++ b/src/relay/transforms/to_a_normal_form.cc
@@ -211,14 +211,14 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::Lexi
}
Expr Atomic(const Expr& e, const Var& v) {
- Expr annotated_expr = MaybeOnDeviceFixed(e, GetSEScope(e));
+ Expr annotated_expr = MaybeOnDeviceFixed(e, GetVirtualDevice(e));
return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr;
}
// Bind expression `now` to var `v` if the original expression is in the include set, or if
// v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly
Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
- Expr annotated_expr = MaybeOnDeviceFixed(now, GetSEScope(orig));
+ Expr annotated_expr = MaybeOnDeviceFixed(now, GetVirtualDevice(orig));
Var var = v.defined() ? v : Var::GenSym();
bool not_included = include_set_ && include_set_->find(orig) == include_set_->end();
if (!v.defined() && not_included) {
@@ -232,10 +232,10 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::Lexi
OnDeviceProps props = GetOnDeviceProps(c);
if (props.body.defined() && props.is_fixed()) {
// Keep track of expression device type for lexically enclosing sub-expressions.
- PushSEScope(props.se_scope);
+ PushVirtualDevice(props.virtual_device);
Expr body = VisitExpr(props.body, v);
// We are done with this sub-expression.
- PopSEScope();
+ PopVirtualDevice();
// Preserve the "on_device" annotations.
return OnDeviceWithProps(body, props);
}
@@ -293,9 +293,9 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::Lexi
} else {
// Keep track of expression and bound variable device types for lexically enclosing
// sub-expressions.
- PushSEScope(GetFunctionResultSEScope(f));
+ PushVirtualDevice(GetFunctionResultVirtualDevice(f));
for (size_t i = 0; i < f->params.size(); ++i) {
- PushBoundVar(f->params[i], GetFunctionParamSEScope(f, i));
+ PushBoundVar(f->params[i], GetFunctionParamVirtualDevice(f, i));
}
EnterFunctionBody();
ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type,
@@ -305,7 +305,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::Lexi
for (size_t i = 0; i < f->params.size(); ++i) {
PopBoundVar(f->params[i]);
}
- PopSEScope();
+ PopVirtualDevice();
}
if (function_nesting() == 0) {
ICHECK(!v.defined());
@@ -320,7 +320,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::Lexi
Expr VisitExpr_(const LetNode* l, const Var& v) final {
Expr e = GetRef<Expr>(l);
// Keep track of bound variable device types for lexically enclosing sub-expressions.
- PushBoundVar(l->var, GetSEScope(l->value));
+ PushBoundVar(l->var, GetVirtualDevice(l->value));
VisitExpr(l->value, l->var);
Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body));
// We are done with these sub-expressions.
diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc
index e9c4daf..0401eeb 100644
--- a/src/target/compilation_config.cc
+++ b/src/target/compilation_config.cc
@@ -33,26 +33,27 @@ void CompilationConfigNode::VisitAttrs(AttrVisitor* v) {
v->Visit("legacy_target_map", &legacy_target_map);
v->Visit("host_target", &host_target);
v->Visit("primitive_targets", &primitive_targets);
- v->Visit("default_primitive_se_scope", &default_primitive_se_scope);
- v->Visit("host_se_scope", &host_se_scope);
+ v->Visit("default_primitive_virtual_device", &default_primitive_virtual_device);
+ v->Visit("host_virtual_device", &host_virtual_device);
v->Visit("optional_homogenous_target", &optional_homogeneous_target);
- // NOTE: The se_scope_cache_ is not accessible via FFI.
+ // NOTE: The virtual_device_cache_ is not accessible via FFI.
}
-SEScope CompilationConfigNode::CanonicalSEScope(const SEScope& se_scope) const {
- if (se_scope->target.defined()) {
- return se_scope_cache_.Unique(se_scope);
+VirtualDevice CompilationConfigNode::CanonicalVirtualDevice(
+ const VirtualDevice& virtual_device) const {
+ if (virtual_device->target.defined()) {
+ return virtual_device_cache_.Unique(virtual_device);
}
- DLDeviceType device_type = se_scope->device_type();
+ DLDeviceType device_type = virtual_device->device_type();
// TODO(mbs): Proper diagnostics.
CHECK(device_type != kInvalidDeviceType)
- << "SEScope annotations must include at least a device_type";
- Target target = FindPrimitiveTargetOrFail(se_scope->device_type());
- return se_scope_cache_.Unique(
- SEScope(device_type, se_scope->virtual_device_id, target, se_scope->memory_scope));
+ << "VirtualDevice annotations must include at least a device_type";
+ Target target = FindPrimitiveTargetOrFail(virtual_device->device_type());
+ return virtual_device_cache_.Unique(VirtualDevice(device_type, virtual_device->virtual_device_id,
+ target, virtual_device->memory_scope));
}
-void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContext& pass_ctx) {
+void CompilationConfigNode::EstablishDefaultVirtualDevices(const transform::PassContext& pass_ctx) {
//
// Gather the hints as to what our default device type for the 'host' should be, and
// create an appropriate target if we don't already have one.
@@ -105,9 +106,10 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex
}
//
- // Establish the host SEScope.
+ // Establish the host VirtualDevice.
//
- host_se_scope = se_scope_cache_.Unique(SEScope(host_device_type,
+ host_virtual_device =
+ virtual_device_cache_.Unique(VirtualDevice(host_device_type,
/*virtual_device_id=*/0, host_target));
//
@@ -149,11 +151,12 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex
}
//
- // Establish the default primitive SEScope, choosing a known Target to match the device type.
+ // Establish the default primitive VirtualDevice, choosing a known Target to match the device
+ // type.
//
- default_primitive_se_scope = se_scope_cache_.Unique(
- SEScope(default_primitive_device_type,
- /*virtual_device_id=*/0, FindPrimitiveTargetOrFail(default_primitive_device_type)));
+ default_primitive_virtual_device = virtual_device_cache_.Unique(VirtualDevice(
+ default_primitive_device_type,
+ /*virtual_device_id=*/0, FindPrimitiveTargetOrFail(default_primitive_device_type)));
}
/* static */ Target CompilationConfigNode::MakeDefaultTarget(DLDeviceType device_type) {
@@ -205,7 +208,7 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx,
// Complete the targets vector and establish default scopes. After this primitive_targets will
// contain the definitive list of all required targets, target_host will be defined, and
// all primitive targets will have host target_host.
- node->EstablishDefaultSEScopes(pass_ctx);
+ node->EstablishDefaultVirtualDevices(pass_ctx);
// LEGACY: Reconstruct the target map from all the primitive targets.
// Note that we require pointer equality between targets in legacy_target_map and
@@ -214,8 +217,8 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx,
node->legacy_target_map.Set(Integer(primitive_target->kind->device_type), primitive_target);
}
- ICHECK(node->default_primitive_se_scope->target.defined());
- ICHECK(node->host_se_scope->target.defined());
+ ICHECK(node->default_primitive_virtual_device->target.defined());
+ ICHECK(node->host_virtual_device->target.defined());
ICHECK_GT(node->primitive_targets.size(), 0U);
// Legacy: Some passes only support homogenous compilation and expect the target to be
@@ -227,8 +230,8 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx,
DLOG(INFO) << "Target " << target->ToDebugString() << " of device type "
<< target->kind->device_type << " is available for primitives";
}
- DLOG(INFO) << "Using default primitive scope " << node->default_primitive_se_scope;
- DLOG(INFO) << "Using host scope " << node->host_se_scope;
+ DLOG(INFO) << "Using default primitive virtual device " << node->default_primitive_virtual_device;
+ DLOG(INFO) << "Using host virtual device " << node->host_virtual_device;
data_ = std::move(node);
}
diff --git a/src/target/se_scope.cc b/src/target/virtual_device.cc
similarity index 71%
rename from src/target/se_scope.cc
rename to src/target/virtual_device.cc
index 8e6c6fe..cde58d3 100644
--- a/src/target/se_scope.cc
+++ b/src/target/virtual_device.cc
@@ -18,21 +18,22 @@
*/
/*!
- * \file tvm/target/se_scope.cc
- * \brief Implementation of \p SEScope for representing a Storage or Execution scope.
+ * \file tvm/target/virtual_device.cc
+ * \brief A compile time representation for where data is to be stored at runtime, and how to
+ * compile code to compute it.
*/
#include <tvm/node/reflection.h>
#include <tvm/runtime/device_api.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
namespace tvm {
-TVM_REGISTER_NODE_TYPE(SEScopeNode);
+TVM_REGISTER_NODE_TYPE(VirtualDeviceNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<SEScopeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = ref.as<SEScopeNode>();
- p->stream << "SEScope(";
+ .set_dispatch<VirtualDeviceNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = ref.as<VirtualDeviceNode>();
+ p->stream << "VirtualDevice(";
if (node->IsFullyUnconstrained()) {
p->stream << "?";
} else {
@@ -65,12 +66,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});
-SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target,
- MemoryScope memory_scope) {
+VirtualDevice::VirtualDevice(DLDeviceType device_type, int virtual_device_id, Target target,
+ MemoryScope memory_scope) {
ICHECK(!target.defined() || device_type == target->kind->device_type)
<< "target " << target->ToDebugString() << " has device type " << target->kind->device_type
- << " but scope has device type " << device_type;
- auto node = make_object<SEScopeNode>();
+ << " but virtual device has device type " << device_type;
+ auto node = make_object<VirtualDeviceNode>();
node->device_type_int = device_type;
node->virtual_device_id = virtual_device_id;
node->target = std::move(target);
@@ -78,13 +79,13 @@ SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target,
data_ = std::move(node);
}
-/* static */ SEScope SEScope::FullyUnconstrained() {
- static const SEScope unconstrained{};
+/* static */ VirtualDevice VirtualDevice::FullyUnconstrained() {
+ static const VirtualDevice unconstrained{};
return unconstrained;
}
/* static */
-Optional<SEScope> SEScope::Join(const SEScope& lhs, const SEScope& rhs) {
+Optional<VirtualDevice> VirtualDevice::Join(const VirtualDevice& lhs, const VirtualDevice& rhs) {
if (lhs == rhs) {
return lhs;
}
@@ -124,11 +125,12 @@ Optional<SEScope> SEScope::Join(const SEScope& lhs, const SEScope& rhs) {
} else {
joined_memory_scope = rhs->memory_scope;
}
- return SEScope(joined_device_type, joined_virtual_device_id, joined_target, joined_memory_scope);
+ return VirtualDevice(joined_device_type, joined_virtual_device_id, joined_target,
+ joined_memory_scope);
}
/* static */
-SEScope SEScope::Default(const SEScope& lhs, const SEScope& rhs) {
+VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevice& rhs) {
if (lhs == rhs) {
return lhs;
}
@@ -160,13 +162,14 @@ SEScope SEScope::Default(const SEScope& lhs, const SEScope& rhs) {
} else {
defaulted_memory_scope = rhs->memory_scope;
}
- return SEScope(defaulted_device_type, defaulted_virtual_device_id, defaulted_target,
- defaulted_memory_scope);
+ return VirtualDevice(defaulted_device_type, defaulted_virtual_device_id, defaulted_target,
+ defaulted_memory_scope);
}
-SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Target target,
- MemoryScope memory_scope) {
- SEScope prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope));
+VirtualDevice VirtualDeviceCache::Make(DLDeviceType device_type, int virtual_device_id,
+ Target target, MemoryScope memory_scope) {
+ VirtualDevice prototype(device_type, virtual_device_id, std::move(target),
+ std::move(memory_scope));
auto itr = cache_.find(prototype);
if (itr == cache_.end()) {
cache_.emplace(prototype);
@@ -180,11 +183,12 @@ SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Targ
}
}
-SEScope SEScopeCache::Unique(const SEScope& scope) {
- return Make(scope->device_type(), scope->virtual_device_id, scope->target, scope->memory_scope);
+VirtualDevice VirtualDeviceCache::Unique(const VirtualDevice& virtual_device) {
+ return Make(virtual_device->device_type(), virtual_device->virtual_device_id,
+ virtual_device->target, virtual_device->memory_scope);
}
-TVM_REGISTER_GLOBAL("target.SEScope_ForDeviceTargetAndMemoryScope")
- .set_body_typed(SEScope::ForDeviceTargetAndMemoryScope);
+TVM_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope")
+ .set_body_typed(VirtualDevice::ForDeviceTargetAndMemoryScope);
} // namespace tvm
diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc
index 8412cb8..ae2b950 100644
--- a/src/tir/analysis/device_constraint_utils.cc
+++ b/src/tir/analysis/device_constraint_utils.cc
@@ -32,7 +32,7 @@
#include "./device_constraint_utils.h"
#include <tvm/relay/attrs/memory.h>
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
@@ -104,11 +104,11 @@ void CheckNoRemainingPointerParams(const tir::PrimFunc& prim_func,
* using \p prim_func parameters at or after \p *current_primfunc_param_index. Currently
* only memory scope is extracted. Fails if constraints are not consistent, ie \p type is a tuple
* type and the \p prim_func is attempting to map different fields of that tuple to different memory
- * scopes. Returns the fully unconstrained \p SEScope if no memory scopes constraints arise from
- * the \p prim_func, ie all storage scope strings in pointer types are empty.
+ * scopes. Returns the fully unconstrained \p VirtualDevice if no memory scopes constraints arise
+ * from the \p prim_func, ie all storage scope strings in pointer types are empty.
*/
-SEScope ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& type,
- size_t* current_primfunc_param_index) {
+VirtualDevice ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& type,
+ size_t* current_primfunc_param_index) {
std::string memory_scope; // default empty => no constraint
for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) {
std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func, current_primfunc_param_index);
@@ -120,25 +120,26 @@ SEScope ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& ty
} else if (buffer_memory_scope.empty()) {
// No constraint.
} else {
- // Tuples must be homogenous on their SEScope and thus memory scope.
+ // Tuples must be homogenous on their VirtualDevice and thus memory scope.
ICHECK_EQ(buffer_memory_scope, memory_scope);
}
++*current_primfunc_param_index;
}
- return SEScope::ForMemoryScope(memory_scope);
+ return VirtualDevice::ForMemoryScope(memory_scope);
}
/*!
* \brief Insert into param_constraints an entry for each parameter of \p prim_func starting from
* \p *current_primfunc_param_index for the flattened form of a Rleay parameters of \p type. Each
- * entry maps to \p se_scope.
+ * entry maps to \p virtual_device.
*/
-void InsertParamConstraints(const tir::PrimFunc& prim_func, const Type& type,
- const SEScope& se_scope, size_t* current_primfunc_param_index,
- std::unordered_map<const tir::VarNode*, SEScope>* param_constraints) {
+void InsertParamConstraints(
+ const tir::PrimFunc& prim_func, const Type& type, const VirtualDevice& virtual_device,
+ size_t* current_primfunc_param_index,
+ std::unordered_map<const tir::VarNode*, VirtualDevice>* param_constraints) {
for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) {
std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func, current_primfunc_param_index);
- param_constraints->emplace(kv.first.get(), se_scope);
+ param_constraints->emplace(kv.first.get(), virtual_device);
++*current_primfunc_param_index;
}
}
@@ -186,22 +187,22 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
* memory scopes needed to change.
*/
PrimFunc Rewrite(const PrimFunc& prim_func, const FuncType& relay_func_type,
- const Array<SEScope>& arg_and_result_se_scopes) {
+ const Array<VirtualDevice>& arg_and_result_virtual_devices) {
size_t current_primfunc_param_index = 0;
- std::unordered_map<const tir::VarNode*, SEScope> param_constraints;
+ std::unordered_map<const tir::VarNode*, VirtualDevice> param_constraints;
// For each Relay function parameter...
for (size_t i = 0; i < relay_func_type->arg_types.size(); ++i) {
const Type& param_type = relay_func_type->arg_types[i];
- const SEScope& param_se_scope = arg_and_result_se_scopes[i];
- InsertParamConstraints(prim_func, param_type, param_se_scope, ¤t_primfunc_param_index,
- ¶m_constraints);
+ const VirtualDevice& param_virtual_device = arg_and_result_virtual_devices[i];
+ InsertParamConstraints(prim_func, param_type, param_virtual_device,
+ ¤t_primfunc_param_index, ¶m_constraints);
}
// For the Relay function result...
const Type& ret_type = relay_func_type->ret_type;
- const SEScope& ret_se_scope = arg_and_result_se_scopes.back();
- InsertParamConstraints(prim_func, ret_type, ret_se_scope, ¤t_primfunc_param_index,
+ const VirtualDevice& ret_virtual_device = arg_and_result_virtual_devices.back();
+ InsertParamConstraints(prim_func, ret_type, ret_virtual_device, ¤t_primfunc_param_index,
¶m_constraints);
// Make sure we accounted for all prim_func parameters.
@@ -214,10 +215,10 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
// For each constrained parameter...
for (const auto& kv : param_constraints) {
const tir::Var param = GetRef<tir::Var>(kv.first);
- const SEScope& se_scope = kv.second;
+ const VirtualDevice& virtual_device = kv.second;
const tir::Buffer& buffer = prim_func->buffer_map[param];
// Rewrite the buffer to account for constraint.
- const Buffer new_buffer = RewriteBuffer(buffer, se_scope);
+ const Buffer new_buffer = RewriteBuffer(buffer, virtual_device);
if (!new_buffer.same_as(buffer)) {
any_change = true;
}
@@ -357,10 +358,10 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
BufferRegion new_source = VisitItem(match_buffer_region_node->source.get());
// The buffer field however is a definitional occurrence, aliased on top of the source.
// Transfer any memory scope from the source to the destination.
- Optional<SEScope> opt_se_scope = GetBufferConstraint(new_source->buffer);
+ Optional<VirtualDevice> opt_virtual_device = GetBufferConstraint(new_source->buffer);
tir::Buffer new_buffer;
- if (opt_se_scope.defined()) {
- new_buffer = RewriteBuffer(match_buffer_region_node->buffer, opt_se_scope.value());
+ if (opt_virtual_device.defined()) {
+ new_buffer = RewriteBuffer(match_buffer_region_node->buffer, opt_virtual_device.value());
} else {
new_buffer = match_buffer_region_node->buffer;
}
@@ -407,21 +408,21 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
}
/*!
- * \brief Rewrites \p buffer so as to follow the constraints in \p se_scope
+ * \brief Rewrites \p buffer so as to follow the constraints in \p virtual_device
* (currently just memory scope).
*
* Updates both the var_subst_ and buffer_subst_ to capture the rewrite, but
* also returns the new buffer.
*/
- Buffer RewriteBuffer(const Buffer& buffer, const SEScope& se_scope) {
+ Buffer RewriteBuffer(const Buffer& buffer, const VirtualDevice& virtual_device) {
ICHECK(buffer->data->type_annotation.defined());
const auto* pointer_type_node = buffer->data->type_annotation.as<PointerTypeNode>();
ICHECK(pointer_type_node);
- if (pointer_type_node->storage_scope == se_scope->memory_scope) {
+ if (pointer_type_node->storage_scope == virtual_device->memory_scope) {
// No change.
return buffer;
}
- PointerType new_pointer_type(pointer_type_node->element_type, se_scope->memory_scope);
+ PointerType new_pointer_type(pointer_type_node->element_type, virtual_device->memory_scope);
Var new_data(buffer->data->name_hint, new_pointer_type, buffer->data->span);
var_subst_.emplace(buffer->data.get(), new_data);
Buffer new_buffer(new_data, buffer->dtype, buffer->shape, buffer->strides, buffer->elem_offset,
@@ -432,14 +433,15 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
}
/*!
- * \brief Returns the SEScope capturing any memory scope in \p buffer. Returns nullptr if
+ * \brief Returns the VirtualDevice capturing any memory scope in \p buffer. Returns nullptr if
* buffer's data var does not have a type annotation of \p PointerType. Returns the fully
- * unconstrained \p SEScope if no memory scope is given.
+ * unconstrained \p VirtualDevice if no memory scope is given.
*/
- static Optional<SEScope> GetBufferConstraint(const tir::Buffer& buffer) {
+ static Optional<VirtualDevice> GetBufferConstraint(const tir::Buffer& buffer) {
const auto* pointer_type_node = PointerInBuffer(buffer);
- return pointer_type_node == nullptr ? Optional<SEScope>()
- : SEScope::ForMemoryScope(pointer_type_node->storage_scope);
+ return pointer_type_node == nullptr
+ ? Optional<VirtualDevice>()
+ : VirtualDevice::ForMemoryScope(pointer_type_node->storage_scope);
}
/*!
@@ -455,59 +457,60 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
} // namespace
-Array<SEScope> GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func,
- const FuncType& relay_func_type) {
+Array<VirtualDevice> GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func,
+ const FuncType& relay_func_type) {
// Build the implied domain (in terms of the function's Relay type) implied by any memory scope
// constrains in the function's buffers, for both arguments and results.
- Array<SEScope> se_scopes;
- se_scopes.reserve(relay_func_type->arg_types.size() + 1);
+ Array<VirtualDevice> virtual_devices;
+ virtual_devices.reserve(relay_func_type->arg_types.size() + 1);
// For each Relay function parameter...
size_t current_primfunc_param_index = 0;
for (const auto& param_type : relay_func_type->arg_types) {
- SEScope param_se_scope =
+ VirtualDevice param_virtual_device =
ConsistentParamConstraint(prim_func, param_type, ¤t_primfunc_param_index);
- se_scopes.push_back(param_se_scope);
+ virtual_devices.push_back(param_virtual_device);
}
// For the Relay function result...
const Type& ret_type = relay_func_type->ret_type;
- SEScope ret_se_scope =
+ VirtualDevice ret_virtual_device =
ConsistentParamConstraint(prim_func, ret_type, ¤t_primfunc_param_index);
- se_scopes.push_back(ret_se_scope);
+ virtual_devices.push_back(ret_virtual_device);
// Make sure all parameters of the prim_func have been accounted for.
CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index);
- return se_scopes;
+ return virtual_devices;
}
TVM_REGISTER_GLOBAL("tir.analysis.GetPrimFuncArgAndResultMemoryConstraints")
.set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type) {
Array<String> memory_scopes;
memory_scopes.reserve(relay_func_type->type_params.size() + 1);
- for (const auto& se_scope : GetPrimFuncArgAndResultConstraints(prim_func, relay_func_type)) {
- memory_scopes.push_back(se_scope->memory_scope);
+ for (const auto& virtual_device :
+ GetPrimFuncArgAndResultConstraints(prim_func, relay_func_type)) {
+ memory_scopes.push_back(virtual_device->memory_scope);
}
return memory_scopes;
});
-PrimFunc ApplyPrimFuncArgAndResultConstraints(const PrimFunc& prim_func,
- const FuncType& relay_func_type,
- const Array<SEScope>& arg_and_result_se_scopes) {
+PrimFunc ApplyPrimFuncArgAndResultConstraints(
+ const PrimFunc& prim_func, const FuncType& relay_func_type,
+ const Array<VirtualDevice>& arg_and_result_virtual_devices) {
return ApplyDeviceConstraintsMutator().Rewrite(prim_func, relay_func_type,
- arg_and_result_se_scopes);
+ arg_and_result_virtual_devices);
}
TVM_REGISTER_GLOBAL("tir.analysis.ApplyPrimFuncArgAndResultMemoryConstraints")
.set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type,
const Array<String>& arg_and_result_memory_scopes) {
- Array<SEScope> se_scopes;
- se_scopes.reserve(arg_and_result_memory_scopes.size());
+ Array<VirtualDevice> virtual_devices;
+ virtual_devices.reserve(arg_and_result_memory_scopes.size());
for (const auto& memory_scope : arg_and_result_memory_scopes) {
- se_scopes.push_back(SEScope::ForMemoryScope(memory_scope));
+ virtual_devices.push_back(VirtualDevice::ForMemoryScope(memory_scope));
}
- return ApplyPrimFuncArgAndResultConstraints(prim_func, relay_func_type, se_scopes);
+ return ApplyPrimFuncArgAndResultConstraints(prim_func, relay_func_type, virtual_devices);
});
} // namespace tir
diff --git a/src/tir/analysis/device_constraint_utils.h b/src/tir/analysis/device_constraint_utils.h
index be0f199..717bf52 100644
--- a/src/tir/analysis/device_constraint_utils.h
+++ b/src/tir/analysis/device_constraint_utils.h
@@ -23,21 +23,21 @@
* parameters.
*
* These utilities are used by the \p PlanDevices pass to extract memory (aka 'storage') scope
- * information from \p PrimFuncs and convert them back into \p SEScope form w.r.t. the original
- * Relay type of the \p PrimFunc (ie before flattening of tuple arguments/results and conversion
- * to destination-passing style aka DPS).
+ * information from \p PrimFuncs and convert them back into \p VirtualDevice form w.r.t. the
+ * original Relay type of the \p PrimFunc (ie before flattening of tuple arguments/results and
+ * conversion to destination-passing style aka DPS).
*
* A utility is also supplied to go the other way: impose memory scopes on \p PrimFunc parameters.
* However that's still in EXPERIMENTAL form.
*
* We may extend these utilities to also gather/apply layout information should we add that to
- * \p SEScope.
+ * \p VirtualDevice.
*/
#ifndef TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_
#define TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_
-#include <tvm/target/se_scope.h>
+#include <tvm/target/virtual_device.h>
#include <tvm/tir/function.h>
namespace tvm {
@@ -71,26 +71,26 @@ namespace tir {
*/
/*!
- * \brief Returns the \p SEScopes capturing the memory (aka storage) scope constraints for all the
- * arguments and result of \p prim_func. However the result will be w.r.t. the \p prim_func's
+ * \brief Returns the \p VirtualDevices capturing the memory (aka storage) scope constraints for all
+ * the arguments and result of \p prim_func. However the result will be w.r.t. the \p prim_func's
* representation as a Relay \p Function of \p relay_func_type_ before lowering and conversion to
* DPS.
*/
-Array<SEScope> GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func,
- const FuncType& relay_func_type);
+Array<VirtualDevice> GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func,
+ const FuncType& relay_func_type);
/*
* \brief Returns \p prim_func written to capture the memory (aka storage) scope constraints
- * for each of the \p prim_func's parameters given by \p arg_and_result_se_scopes. However,
- * \p arg_and_result_se_scopes should be w.r.t. the \p prim_func's representation as a Relay
+ * for each of the \p prim_func's parameters given by \p arg_and_result_virtual_devices. However,
+ * \p arg_and_result_virtual_devices should be w.r.t. the \p prim_func's representation as a Relay
* \p Function of \p relay_func_type before lowering and conversion to DPS.
*
* CAUTION: This is experimental. The resulting \p PrimFunc may not have fully accounted for all
* new memory scopes.
*/
-PrimFunc ApplyPrimFuncArgAndResultConstraints(const PrimFunc& prim_func,
- const FuncType& relay_func_type,
- const Array<SEScope>& arg_and_result_se_scopes);
+PrimFunc ApplyPrimFuncArgAndResultConstraints(
+ const PrimFunc& prim_func, const FuncType& relay_func_type,
+ const Array<VirtualDevice>& arg_and_result_virtual_devices);
} // namespace tir
} // namespace tvm
diff --git a/tests/cpp/relay/op/memory/on_device_test.cc b/tests/cpp/relay/op/memory/on_device_test.cc
index 45d4f88..6f0a0b0 100644
--- a/tests/cpp/relay/op/memory/on_device_test.cc
+++ b/tests/cpp/relay/op/memory/on_device_test.cc
@@ -30,22 +30,22 @@ TEST(OnDeviceOp, Name) { EXPECT_EQ(OnDeviceOp()->name, "on_device"); }
TEST(OnDevice, Default) {
Var body("x", {});
- SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3);
- Call call = OnDevice(body, se_scope);
+ VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3);
+ Call call = OnDevice(body, virtual_device);
EXPECT_EQ(call->op, OnDeviceOp());
EXPECT_EQ(call->args.size(), 1);
EXPECT_EQ(call->args[0], body);
const auto* attrs = call->attrs.as<OnDeviceAttrs>();
ASSERT_TRUE(attrs != nullptr);
- EXPECT_EQ(attrs->se_scope, se_scope);
+ EXPECT_EQ(attrs->virtual_device, virtual_device);
EXPECT_FALSE(attrs->constrain_result);
EXPECT_TRUE(attrs->constrain_body);
}
TEST(OnDevice, Fixed) {
Var body("x", {});
- SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3);
- Call call = OnDevice(body, se_scope, /*constrain_result=*/true);
+ VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3);
+ Call call = OnDevice(body, virtual_device, /*constrain_result=*/true);
const auto* attrs = call->attrs.as<OnDeviceAttrs>();
ASSERT_TRUE(attrs != nullptr);
EXPECT_TRUE(attrs->constrain_result);
@@ -54,8 +54,8 @@ TEST(OnDevice, Fixed) {
TEST(OnDevice, Free) {
Var body("x", {});
- SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3);
- Call call = OnDevice(body, se_scope, /*constrain_result=*/false, /*constrain_body=*/false);
+ VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3);
+ Call call = OnDevice(body, virtual_device, /*constrain_result=*/false, /*constrain_body=*/false);
const auto* attrs = call->attrs.as<OnDeviceAttrs>();
ASSERT_TRUE(attrs != nullptr);
EXPECT_FALSE(attrs->constrain_result);
@@ -64,23 +64,23 @@ TEST(OnDevice, Free) {
TEST(GetOnDeviceProps, Correct) {
Var body("x", {});
- SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3);
- Call call = OnDevice(body, se_scope, /*constrain_result=*/true, /*constrain_body=*/false);
+ VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3);
+ Call call = OnDevice(body, virtual_device, /*constrain_result=*/true, /*constrain_body=*/false);
OnDeviceProps props = GetOnDeviceProps(call);
ASSERT_TRUE(props.body.defined());
- ASSERT_EQ(props.se_scope, se_scope);
+ ASSERT_EQ(props.virtual_device, virtual_device);
ASSERT_TRUE(props.constrain_result);
ASSERT_FALSE(props.constrain_body);
}
TEST(MaybeOnDevice, Wrapped) {
- SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3);
+ VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3);
Var body("x", {});
- Call inner = OnDevice(body, se_scope);
- Call outer = OnDevice(inner, se_scope);
+ Call inner = OnDevice(body, virtual_device);
+ Call outer = OnDevice(inner, virtual_device);
OnDeviceProps props = GetOnDeviceProps(outer);
ASSERT_TRUE(props.body.defined());
- ASSERT_EQ(props.se_scope, se_scope);
+ ASSERT_EQ(props.virtual_device, virtual_device);
ASSERT_FALSE(props.constrain_result);
ASSERT_TRUE(props.constrain_body);
}
diff --git a/tests/cpp/relay/transforms/device_domains_test.cc b/tests/cpp/relay/transforms/device_domains_test.cc
index 5df7984..7314f64 100644
--- a/tests/cpp/relay/transforms/device_domains_test.cc
+++ b/tests/cpp/relay/transforms/device_domains_test.cc
@@ -45,8 +45,8 @@ IRModule TestModule() {
}
TEST(DeviceDomains, SmokeTest) {
- SEScope cpu = SEScope::ForDeviceType(kDLCPU);
- SEScope cuda = SEScope::ForDeviceType(kDLCUDA);
+ VirtualDevice cpu = VirtualDevice::ForDeviceType(kDLCPU);
+ VirtualDevice cuda = VirtualDevice::ForDeviceType(kDLCUDA);
TargetMap target_map;
target_map.Set(Integer(static_cast<int>(kDLCPU)), Target("llvm"));
target_map.Set(Integer(static_cast<int>(kDLCUDA)), Target("cuda"));
@@ -66,11 +66,11 @@ TEST(DeviceDomains, SmokeTest) {
arg_and_results.push_back(result_domain);
DeviceDomainPtr implied_add_domain = domains.MakeHigherOrderDomain(std::move(arg_and_results));
EXPECT_FALSE(domains.UnifyOrNull(actual_add_domain, implied_add_domain) == nullptr);
- EXPECT_FALSE(domains.UnifyOrNull(
- x_domain, domains.ForSEScope(f->params[0]->checked_type(), cuda)) == nullptr);
+ EXPECT_FALSE(domains.UnifyOrNull(x_domain, domains.ForVirtualDevice(f->params[0]->checked_type(),
+ cuda)) == nullptr);
- EXPECT_EQ(domains.ResultSEScope(y_domain), config->CanonicalSEScope(cuda));
- EXPECT_EQ(domains.ResultSEScope(result_domain), config->CanonicalSEScope(cuda));
+ EXPECT_EQ(domains.ResultVirtualDevice(y_domain), config->CanonicalVirtualDevice(cuda));
+ EXPECT_EQ(domains.ResultVirtualDevice(result_domain), config->CanonicalVirtualDevice(cuda));
}
} // namespace
diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc
index 31b9368..2b1041b 100644
--- a/tests/cpp/target/compilation_config_test.cc
+++ b/tests/cpp/target/compilation_config_test.cc
@@ -48,9 +48,9 @@ TEST(CompilationConfig, Constructor_Homogeneous_FallbackCPUHost) {
legacy_target_map.Set(Integer(static_cast<int>(kDLCUDA)), cuda_target);
CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{});
- SEScope expected_default_primitive_se_scope(kDLCUDA, 0,
- Target::WithHost(cuda_target, host_target));
- SEScope expected_host_se_scope(kDLCPU, 0, host_target);
+ VirtualDevice expected_default_primitive_virtual_device(
+ kDLCUDA, 0, Target::WithHost(cuda_target, host_target));
+ VirtualDevice expected_host_virtual_device(kDLCPU, 0, host_target);
ASSERT_EQ(config->legacy_target_map.size(), 1);
EXPECT_TRUE(StructuralEqual()((*config->legacy_target_map.begin()).second,
@@ -60,9 +60,9 @@ TEST(CompilationConfig, Constructor_Homogeneous_FallbackCPUHost) {
ASSERT_EQ(config->primitive_targets.size(), 1);
EXPECT_TRUE(
StructuralEqual()(config->primitive_targets[0], Target::WithHost(cuda_target, host_target)));
- EXPECT_TRUE(
- StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope));
- EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope));
+ EXPECT_TRUE(StructuralEqual()(config->default_primitive_virtual_device,
+ expected_default_primitive_virtual_device));
+ EXPECT_TRUE(StructuralEqual()(config->host_virtual_device, expected_host_virtual_device));
ASSERT_TRUE(config->optional_homogeneous_target.defined());
EXPECT_TRUE(StructuralEqual()(config->optional_homogeneous_target,
Target::WithHost(cuda_target, host_target)));
@@ -107,9 +107,9 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_FallbackCPUHost) {
Target::WithHost(cuda_target, host_target));
CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{});
- SEScope expected_default_primitive_se_scope(kDLCUDA, 0,
- Target::WithHost(cuda_target, host_target));
- SEScope expected_host_se_scope(kDLCPU, 0, host_target);
+ VirtualDevice expected_default_primitive_virtual_device(
+ kDLCUDA, 0, Target::WithHost(cuda_target, host_target));
+ VirtualDevice expected_host_virtual_device(kDLCPU, 0, host_target);
ASSERT_EQ(config->legacy_target_map.size(), 2);
for (const auto& pair : config->legacy_target_map) {
@@ -121,9 +121,9 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_FallbackCPUHost) {
}
EXPECT_TRUE(config->host_target.defined());
EXPECT_TRUE(StructuralEqual()(config->host_target, host_target));
- EXPECT_TRUE(
- StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope));
- EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope));
+ EXPECT_TRUE(StructuralEqual()(config->default_primitive_virtual_device,
+ expected_default_primitive_virtual_device));
+ EXPECT_TRUE(StructuralEqual()(config->host_virtual_device, expected_host_virtual_device));
EXPECT_FALSE(config->optional_homogeneous_target.defined());
}
@@ -140,9 +140,9 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_ExplicitHost) {
Target::WithHost(cuda_target, host_target));
CompilationConfig config(pass_ctx, legacy_target_map, host_target);
- SEScope expected_default_primitive_se_scope(kDLCUDA, 0,
- Target::WithHost(cuda_target, host_target));
- SEScope expected_host_se_scope(kDLCPU, 0, host_target);
+ VirtualDevice expected_default_primitive_virtual_device(
+ kDLCUDA, 0, Target::WithHost(cuda_target, host_target));
+ VirtualDevice expected_host_virtual_device(kDLCPU, 0, host_target);
ASSERT_EQ(config->legacy_target_map.size(), 2);
for (const auto& pair : config->legacy_target_map) {
@@ -155,9 +155,9 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_ExplicitHost) {
EXPECT_TRUE(config->host_target.defined());
EXPECT_TRUE(StructuralEqual()(config->host_target, host_target));
ASSERT_EQ(config->primitive_targets.size(), 2);
- EXPECT_TRUE(
- StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope));
- EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope));
+ EXPECT_TRUE(StructuralEqual()(config->default_primitive_virtual_device,
+ expected_default_primitive_virtual_device));
+ EXPECT_TRUE(StructuralEqual()(config->host_virtual_device, expected_host_virtual_device));
EXPECT_FALSE(config->optional_homogeneous_target.defined());
}
@@ -188,40 +188,40 @@ TEST(CompilationConfig, Constructor_DefaultNoMatchingPrimitiveTarget) {
CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}));
}
-TEST(CompilationConfig, CanonicalSEScope) {
+TEST(CompilationConfig, CanonicalVirtualDevice) {
Target host_target = TestDefaultCpuTarget();
Target cuda_target = TestCudaTarget();
Target cpu_target = TestCpuTarget();
CompilationConfig config = TestCompilationConfig();
{
- SEScope in = SEScope(kDLCPU);
- SEScope actual = config->CanonicalSEScope(in);
+ VirtualDevice in = VirtualDevice(kDLCPU);
+ VirtualDevice actual = config->CanonicalVirtualDevice(in);
ASSERT_TRUE(actual->target.defined());
EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cpu_target, host_target)));
- EXPECT_EQ(config->CanonicalSEScope(in), actual);
+ EXPECT_EQ(config->CanonicalVirtualDevice(in), actual);
}
{
- SEScope in = SEScope(kDLCUDA);
- SEScope actual = config->CanonicalSEScope(in);
+ VirtualDevice in = VirtualDevice(kDLCUDA);
+ VirtualDevice actual = config->CanonicalVirtualDevice(in);
ASSERT_TRUE(actual->target.defined());
EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cuda_target, host_target)));
- EXPECT_EQ(config->CanonicalSEScope(in), actual);
+ EXPECT_EQ(config->CanonicalVirtualDevice(in), actual);
}
}
-TEST(CompilationConfig, CanonicalSEScope_NoDevice) {
+TEST(CompilationConfig, CanonicalVirtualDevice_NoDevice) {
CompilationConfig config = TestCompilationConfig();
- SEScope fully_unconstrained;
- EXPECT_ANY_THROW(config->CanonicalSEScope(fully_unconstrained));
- SEScope missing_device(kInvalidDeviceType, 3, {}, "local");
- EXPECT_ANY_THROW(config->CanonicalSEScope(missing_device));
+ VirtualDevice fully_unconstrained;
+ EXPECT_ANY_THROW(config->CanonicalVirtualDevice(fully_unconstrained));
+ VirtualDevice missing_device(kInvalidDeviceType, 3, {}, "local");
+ EXPECT_ANY_THROW(config->CanonicalVirtualDevice(missing_device));
}
-TEST(CompilationConfig, CanonicalSEScope_NoMatchingTarget) {
+TEST(CompilationConfig, CanonicalVirtualDevice_NoMatchingTarget) {
CompilationConfig config = TestCompilationConfig();
- SEScope no_such_target(kDLMetal);
- EXPECT_ANY_THROW(config->CanonicalSEScope(no_such_target));
+ VirtualDevice no_such_target(kDLMetal);
+ EXPECT_ANY_THROW(config->CanonicalVirtualDevice(no_such_target));
}
} // namespace
diff --git a/tests/cpp/target/se_scope_test.cc b/tests/cpp/target/se_scope_test.cc
deleted file mode 100644
index 166ba46..0000000
--- a/tests/cpp/target/se_scope_test.cc
+++ /dev/null
@@ -1,119 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <gtest/gtest.h>
-#include <tvm/target/se_scope.h>
-#include <tvm/target/target.h>
-
-namespace tvm {
-namespace {
-
-TEST(SEScope, Join_Defined) {
- {
- Target target_a = Target("cuda");
- SEScope lhs = SEScope(kDLCUDA, 3);
- SEScope rhs = SEScope(kDLCUDA, -1, target_a, "global");
- Optional<SEScope> actual = SEScope::Join(lhs, rhs);
- EXPECT_TRUE(actual.operator bool());
- SEScope expected = SEScope(kDLCUDA, 3, target_a, "global");
- EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
- }
- {
- Target target_a = Target("cuda");
- SEScope lhs = SEScope(kDLCUDA, -1, target_a, "global");
- SEScope rhs = SEScope(kDLCUDA, 3);
- Optional<SEScope> actual = SEScope::Join(lhs, rhs);
- EXPECT_TRUE(actual.operator bool());
- SEScope expected = SEScope(kDLCUDA, 3, target_a, "global");
- EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
- }
- {
- Target target_a = Target("cuda");
- SEScope lhs = SEScope(kDLCUDA);
- SEScope rhs = SEScope(kDLCUDA, 2, target_a);
- Optional<SEScope> actual = SEScope::Join(lhs, rhs);
- EXPECT_TRUE(actual.operator bool());
- SEScope expected = SEScope(kDLCUDA, 2, target_a);
- EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
- }
- {
- Target target_a = Target("cuda");
- SEScope lhs = SEScope();
- SEScope rhs = SEScope(kDLCUDA, 3, target_a, "global");
- Optional<SEScope> actual = SEScope::Join(lhs, rhs);
- EXPECT_TRUE(actual.operator bool());
- SEScope expected = rhs;
- EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
- }
-}
-
-TEST(SEScope, Join_Undefined) {
- {
- SEScope lhs = SEScope(kDLCUDA);
- SEScope rhs = SEScope(kDLCPU);
- Optional<SEScope> actual = SEScope::Join(lhs, rhs);
- EXPECT_FALSE(actual);
- }
- {
- SEScope lhs = SEScope(kDLCUDA, 3);
- SEScope rhs = SEScope(kDLCUDA, 4);
- Optional<SEScope> actual = SEScope::Join(lhs, rhs);
- EXPECT_FALSE(actual);
- }
- {
- SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda"));
- SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda"));
- Optional<SEScope> actual = SEScope::Join(lhs, rhs);
- EXPECT_FALSE(actual);
- }
- {
- SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda"), "local");
- SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda"), "global");
- Optional<SEScope> actual = SEScope::Join(lhs, rhs);
- EXPECT_FALSE(actual);
- }
-}
-
-TEST(SEScope, Default) {
- Target target_a = Target("cuda");
- SEScope lhs = SEScope(kDLCUDA, -1, Target(), "global");
- SEScope rhs = SEScope(kDLCUDA, 3, target_a, "local");
- SEScope actual = SEScope::Default(lhs, rhs);
- SEScope expected = SEScope(kDLCUDA, 3, target_a, "global");
- EXPECT_TRUE(StructuralEqual()(actual, expected));
-}
-
-TEST(SEScope, Constructor_Invalid) { EXPECT_ANY_THROW(SEScope(kDLCPU, -1, Target("cuda"))); }
-
-TEST(SEScopeCache, Memoized) {
- SEScopeCache cache;
- Target target_a = Target("cuda");
- Target target_b = Target("llvm");
- SEScope se_scope_a = cache.Make(kDLCUDA, 3, target_a, "local");
- SEScope se_scope_b = cache.Make(kDLCPU, 1, target_b, "global");
-
- EXPECT_EQ(cache.Make(kDLCUDA, 3, target_a, "local"), se_scope_a);
- EXPECT_EQ(cache.Make(kDLCPU, 1, target_b, "global"), se_scope_b);
- EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), se_scope_a);
- EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), se_scope_a);
- EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), se_scope_a);
-}
-
-} // namespace
-} // namespace tvm
diff --git a/tests/cpp/target/virtual_device_test.cc b/tests/cpp/target/virtual_device_test.cc
new file mode 100644
index 0000000..35e0787
--- /dev/null
+++ b/tests/cpp/target/virtual_device_test.cc
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/target/target.h>
+#include <tvm/target/virtual_device.h>
+
+namespace tvm {
+namespace {
+
+TEST(VirtualDevice, Join_Defined) {
+ {
+ Target target_a = Target("cuda");
+ VirtualDevice lhs = VirtualDevice(kDLCUDA, 3);
+ VirtualDevice rhs = VirtualDevice(kDLCUDA, -1, target_a, "global");
+ Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
+ EXPECT_TRUE(actual.operator bool());
+ VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global");
+ EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
+ }
+ {
+ Target target_a = Target("cuda");
+ VirtualDevice lhs = VirtualDevice(kDLCUDA, -1, target_a, "global");
+ VirtualDevice rhs = VirtualDevice(kDLCUDA, 3);
+ Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
+ EXPECT_TRUE(actual.operator bool());
+ VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global");
+ EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
+ }
+ {
+ Target target_a = Target("cuda");
+ VirtualDevice lhs = VirtualDevice(kDLCUDA);
+ VirtualDevice rhs = VirtualDevice(kDLCUDA, 2, target_a);
+ Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
+ EXPECT_TRUE(actual.operator bool());
+ VirtualDevice expected = VirtualDevice(kDLCUDA, 2, target_a);
+ EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
+ }
+ {
+ Target target_a = Target("cuda");
+ VirtualDevice lhs = VirtualDevice();
+ VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, target_a, "global");
+ Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
+ EXPECT_TRUE(actual.operator bool());
+ VirtualDevice expected = rhs;
+ EXPECT_TRUE(StructuralEqual()(actual.value(), expected));
+ }
+}
+
+TEST(VirtualDevice, Join_Undefined) {
+ {
+ VirtualDevice lhs = VirtualDevice(kDLCUDA);
+ VirtualDevice rhs = VirtualDevice(kDLCPU);
+ Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
+ EXPECT_FALSE(actual);
+ }
+ {
+ VirtualDevice lhs = VirtualDevice(kDLCUDA, 3);
+ VirtualDevice rhs = VirtualDevice(kDLCUDA, 4);
+ Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
+ EXPECT_FALSE(actual);
+ }
+ {
+ VirtualDevice lhs = VirtualDevice(kDLCUDA, 3, Target("cuda"));
+ VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, Target("cuda"));
+ Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
+ EXPECT_FALSE(actual);
+ }
+ {
+ VirtualDevice lhs = VirtualDevice(kDLCUDA, 3, Target("cuda"), "local");
+ VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, Target("cuda"), "global");
+ Optional<VirtualDevice> actual = VirtualDevice::Join(lhs, rhs);
+ EXPECT_FALSE(actual);
+ }
+}
+
+TEST(VirtualDevice, Default) {
+ Target target_a = Target("cuda");
+ VirtualDevice lhs = VirtualDevice(kDLCUDA, -1, Target(), "global");
+ VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, target_a, "local");
+ VirtualDevice actual = VirtualDevice::Default(lhs, rhs);
+ VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global");
+ EXPECT_TRUE(StructuralEqual()(actual, expected));
+}
+
+TEST(VirtualDevice, Constructor_Invalid) {
+ EXPECT_ANY_THROW(VirtualDevice(kDLCPU, -1, Target("cuda")));
+}
+
+TEST(VirtualDeviceCache, Memoized) {
+ VirtualDeviceCache cache;
+ Target target_a = Target("cuda");
+ Target target_b = Target("llvm");
+ VirtualDevice virtual_device_a = cache.Make(kDLCUDA, 3, target_a, "local");
+ VirtualDevice virtual_device_b = cache.Make(kDLCPU, 1, target_b, "global");
+
+ EXPECT_EQ(cache.Make(kDLCUDA, 3, target_a, "local"), virtual_device_a);
+ EXPECT_EQ(cache.Make(kDLCPU, 1, target_b, "global"), virtual_device_b);
+ EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), virtual_device_a);
+ EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), virtual_device_a);
+ EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), virtual_device_a);
+}
+
+} // namespace
+} // namespace tvm
diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py
index 5ad2a59..2352821 100644
--- a/tests/python/relay/op/annotation/test_annotation.py
+++ b/tests/python/relay/op/annotation/test_annotation.py
@@ -26,10 +26,10 @@ def test_on_device_via_string():
assert isinstance(call, relay.Call)
assert len(call.args) == 1
assert call.args[0] == x
- assert call.attrs.se_scope.device_type_int == 2 # ie kDLCUDA
- assert call.attrs.se_scope.virtual_device_id == 0
- assert call.attrs.se_scope.target is None
- assert call.attrs.se_scope.memory_scope == ""
+ assert call.attrs.virtual_device.device_type_int == 2 # ie kDLCUDA
+ assert call.attrs.virtual_device.virtual_device_id == 0
+ assert call.attrs.virtual_device.target is None
+ assert call.attrs.virtual_device.memory_scope == ""
assert call.attrs.constrain_body
assert not call.attrs.constrain_result
@@ -37,7 +37,7 @@ def test_on_device_via_string():
def test_on_device_via_device():
x = relay.Var("x")
call = relay.annotation.on_device(x, tvm.device("cpu"))
- assert call.attrs.se_scope.device_type_int == 1 # ie kDLCPU
+ assert call.attrs.virtual_device.device_type_int == 1 # ie kDLCPU
def test_on_device_invalid_device():
@@ -48,7 +48,7 @@ def test_on_device_invalid_device():
def test_on_device_fixed():
x = relay.Var("x")
call = relay.annotation.on_device(x, "cuda", constrain_result=True)
- assert call.attrs.se_scope.device_type_int == 2 # ie kDLCUDA
+ assert call.attrs.virtual_device.device_type_int == 2 # ie kDLCUDA
assert call.attrs.constrain_body
assert call.attrs.constrain_result
@@ -56,7 +56,7 @@ def test_on_device_fixed():
def test_on_device_free():
x = relay.Var("x")
call = relay.annotation.on_device(x, "cuda", constrain_result=False, constrain_body=False)
- assert call.attrs.se_scope.device_type_int == -1 # ie kInvalidDeviceType
+ assert call.attrs.virtual_device.device_type_int == -1 # ie kInvalidDeviceType
assert not call.attrs.constrain_body
assert not call.attrs.constrain_result
@@ -67,10 +67,10 @@ def test_function_on_device():
f = relay.Function([x, y], relay.add(x, y))
func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda")
assert isinstance(func, relay.Function)
- assert len(func.attrs["param_se_scopes"]) == 2
- assert func.attrs["param_se_scopes"][0].device_type_int == 1 # ie kDLCPU
- assert func.attrs["param_se_scopes"][1].device_type_int == 2 # ie kDLCUDA
- assert func.attrs["result_se_scope"].device_type_int == 2 # ie KDLCUDA
+ assert len(func.attrs["param_virtual_devices"]) == 2
+ assert func.attrs["param_virtual_devices"][0].device_type_int == 1 # ie kDLCPU
+ assert func.attrs["param_virtual_devices"][1].device_type_int == 2 # ie kDLCUDA
+ assert func.attrs["result_virtual_device"].device_type_int == 2 # ie KDLCUDA
if __name__ == "__main__":
diff --git a/tests/python/relay/op/test_tensor.py b/tests/python/relay/op/test_tensor.py
index 4d2c176..2d561cf 100644
--- a/tests/python/relay/op/test_tensor.py
+++ b/tests/python/relay/op/test_tensor.py
@@ -26,14 +26,14 @@ def test_device_copy_via_string():
assert isinstance(call, relay.Call)
assert len(call.args) == 1
assert call.args[0] == x
- assert call.attrs.src_se_scope.device_type_int == 2 # ie kDLCUDA
- assert call.attrs.src_se_scope.virtual_device_id == 0
- assert call.attrs.src_se_scope.target is None
- assert call.attrs.src_se_scope.memory_scope == ""
- assert call.attrs.dst_se_scope.device_type_int == 1 # ie kDLCPU
- assert call.attrs.dst_se_scope.virtual_device_id == 0
- assert call.attrs.dst_se_scope.target is None
- assert call.attrs.dst_se_scope.memory_scope == ""
+ assert call.attrs.src_virtual_device.device_type_int == 2 # ie kDLCUDA
+ assert call.attrs.src_virtual_device.virtual_device_id == 0
+ assert call.attrs.src_virtual_device.target is None
+ assert call.attrs.src_virtual_device.memory_scope == ""
+ assert call.attrs.dst_virtual_device.device_type_int == 1 # ie kDLCPU
+ assert call.attrs.dst_virtual_device.virtual_device_id == 0
+ assert call.attrs.dst_virtual_device.target is None
+ assert call.attrs.dst_virtual_device.memory_scope == ""
def test_device_copy_via_device():
@@ -42,8 +42,8 @@ def test_device_copy_via_device():
assert isinstance(call, relay.Call)
assert len(call.args) == 1
assert call.args[0] == x
- assert call.attrs.src_se_scope.device_type_int == 2 # ie kDLCUDA
- assert call.attrs.dst_se_scope.device_type_int == 1 # ie kDLCPU
+ assert call.attrs.src_virtual_device.device_type_int == 2 # ie kDLCUDA
+ assert call.attrs.dst_virtual_device.device_type_int == 1 # ie kDLCPU
if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py
index 3893da4..bc19bcd 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -19,8 +19,8 @@ from tvm.relay import Function, transform
from tvm.relay.testing import inception_v3
import pytest
-cpu_scope = tvm.target.make_se_scope(tvm.cpu(), tvm.target.Target("llvm"))
-metatable = {"SEScope": [cpu_scope]}
+cpu_scope = tvm.target.make_virtual_device(tvm.cpu(), tvm.target.Target("llvm"))
+metatable = {"VirtualDevice": [cpu_scope]}
core = tvm.IRModule()
core.import_from_std("core.rly")
@@ -234,7 +234,7 @@ def test_impure_op():
def @main() {
let %size: int64 = cast(1024, dtype="int64");
let %alignment: int64 = cast(64, dtype="int64");
- let %x = memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][0]);
+ let %x = memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][0]);
0
}
""",
@@ -249,7 +249,7 @@ def test_impure_op():
def @main() {
let %x = memory.alloc_storage(cast(1024, dtype="int64"),
cast(64, dtype="int64"),
- se_scope=meta[SEScope][0]);
+ virtual_device=meta[VirtualDevice][0]);
0
}
""",
@@ -271,7 +271,7 @@ def test_impure_func():
def @f() -> int {
let %size: int64 = cast(1024, dtype="int64");
let %alignment: int64 = cast(64, dtype="int64");
- let %x = memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][0]);
+ let %x = memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][0]);
0
}
def @main() -> int {
@@ -290,7 +290,7 @@ def test_impure_func():
def @f() -> int {
let %x = memory.alloc_storage(cast(1024, dtype="int64"),
cast(64, dtype="int64"),
- se_scope=meta[SEScope][0]);
+ virtual_device=meta[VirtualDevice][0]);
0
}
def @main() -> int {
diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py
index ee9cfc9..82e40af 100644
--- a/tests/python/relay/test_pass_plan_devices.py
+++ b/tests/python/relay/test_pass_plan_devices.py
@@ -41,13 +41,13 @@ TARGETS = {
tvm.tir.IntImm("int32", GPU_DEVICE.device_type): GPU_TARGET,
}
-HOST = tvm.target.make_se_scope(HOST_DEVICE, HOST_TARGET) # device_type=1
-CPU = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET) # device_type=1
-GPU = tvm.target.make_se_scope(GPU_DEVICE, GPU_TARGET) # device_type=2
+HOST = tvm.target.make_virtual_device(HOST_DEVICE, HOST_TARGET) # device_type=1
+CPU = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET) # device_type=1
+GPU = tvm.target.make_virtual_device(GPU_DEVICE, GPU_TARGET) # device_type=2
DEFAULT = GPU
-CPU_SCOPE_A = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET, memory_scope="scopeA")
-CPU_SCOPE_B = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET, memory_scope="scopeB")
+CPU_SCOPE_A = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET, memory_scope="scopeA")
+CPU_SCOPE_B = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET, memory_scope="scopeB")
CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int})
@@ -109,7 +109,7 @@ def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, a
def test_plain():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# Everything defaults to GPU
def input():
@@ -134,8 +134,8 @@ def test_plain():
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][1], meta[SEScope][1]],
- result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][1]) {
%0 = add(%a, %b);
%1 = add(%c, %d);
subtract(%0, %1)
@@ -153,7 +153,7 @@ def test_plain():
def test_left_add_on_cpu():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# Force some args to be on CPU, rest default to GPU.
def input():
@@ -163,7 +163,7 @@ def test_left_add_on_cpu():
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) {
%0 = add(%a, %b);
- %1 = on_device(%0, se_scope=meta[SEScope][0]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
%2 = add(%c, %d);
subtract(%1, %2)
}
@@ -179,11 +179,11 @@ def test_left_add_on_cpu():
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]],
- result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][1]) {
%0 = add(%a, %b);
- %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
- %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%3 = add(%c, %d);
subtract(%2, %3)
}
@@ -200,7 +200,7 @@ def test_left_add_on_cpu():
def test_left_add_on_cpu_via_copy():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# As for test_left_add_on_cpu, but with an explicit device_copy.
def input():
@@ -210,7 +210,7 @@ def test_left_add_on_cpu_via_copy():
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) {
%0 = add(%a, %b);
- %1 = device_copy(%0, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %1 = device_copy(%0, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%2 = add(%c, %d);
subtract(%1, %2)
}
@@ -226,11 +226,11 @@ def test_left_add_on_cpu_via_copy():
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]],
- result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][1]) {
%0 = add(%a, %b);
- %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
- %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%3 = add(%c, %d);
subtract(%2, %3)
}
@@ -247,7 +247,7 @@ def test_left_add_on_cpu_via_copy():
def test_both_adds_on_cpu():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -257,8 +257,8 @@ def test_both_adds_on_cpu():
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) {
%0 = add(%a, %b);
%1 = add(%c, %d);
- %2 = on_device(%0, se_scope=meta[SEScope][0]);
- %3 = on_device(%1, se_scope=meta[SEScope][0]);
+ %2 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
+ %3 = on_device(%1, virtual_device=meta[VirtualDevice][0]);
subtract(%2, %3)
}
""",
@@ -273,14 +273,14 @@ def test_both_adds_on_cpu():
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]],
- result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]],
+ result_virtual_device=meta[VirtualDevice][1]) {
%0 = add(%a, %b);
- %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
%2 = add(%c, %d);
- %3 = on_device(%2, se_scope=meta[SEScope][0], constrain_result=True);
- %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
- %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %3 = on_device(%2, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %4 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
+ %5 = device_copy(%3, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
subtract(%4, %5)
}
""",
@@ -296,7 +296,7 @@ def test_both_adds_on_cpu():
def test_sharing():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# The same add sub-expression is annotated twice.
def input():
@@ -305,8 +305,8 @@ def test_sharing():
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) {
%0 = add(%a, %b);
- %1 = on_device(%0, se_scope=meta[SEScope][0]);
- %2 = on_device(%0, se_scope=meta[SEScope][0]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
+ %2 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
subtract(%1, %2)
}
""",
@@ -320,12 +320,12 @@ def test_sharing():
"""
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) {
%0 = add(%a, %b);
- %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
- %2 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
- %3 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
- %4 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %2 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %3 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
+ %4 = device_copy(%2, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
subtract(%3, %4)
}
""",
@@ -342,7 +342,7 @@ def test_sharing():
def test_let_on_cpu():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# The device for a let-bound expression can flow from uses of the let-bound var.
def input():
@@ -353,7 +353,7 @@ def test_let_on_cpu():
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) {
let %l = add(%a, %b);
let %r = add(%c, %d);
- %0 = on_device(%l, se_scope=meta[SEScope][0]);
+ %0 = on_device(%l, virtual_device=meta[VirtualDevice][0]);
subtract(%0, %r)
}
""",
@@ -368,12 +368,12 @@ def test_let_on_cpu():
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]],
- result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][1]) {
%0 = add(%a, %b);
- let %l = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
- let %r = on_device(add(%c, %d), se_scope=meta[SEScope][1], constrain_result=True);
- %1 = device_copy(%l, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ let %l = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ let %r = on_device(add(%c, %d), virtual_device=meta[VirtualDevice][1], constrain_result=True);
+ %1 = device_copy(%l, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
subtract(%1, %r)
}
""",
@@ -389,7 +389,7 @@ def test_let_on_cpu():
def test_func_param_on_cpu():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# Devices for function parameters flow to call sites.
def input():
@@ -400,7 +400,7 @@ def test_func_param_on_cpu():
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) {
let %f = fn (%x, %y) {
%0 = add(%x, %y);
- on_device(%0, se_scope=meta[SEScope][0])
+ on_device(%0, virtual_device=meta[VirtualDevice][0])
};
%1 = %f(%a, %b);
%2 = add(%c, %d);
@@ -418,10 +418,10 @@ def test_func_param_on_cpu():
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]],
- result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]],
+ result_virtual_device=meta[VirtualDevice][0]) {
let %f = fn (%x, %y,
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) {
add(%x, %y)
};
%0 = %f(%a, %b);
@@ -441,7 +441,7 @@ def test_func_param_on_cpu():
def test_func_result_on_cpu():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# Devices for call sites flow to function results.
def input():
@@ -454,7 +454,7 @@ def test_func_result_on_cpu():
add(%x, %y)
};
%0 = %f(%a, %b);
- %1 = on_device(%0, se_scope=meta[SEScope][0]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
%2 = add(%c, %d);
subtract(%1, %2)
}
@@ -470,15 +470,15 @@ def test_func_result_on_cpu():
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]],
- result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][1]) {
let %f = fn (%x, %y,
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) {
add(%x, %y)
};
%1 = %f(%a, %b);
- %2 = on_device(%1, se_scope=meta[SEScope][0], constrain_result=True);
- %3 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %2 = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %3 = device_copy(%2, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%4 = add(%c, %d);
subtract(%3, %4)
}
@@ -495,7 +495,7 @@ def test_func_result_on_cpu():
def test_higher_order():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# The constraint on %a flows back to %y via %f and %h
def input():
@@ -505,7 +505,7 @@ def test_higher_order():
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) {
let %f = fn (%g) {
fn (%a) {
- %0 = on_device(%a, se_scope=meta[SEScope][0]);
+ %0 = on_device(%a, virtual_device=meta[VirtualDevice][0]);
%1 = %g(%0);
add(%1, %x)
}
@@ -528,15 +528,15 @@ def test_higher_order():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) {
- let %f = fn (%g, param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) {
- fn (%a, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) {
- %0 = device_copy(%a, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) {
+ let %f = fn (%g, param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) {
+ fn (%a, param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) {
+ %0 = device_copy(%a, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%1 = %g(%0);
add(%1, %x)
}
};
- let %h = fn (%b, param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) {
+ let %h = fn (%b, param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) {
negative(%b)
};
%2 = %f(%h);
@@ -562,7 +562,7 @@ def test_higher_order():
def test_function_in_tuple():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# Since %f ends up in a tuple its argument and result is forced to be on the CPU
def input():
@@ -571,7 +571,7 @@ def test_function_in_tuple():
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) {
let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) {
- %0 = on_device(%b, se_scope=meta[SEScope][0]);
+ %0 = on_device(%b, virtual_device=meta[VirtualDevice][0]);
add(%a, %0)
};
let %t = (%f, %x);
@@ -590,12 +590,12 @@ def test_function_in_tuple():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) {
let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) {
add(%a, %b)
};
- let %t = on_device((%f, %x), se_scope=meta[SEScope][0], constrain_result=True);
+ let %t = on_device((%f, %x), virtual_device=meta[VirtualDevice][0], constrain_result=True);
%0 = %t.1;
%1 = %t.0;
%1(%0, %y)
@@ -614,14 +614,14 @@ def test_function_in_tuple():
def test_device_copy():
const = rand((5, 7))
- metatable = {"SEScope": [CPU, GPU], "relay.Constant": [relay.const(const)]}
+ metatable = {"VirtualDevice": [CPU, GPU], "relay.Constant": [relay.const(const)]}
def input():
return tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32]) {
- %0 = device_copy(%x, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %0 = device_copy(%x, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
add(%0, meta[relay.Constant][0])
}
""",
@@ -635,8 +635,8 @@ def test_device_copy():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) {
- %0 = device_copy(%x, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) {
+ %0 = device_copy(%x, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
add(%0, meta[relay.Constant][0])
}
""",
@@ -652,7 +652,7 @@ def test_device_copy():
def test_shape_of():
- metatable = {"SEScope": [HOST, GPU]}
+ metatable = {"VirtualDevice": [HOST, GPU]}
# We need to use constrain_result=True in the on_device call so that the tensor will be on the GPU. Otherwise the
# result defaults to the result device for @main which is the CPU, thus forcing a copy.
@@ -662,7 +662,7 @@ def test_shape_of():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(?, ?), float32]) {
- %0 = on_device(%x, se_scope=meta[SEScope][1], constrain_result=True);
+ %0 = on_device(%x, virtual_device=meta[VirtualDevice][1], constrain_result=True);
vm.shape_of(%0, dtype="int64")
}
""",
@@ -676,7 +676,7 @@ def test_shape_of():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(?, ?), float32],
- param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][0]) {
vm.shape_of(%x, dtype="int64")
}
""",
@@ -692,14 +692,14 @@ def test_shape_of():
def test_alloc_storage():
- metatable = {"SEScope": [HOST, GPU]}
+ metatable = {"VirtualDevice": [HOST, GPU]}
def input():
return tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%size: int64, %alignment: int64) {
- memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][1])
+ memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][1])
}
""",
"from_string",
@@ -712,8 +712,8 @@ def test_alloc_storage():
"""
#[version = "0.0.5"]
def @main(%size: int64, %alignment: int64,
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) {
- memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][1])
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) {
+ memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][1])
}
""",
"from_string",
@@ -727,7 +727,10 @@ def test_alloc_storage():
def test_alloc_tensor():
shape = np.array([3, 2])
- metatable = {"SEScope": [HOST, GPU], "relay.Constant": [relay.const(shape, dtype="int64")]}
+ metatable = {
+ "VirtualDevice": [HOST, GPU],
+ "relay.Constant": [relay.const(shape, dtype="int64")],
+ }
def input():
return tvm.parser.parse(
@@ -747,9 +750,9 @@ def test_alloc_tensor():
return tvm.parser.parse(
"""
#[version = "0.0.5"]
- def @main(%sto: Storage[], param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) {
- %0 = on_device(0, se_scope=meta[SEScope][0], constrain_result=True);
- %1 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], constrain_result=True);
+ def @main(%sto: Storage[], param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) {
+ %0 = on_device(0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %1 = on_device(meta[relay.Constant][0], virtual_device=meta[VirtualDevice][0], constrain_result=True);
memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[])
}
""",
@@ -764,7 +767,10 @@ def test_alloc_tensor():
def test_reshape_tensor():
newshape = [2, 4, 2]
- metatable = {"SEScope": [HOST, GPU], "relay.Constant": [relay.const(newshape, dtype="int64")]}
+ metatable = {
+ "VirtualDevice": [HOST, GPU],
+ "relay.Constant": [relay.const(newshape, dtype="int64")],
+ }
def input():
return tvm.parser.parse(
@@ -784,8 +790,8 @@ def test_reshape_tensor():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(2, 8), float32],
- param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) {
- %0 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], constrain_result=True);
+ param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) {
+ %0 = on_device(meta[relay.Constant][0], virtual_device=meta[VirtualDevice][0], constrain_result=True);
vm.reshape_tensor(%x, %0, newshape=[2, 4, 2])
}
""",
@@ -801,7 +807,7 @@ def test_reshape_tensor():
def test_dynamic_input():
- metatable = {"SEScope": [GPU]}
+ metatable = {"VirtualDevice": [GPU]}
# There's nothing special about inferring devices for partially unknown types.
def input():
@@ -822,7 +828,7 @@ def test_dynamic_input():
"""
#[version = "0.0.5"]
def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) {
add(%x0, %x1)
}
""",
@@ -838,7 +844,7 @@ def test_dynamic_input():
def test_redundant_annotation():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -846,9 +852,9 @@ def test_redundant_annotation():
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) {
%0 = add(%x, %y);
- %1 = on_device(%0, se_scope=meta[SEScope][0]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
%2 = subtract(%1, %z);
- %3 = on_device(%0, se_scope=meta[SEScope][0]);
+ %3 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
add(%2, %3)
}
""",
@@ -862,14 +868,14 @@ def test_redundant_annotation():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1]],
- result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][1]) {
%0 = add(%x, %y);
- %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
- %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
- %3 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
+ %3 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
%4 = subtract(%2, %z);
- %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %5 = device_copy(%3, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
add(%4, %5)
}
""",
@@ -886,7 +892,7 @@ def test_redundant_annotation():
def test_annotate_expr():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -894,9 +900,9 @@ def test_annotate_expr():
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) {
%0 = add(%x, %y);
- %1 = on_device(%0, se_scope=meta[SEScope][1]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][1]);
%2 = subtract(%1, %z);
- on_device(%2, se_scope=meta[SEScope][0])
+ on_device(%2, virtual_device=meta[VirtualDevice][0])
}
""",
"from_string",
@@ -909,11 +915,11 @@ def test_annotate_expr():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][0]],
- result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][0]],
+ result_virtual_device=meta[VirtualDevice][0]) {
%0 = add(%x, %y);
- %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True);
- %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True);
+ %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]);
subtract(%2, %z)
}
""",
@@ -929,7 +935,7 @@ def test_annotate_expr():
def test_annotate_all():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -937,9 +943,9 @@ def test_annotate_all():
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) {
%0 = add(%x, %y);
- %1 = on_device(%0, se_scope=meta[SEScope][0]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
%2 = subtract(%1, %z);
- on_device(%2, se_scope=meta[SEScope][0])
+ on_device(%2, virtual_device=meta[VirtualDevice][0])
}
""",
"from_string",
@@ -952,8 +958,8 @@ def test_annotate_all():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]],
- result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]],
+ result_virtual_device=meta[VirtualDevice][0]) {
%0 = add(%x, %y);
subtract(%0, %z)
}
@@ -982,7 +988,7 @@ def test_conv_network():
|
<result> <--- CPU
"""
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -992,12 +998,12 @@ def test_conv_network():
%weight: Tensor[(64, 64, 3, 3), float32]) {
%0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]);
%1 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]);
- %2 = on_device(%0, se_scope=meta[SEScope][0]);
- %3 = on_device(%1, se_scope=meta[SEScope][0]);
+ %2 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
+ %3 = on_device(%1, virtual_device=meta[VirtualDevice][0]);
%4 = add(%2, %3);
- %5 = on_device(%4, se_scope=meta[SEScope][1]);
+ %5 = on_device(%4, virtual_device=meta[VirtualDevice][1]);
%6 = nn.conv2d(%5, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]);
- on_device(%6, se_scope=meta[SEScope][0])
+ on_device(%6, virtual_device=meta[VirtualDevice][0])
}
""",
"from_string",
@@ -1011,17 +1017,17 @@ def test_conv_network():
#[version = "0.0.5"]
def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32],
%weight: Tensor[(64, 64, 3, 3), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]],
- result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]],
+ result_virtual_device=meta[VirtualDevice][0]) {
%0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]);
- %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
%2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]);
- %3 = on_device(%2, se_scope=meta[SEScope][0], constrain_result=True);
- %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
- %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %3 = on_device(%2, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %4 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
+ %5 = device_copy(%3, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%6 = add(%4, %5);
- %7 = on_device(%6, se_scope=meta[SEScope][1], constrain_result=True);
- %8 = device_copy(%7, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]);
+ %7 = on_device(%6, virtual_device=meta[VirtualDevice][1], constrain_result=True);
+ %8 = device_copy(%7, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]);
nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3])
}
""",
@@ -1035,7 +1041,7 @@ def test_conv_network():
def test_tuple_get_item():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# Note that the device copy should be placed after projection rather than before. This is handled by
# a heuristic in the pass.
@@ -1045,12 +1051,12 @@ def test_tuple_get_item():
#[version = "0.0.5"]
def @main(%x: Tensor[(3, 3, 4), float32]) {
let %t = split(%x, indices_or_sections=3);
- %0 = on_device(%t, se_scope=meta[SEScope][0]);
- %1 = on_device(%t, se_scope=meta[SEScope][0]);
+ %0 = on_device(%t, virtual_device=meta[VirtualDevice][0]);
+ %1 = on_device(%t, virtual_device=meta[VirtualDevice][0]);
%2 = %0.0;
%3 = %1.1;
%4 = subtract(%2, %3);
- on_device(%4, se_scope=meta[SEScope][1])
+ on_device(%4, virtual_device=meta[VirtualDevice][1])
}
""",
"from_string",
@@ -1063,15 +1069,15 @@ def test_tuple_get_item():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(3, 3, 4), float32],
- param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) {
%0 = split(%x, indices_or_sections=3);
- let %t = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
+ let %t = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
%1 = %t.0;
- %2 = on_device(%1, se_scope=meta[SEScope][0], constrain_result=True);
+ %2 = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True);
%3 = %t.1;
- %4 = on_device(%3, se_scope=meta[SEScope][0], constrain_result=True);
- %5 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
- %6 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %4 = on_device(%3, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %5 = device_copy(%2, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
+ %6 = device_copy(%4, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
subtract(%5, %6)
}
""",
@@ -1101,7 +1107,7 @@ def test_propogation():
|
<result> <--- CPU
"""
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -1109,16 +1115,16 @@ def test_propogation():
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32]) {
%0 = negative(%x);
- %1 = on_device(%0, se_scope=meta[SEScope][0]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
%2 = negative(%1);
- %3 = on_device(%0, se_scope=meta[SEScope][0]);
+ %3 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
%4 = negative(%3);
- %5 = on_device(%2, se_scope=meta[SEScope][1]);
- %6 = on_device(%4, se_scope=meta[SEScope][1]);
+ %5 = on_device(%2, virtual_device=meta[VirtualDevice][1]);
+ %6 = on_device(%4, virtual_device=meta[VirtualDevice][1]);
%7 = add(%5, %6);
- %8 = on_device(%7, se_scope=meta[SEScope][1]);
+ %8 = on_device(%7, virtual_device=meta[VirtualDevice][1]);
%9 = negative(%8);
- on_device(%9, se_scope=meta[SEScope][0])
+ on_device(%9, virtual_device=meta[VirtualDevice][0])
}
""",
"from_string",
@@ -1131,17 +1137,17 @@ def test_propogation():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) {
%0 = negative(%x);
- %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
- %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
- %3 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True);
- %4 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
+ %3 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %4 = device_copy(%3, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%5 = negative(%2);
%6 = negative(%4);
%7 = add(%5, %6);
- %8 = on_device(%7, se_scope=meta[SEScope][1], constrain_result=True);
- %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]);
+ %8 = on_device(%7, virtual_device=meta[VirtualDevice][1], constrain_result=True);
+ %9 = device_copy(%8, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]);
negative(%9)
}
""",
@@ -1173,7 +1179,7 @@ def test_fusible_network():
|
<result> <--- CPU
"""
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -1181,14 +1187,14 @@ def test_fusible_network():
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) {
%0 = add(%x, %y);
- %1 = on_device(%0, se_scope=meta[SEScope][1]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][1]);
%2 = negative(%1);
- %3 = on_device(%2, se_scope=meta[SEScope][0]);
+ %3 = on_device(%2, virtual_device=meta[VirtualDevice][0]);
%4 = negative(%0);
%5 = add(%3, %4);
- %6 = on_device(%5, se_scope=meta[SEScope][1]);
+ %6 = on_device(%5, virtual_device=meta[VirtualDevice][1]);
%7 = negative(%6);
- on_device(%7, se_scope=meta[SEScope][0])
+ on_device(%7, virtual_device=meta[VirtualDevice][0])
}
""",
"from_string",
@@ -1201,17 +1207,17 @@ def test_fusible_network():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][0]) {
%0 = add(%x, %y);
- %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True);
- %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True);
+ %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]);
%3 = negative(%2);
- %4 = on_device(%3, se_scope=meta[SEScope][0], constrain_result=True);
- %5 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ %4 = on_device(%3, virtual_device=meta[VirtualDevice][0], constrain_result=True);
+ %5 = device_copy(%4, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%6 = negative(%0);
%7 = add(%5, %6);
- %8 = on_device(%7, se_scope=meta[SEScope][1], constrain_result=True);
- %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]);
+ %8 = on_device(%7, virtual_device=meta[VirtualDevice][1], constrain_result=True);
+ %9 = device_copy(%8, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]);
negative(%9)
}
""",
@@ -1241,7 +1247,7 @@ def test_unpropagatable_graph():
|
<result> <--- CPU
"""
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -1251,10 +1257,10 @@ def test_unpropagatable_graph():
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) {
%0 = add(%a, %b);
%1 = multiply(%c, %d);
- %2 = on_device(%0, se_scope=meta[SEScope][0]);
- %3 = on_device(%1, se_scope=meta[SEScope][1]);
+ %2 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
+ %3 = on_device(%1, virtual_device=meta[VirtualDevice][1]);
%4 = subtract(%2, %3);
- on_device(%4, se_scope=meta[SEScope][0])
+ on_device(%4, virtual_device=meta[VirtualDevice][0])
}
""",
"from_string",
@@ -1268,12 +1274,12 @@ def test_unpropagatable_graph():
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]],
- result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][0]) {
%0 = multiply(%c, %d);
- %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True);
%2 = add(%a, %b);
- %3 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]);
+ %3 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]);
subtract(%2, %3)
}
""",
@@ -1289,7 +1295,7 @@ def test_unpropagatable_graph():
def test_conditional():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
# The conditional is over a function type, thus exercising the first-order/higher-order domain handling.
def input():
@@ -1298,7 +1304,7 @@ def test_conditional():
#[version = "0.0.5"]
def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) {
let %f = fn (%a) {
- %0 = on_device(%y, se_scope=meta[SEScope][0], constrain_result=True);
+ %0 = on_device(%y, virtual_device=meta[VirtualDevice][0], constrain_result=True);
add(%a, %0)
};
let %g = fn (%a1) {
@@ -1322,19 +1328,19 @@ def test_conditional():
"""
#[version = "0.0.5"]
def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]],
- result_se_scope=meta[SEScope][0]) {
- let %f = fn (%a, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]],
+ result_virtual_device=meta[VirtualDevice][0]) {
+ let %f = fn (%a, param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) {
add(%a, %y)
};
- let %g = fn (%a1, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) {
+ let %g = fn (%a1, param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) {
subtract(%a1, %y)
};
let %h = on_device(if (%x) {
%f
} else {
%g
- }, se_scope=meta[SEScope][0], constrain_result=True);
+ }, virtual_device=meta[VirtualDevice][0], constrain_result=True);
%h(%z)
}
""",
@@ -1357,14 +1363,14 @@ def test_conditional():
def test_global():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
"""
#[version = "0.0.5"]
def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] {
- %0 = on_device(%b, se_scope=meta[SEScope][0]);
+ %0 = on_device(%b, virtual_device=meta[VirtualDevice][0]);
add(%a, %0)
}
@@ -1382,15 +1388,15 @@ def test_global():
"""
#[version = "0.0.5"]
def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][1], meta[SEScope][0]],
- result_se_scope=meta[SEScope][1]) -> Tensor[(5, 7), float32] {
- %0 = device_copy(%b, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
+ param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][0]],
+ result_virtual_device=meta[VirtualDevice][1]) -> Tensor[(5, 7), float32] {
+ %0 = device_copy(%b, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
add(%a, %0)
}
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][1]],
- result_se_scope=meta[SEScope][1]) -> Tensor[(5, 7), float32] {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][1]) -> Tensor[(5, 7), float32] {
@f(%y, %x)
}
""",
@@ -1409,7 +1415,7 @@ def test_global():
def test_ref():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -1417,7 +1423,7 @@ def test_ref():
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) {
let %r = ref(%x);
- %0 = on_device(%y, se_scope=meta[SEScope][0]);
+ %0 = on_device(%y, virtual_device=meta[VirtualDevice][0]);
ref_write(%r, %0);
%1 = ref_read(%r);
add(%x, %1)
@@ -1433,10 +1439,10 @@ def test_ref():
"""
#[version = "0.0.5"]
def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) {
- let %r = on_device(ref(%x), se_scope=meta[SEScope][1], constrain_result=True);
- %0 = device_copy(%y, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]);
- on_device(ref_write(%r, %0), se_scope=meta[SEScope][1], constrain_result=True);
+ param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) {
+ let %r = on_device(ref(%x), virtual_device=meta[VirtualDevice][1], constrain_result=True);
+ %0 = device_copy(%y, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
+ on_device(ref_write(%r, %0), virtual_device=meta[VirtualDevice][1], constrain_result=True);
%1 = ref_read(%r);
add(%x, %1)
}
@@ -1456,7 +1462,7 @@ def test_ref():
def test_adt():
- metatable = {"SEScope": [CPU, GPU]}
+ metatable = {"VirtualDevice": [CPU, GPU]}
def input():
return tvm.parser.parse(
@@ -1467,7 +1473,7 @@ def test_adt():
Nil,
}
def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) {
- %0 = on_device(%y, se_scope=meta[SEScope][0], constrain_result=True);
+ %0 = on_device(%y, virtual_device=meta[VirtualDevice][0], constrain_result=True);
%1 = Nil;
%2 = Cons(%0, %1);
let %l = Cons(%x, %2);
@@ -1490,10 +1496,10 @@ def test_adt():
Nil,
}
def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) {
%0 = Nil;
%1 = Cons(%y, %0);
- let %l = on_device(Cons(%x, %1), se_scope=meta[SEScope][0], constrain_result=True);
+ let %l = on_device(Cons(%x, %1), virtual_device=meta[VirtualDevice][0], constrain_result=True);
match? (%l) {
Cons(%z, _) => %z
}
@@ -1516,7 +1522,7 @@ def test_free_on_device():
a device_copy to be inserted if necessary, but otherwise does not prevent the flow of
device information."""
metatable = {
- "SEScope": [
+ "VirtualDevice": [
CPU, # no memory scope constraint
CPU_SCOPE_A, # constrain to scopeA
CPU_SCOPE_B,
@@ -1529,22 +1535,22 @@ def test_free_on_device():
"""
#[version = "0.0.5"]
def @on_scope_b(%x: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][2]],
- result_se_scope=meta[SEScope][2]) -> Tensor[(5, 7), float32] {
+ param_virtual_devices=[meta[VirtualDevice][2]],
+ result_virtual_device=meta[VirtualDevice][2]) -> Tensor[(5, 7), float32] {
%x
}
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][1], meta[SEScope][2]],
- result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][2]],
+ result_virtual_device=meta[VirtualDevice][1]) {
// %a's memory scope is unconstrained, so will take on "scopeB" and on_device has no effect
- %0 = @on_scope_b(on_device(%a, se_scope=meta[SEScope][0], constrain_body=False));
+ %0 = @on_scope_b(on_device(%a, virtual_device=meta[VirtualDevice][0], constrain_body=False));
// %b's memory scope is "scopeA", so will require a "scopeA"->"scopeB" copy.
- %1 = @on_scope_b(on_device(%b, se_scope=meta[SEScope][0], constrain_body=False));
+ %1 = @on_scope_b(on_device(%b, virtual_device=meta[VirtualDevice][0], constrain_body=False));
// %c's memory scope is "scopeB", so no copy required.
- %2 = @on_scope_b(on_device(%c, se_scope=meta[SEScope][0], constrain_body=False));
+ %2 = @on_scope_b(on_device(%c, virtual_device=meta[VirtualDevice][0], constrain_body=False));
// result's memory scope is is on "scopeA", so will require a "scopeB"->"scopeA" copy.
%3 = add(add(%0, %1), %2);
- on_device(%3, se_scope=meta[SEScope][0], constrain_body=False)
+ on_device(%3, virtual_device=meta[VirtualDevice][0], constrain_body=False)
}
""",
"from_string",
@@ -1557,20 +1563,20 @@ def test_free_on_device():
"""
#[version = "0.0.5"]
def @on_scope_b(%x: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][2]],
- result_se_scope=meta[SEScope][2]) -> Tensor[(5, 7), float32] {
+ param_virtual_devices=[meta[VirtualDevice][2]],
+ result_virtual_device=meta[VirtualDevice][2]) -> Tensor[(5, 7), float32] {
%x
}
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32],
- param_se_scopes=[meta[SEScope][2], meta[SEScope][1], meta[SEScope][2]],
- result_se_scope=meta[SEScope][1]) {
+ param_virtual_devices=[meta[VirtualDevice][2], meta[VirtualDevice][1], meta[VirtualDevice][2]],
+ result_virtual_device=meta[VirtualDevice][1]) {
%0 = @on_scope_b(%a);
- %1 = device_copy(%b, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]);
+ %1 = device_copy(%b, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][2]);
%2 = @on_scope_b(%1);
%3 = @on_scope_b(%c);
%4 = add(add(%0, %2), %3);
- %5 = on_device(%4, se_scope=meta[SEScope][2], constrain_result=True);
- device_copy(%5, src_se_scope=meta[SEScope][2], dst_se_scope=meta[SEScope][1])
+ %5 = on_device(%4, virtual_device=meta[VirtualDevice][2], constrain_result=True);
+ device_copy(%5, src_virtual_device=meta[VirtualDevice][2], dst_virtual_device=meta[VirtualDevice][1])
}
""",
"from_string",
@@ -1616,12 +1622,12 @@ def test_lowered():
D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk]
metatable = {
- "SEScope": [
- CPU, # meta[SEScope][0], no memory scope
- CPU_SCOPE_A, # meta[SEScope][1], "scopeA"
+ "VirtualDevice": [
+ CPU, # meta[VirtualDevice][0], no memory scope
+ CPU_SCOPE_A, # meta[VirtualDevice][1], "scopeA"
CPU_SCOPE_B,
]
- } # meta[SEScope][2], "scopeB"
+ } # meta[VirtualDevice][2], "scopeB"
gem_ty = relay.FuncType(
[
relay.TensorType((128, 128), "float32"),
@@ -1645,8 +1651,8 @@ def test_lowered():
def @main(%x : Tensor[(128, 128), float32],
%y : Tensor[(128, 128), float32],
%z : Tensor[(128, 128), float32],
- param_se_scopes=[meta[SEScope][0], meta[SEScope][2], meta[SEScope][1]],
- result_se_scope=meta[SEScope][2]) {
+ param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][2], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][2]) {
call_lowered(@gem, (%x, %y, %z))
}
""",
@@ -1668,13 +1674,13 @@ def test_lowered():
def @main(%x : Tensor[(128, 128), float32],
%y : Tensor[(128, 128), float32],
%z : Tensor[(128, 128), float32],
- param_se_scopes=[meta[SEScope][1], meta[SEScope][2], meta[SEScope][1]],
- result_se_scope=meta[SEScope][2]) {
- %0 = device_copy(%z, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]);
- %1 = on_device(%0, se_scope=meta[SEScope][2], constrain_result=True);
+ param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][2], meta[VirtualDevice][1]],
+ result_virtual_device=meta[VirtualDevice][2]) {
+ %0 = device_copy(%z, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][2]);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][2], constrain_result=True);
%2 = call_lowered(@gem, (%x, %y, %1));
- %3 = on_device(%2, se_scope=meta[SEScope][1], constrain_result=True);
- device_copy(%3, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2])
+ %3 = on_device(%2, virtual_device=meta[VirtualDevice][1], constrain_result=True);
+ device_copy(%3, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][2])
}
""",
"from_string",
diff --git a/tests/python/target/test_se_scope.py b/tests/python/target/test_virtual_device.py
similarity index 54%
rename from tests/python/target/test_se_scope.py
rename to tests/python/target/test_virtual_device.py
index 0a9384f..eec77bc 100644
--- a/tests/python/target/test_se_scope.py
+++ b/tests/python/target/test_virtual_device.py
@@ -19,30 +19,30 @@ import pytest
import tvm
-def test_make_se_scope_for_device():
- se_scope = tvm.target.make_se_scope(tvm.device("cuda"))
- assert se_scope.device_type == 2
+def test_make_virtual_device_for_device():
+ virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"))
+ assert virtual_device.device_type == 2
# ie kDLCUDA
- assert se_scope.virtual_device_id == 0
- assert se_scope.target is None
- assert se_scope.memory_scope == ""
+ assert virtual_device.virtual_device_id == 0
+ assert virtual_device.target is None
+ assert virtual_device.memory_scope == ""
-def test_make_se_scope_for_device_and_target():
+def test_make_virtual_device_for_device_and_target():
target = tvm.target.Target("cuda")
- se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target)
- assert se_scope.device_type == 2 # ie kDLCUDA
- assert se_scope.target == target
- assert se_scope.memory_scope == ""
+ virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"), target)
+ assert virtual_device.device_type == 2 # ie kDLCUDA
+ assert virtual_device.target == target
+ assert virtual_device.memory_scope == ""
-def test_make_se_scope_for_device_target_and_memory_scope():
+def test_make_virtual_device_for_device_target_and_memory_scope():
target = tvm.target.Target("cuda")
scope = "local"
- se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target, scope)
- assert se_scope.device_type == 2 # ie kDLCUDA
- assert se_scope.target == target
- assert se_scope.memory_scope == scope
+ virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"), target, scope)
+ assert virtual_device.device_type == 2 # ie kDLCUDA
+ assert virtual_device.target == target
+ assert virtual_device.memory_scope == scope
if __name__ == "__main__":