You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/01/29 15:17:00 UTC
[tvm] branch main updated: [Refactor][VM] Port memory_alloc to c++
(#7369)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 44a071a [Refactor][VM] Port memory_alloc to c++ (#7369)
44a071a is described below
commit 44a071aa1e9ad11c20fbfcf725ddb6dd8a2823c4
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Fri Jan 29 07:16:45 2021 -0800
[Refactor][VM] Port memory_alloc to c++ (#7369)
* Port memory_alloc to c++
* remove memory python pass
---
include/tvm/relay/transform.h | 12 +
python/tvm/relay/__init__.py | 1 -
python/tvm/relay/transform/__init__.py | 1 -
python/tvm/relay/transform/memory_alloc.py | 389 -----------------------
src/relay/backend/vm/compiler.cc | 6 -
src/relay/transforms/memory_alloc.cc | 494 +++++++++++++++++++++++++++++
tests/python/relay/test_any.py | 1 -
tests/python/relay/test_memory_passes.py | 1 -
8 files changed, 506 insertions(+), 399 deletions(-)
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index e4b39da..123b7e3 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -31,6 +31,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/container.h>
+#include <tvm/target/target.h>
#include <string>
@@ -419,6 +420,17 @@ TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
*/
TVM_DLL Pass SimplifyExpr();
+/*!
+ * \brief A pass for manifesting explicit memory allocations and rewriting
+ * specific dialects.
+ *
+ * \param target_host The target used by the host for compliation.
+ * \param targets The device type and target pairs for compliation.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets);
+
} // namespace transform
/*!
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 97f6d1c..89c8fcb 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -61,7 +61,6 @@ from . import qnn
from .scope_builder import ScopeBuilder
# Load Memory Passes
-from .transform import memory_alloc
from .transform import memory_plan
# Required to traverse large programs
diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py
index 1d0ea17..ca9996a 100644
--- a/python/tvm/relay/transform/__init__.py
+++ b/python/tvm/relay/transform/__init__.py
@@ -19,4 +19,3 @@
# transformation passes
from .transform import *
from .recast import recast
-from . import memory_alloc
diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py
deleted file mode 100644
index 66528c8..0000000
--- a/python/tvm/relay/transform/memory_alloc.py
+++ /dev/null
@@ -1,389 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
-"""
-A pass for manifesting explicit memory allocations.
-"""
-import numpy as np
-
-from tvm.ir.transform import PassContext, module_pass
-from tvm.relay.transform import InferType
-from tvm import nd, container
-from ..function import Function
-from ..expr_functor import ExprVisitor, ExprMutator
-from ..scope_builder import ScopeBuilder
-from .. import op
-from ... import DataType, register_func
-from .. import ty, expr
-from ..backend import compile_engine
-from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type
-from ... import cpu
-from ..op.memory import alloc_storage
-from ..analysis import context_analysis
-from ..._ffi.runtime_ctypes import TVMContext
-
-
-def alloc_tensor(storage, shape, dtype="float32", assert_shape=None):
- offset = expr.const(0, dtype="int64")
- return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape)
-
-
-def is_primitive(call):
- return (
- hasattr(call, "op")
- and hasattr(call.op, "attrs")
- and hasattr(call.op.attrs, "Primitive")
- and int(call.op.attrs.Primitive) == 1
- )
-
-
-def is_device_copy(func):
- """
- Check if the current relay expression is a device copy call. We can simply check
- the body of it if it is a function becase the device_copy op is opaque.
- """
- if isinstance(func, Function):
- body = func.body
- return isinstance(body, expr.Call) and body.op == op.get("device_copy")
- if isinstance(func, expr.Call):
- return func.op == op.get("device_copy")
- return False
-
-
-class CheckReshapeOnly(ExprVisitor):
- """A pass to check if the fused op contains only reshape ops."""
-
- def __init__(self):
- super().__init__()
- self._reshape_ops = [
- op.get("reshape"),
- op.get("contrib_reverse_reshape"),
- op.get("dyn.reshape"),
- ]
- self.reshape_only = True
-
- def visit_call(self, call):
- if not self.reshape_only:
- return
- if call.op not in self._reshape_ops:
- self.reshape_only = False
- for arg in call.args:
- self.visit(arg)
-
- def visit_var(self, var):
- var_type = var.checked_type
- if not isinstance(var_type, ty.TensorType):
- self.reshape_only = False
-
-
-def is_reshape_only(func):
- """Check if the primitive function contains only reshape ops."""
- check = CheckReshapeOnly()
- check.visit(func)
- return check.reshape_only
-
-
-class ManifestAllocPass(ExprMutator):
- """A pass for explicitly manifesting all memory allocations in Relay."""
-
- def __init__(self, target_host, context_analysis_map):
- self.invoke_tvm = op.vm.invoke_tvm_op
- self.shape_func = op.vm.shape_func
- self.shape_of = op.vm.shape_of
- self.reshape_tensor = op.vm.reshape_tensor
- self.scopes = [ScopeBuilder()]
- self.target_host = target_host
- self.default_context = cpu(0)
- self.compute_dtype = "int64"
- self.context_analysis_map = context_analysis_map
- super().__init__()
-
- def get_context(self, exp):
- """Get the context of a given expression"""
- assert exp in self.context_analysis_map, exp.astext(False)
- val = self.context_analysis_map[exp]
- # val[0], val[1] are device_type and device_id, respectively.
- # We don't need to unpack after porting this pass to C++.
- assert len(val) == 2
- return TVMContext(val[0].value, val[1].value)
-
- def device_copy(self, inp, src_ctx, dst_ctx):
- """Insert a device copy node."""
- return self.visit(op.tensor.device_copy(inp, src_ctx, dst_ctx))
-
- def current_scope(self):
- return self.scopes[-1]
-
- def visit_tuple(self, tup):
- scope = self.current_scope()
- new_fields = []
- for field in tup.fields:
- field = self.visit(field)
- if isinstance(field, expr.Constant):
- field = scope.let("const", field)
- new_fields.append(field)
- return expr.Tuple(new_fields)
-
- def compute_alignment(self, dtype):
- dtype = DataType(dtype)
- align = (dtype.bits // 8) * dtype.lanes
- # MAGIC CONSTANT FROM device_api.h
- if align < 64:
- align = 64
-
- return expr.const(align, dtype="int64")
-
- def compute_storage_in_relay(self, shape, dtype):
- dtype = DataType(dtype)
- els = op.prod(shape)
- num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype)
- num = num + expr.const(7, self.compute_dtype)
- div = expr.const(8, self.compute_dtype)
- return els * (num / div)
-
- def compute_storage(self, tensor_type):
- dtype = DataType(tensor_type.dtype)
- shape = [int(sh) for sh in tensor_type.shape]
- size = 1
- for sh in shape:
- size *= sh
- size *= (dtype.bits * dtype.lanes + 7) // 8
- return expr.const(size, dtype=self.compute_dtype)
-
- def make_static_allocation(self, scope, tensor_type, ctx, name_hint):
- """Allocate a tensor with a statically known shape."""
- shape = [int(sh) for sh in tensor_type.shape]
- if len(shape) == 0:
- shape = expr.const(np.empty((), dtype=self.compute_dtype), dtype=self.compute_dtype)
- else:
- shape = expr.const(np.array(shape), dtype=self.compute_dtype)
- size = self.compute_storage(tensor_type)
- alignment = self.compute_alignment(tensor_type.dtype)
- dtype = tensor_type.dtype
- sto = scope.let("storage_{0}".format(name_hint), alloc_storage(size, alignment, ctx, dtype))
- # TODO(@jroesch): There is a bug with typing based on the constant shape.
- tensor = alloc_tensor(sto, shape, dtype, tensor_type.shape)
- return scope.let("tensor_{0}".format(name_hint), tensor)
-
- def visit_let(self, let):
- scope = ScopeBuilder()
-
- self.scopes.append(scope)
- while isinstance(let, expr.Let):
- new_val = self.visit(let.value)
- scope.let(let.var, new_val)
- let = let.body
-
- new_body = self.visit(let)
- scope.ret(new_body)
- self.scopes.pop()
-
- return scope.get()
-
- def emit_shape_func(self, scope, func, new_args):
- """Insert the shape function given a primitive function."""
- shape_func_ins = []
- engine = compile_engine.get()
- cfunc = engine.lower_shape_func(func, self.target_host)
- input_states = cfunc.shape_func_param_states
-
- is_inputs = []
- input_pos = 0
- cpu_ctx = nd.cpu(0)
- for i, (arg, state) in enumerate(zip(new_args, input_states)):
- state = int(state)
- # Pass Shapes
- if state == 2:
- for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)):
- sh_of = self.visit(self.shape_of(subexp))
- shape_func_ins.append(scope.let("in_shape_{0}".format(input_pos + j), sh_of))
- input_pos += 1
- is_inputs.append(0)
- # Pass Inputs
- elif state == 1:
- new_arg = self.visit(arg)
- ctx = self.get_context(arg)
- if ctx.device_type != cpu_ctx.device_type:
- new_arg = self.device_copy(new_arg, ctx, cpu_ctx)
- shape_func_ins.append(scope.let("in_shape_{0}".format(input_pos), new_arg))
- input_pos += 1
- is_inputs.append(1)
- else:
- # TODO(@jroesch): handle 3rd case
- raise Exception("unsupported shape function input state")
-
- out_shapes = []
- for i, out in enumerate(cfunc.outputs):
- tt = ty.TensorType(out.shape, out.dtype)
- # Put shape func on CPU. This also ensures that everything between
- # shape_of and shape_func are on CPU.
- alloc = self.make_static_allocation(scope, tt, cpu_ctx, i)
- alloc = scope.let("shape_func_out_{0}".format(i), alloc)
- out_shapes.append(alloc)
-
- shape_call = self.shape_func(
- func, expr.Tuple(shape_func_ins), expr.Tuple(out_shapes), is_inputs
- )
-
- scope.let("shape_func", shape_call)
- return out_shapes
-
- def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
- """Generate the code for invoking a TVM op with a dynamic shape."""
- out_shapes = self.emit_shape_func(scope, func, new_args)
-
- storages = []
- func_ctx = self.get_context(func)
- for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)):
- size = self.compute_storage_in_relay(out_shape, out_type.dtype)
- alignment = self.compute_alignment(out_type.dtype)
- sto = scope.let(
- "storage_{i}".format(i=i), alloc_storage(size, alignment, func_ctx, out_type.dtype)
- )
- storages.append(sto)
-
- outs = []
- sh_ty_storage = zip(out_shapes, out_types, storages)
- for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage):
- alloc = alloc_tensor(storage, out_shape, out_type.dtype, out_type.shape)
- alloc = scope.let("out_{i}".format(i=i), alloc)
- outs.append(alloc)
-
- tuple_outs = expr.Tuple(outs)
- invoke = self.invoke_tvm(func, ins, tuple_outs)
- scope.let("", invoke)
- return to_tuple_type(ret_type, tuple_outs.fields)
-
- def emit_reshape_tensor(self, scope, func, new_args, ret_type):
- if self.is_dynamic(ret_type):
- out_shapes = self.emit_shape_func(scope, func, new_args)
- shape_expr = out_shapes[0]
- else:
- # constant output shape
- shape = [int(dim) for dim in ret_type.shape]
- shape_expr = expr.const(shape, dtype=self.compute_dtype)
- return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape)
-
- def is_dynamic(self, ret_type):
- is_dynamic = ty.is_dynamic(ret_type)
- # TODO(@jroesch): restore this code, more complex then it seems
- # for arg in call.args:
- # is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
- return is_dynamic
-
- def visit_call(self, call):
- if is_primitive(call):
- # Because we are in ANF we do not need to visit the arguments.
- scope = self.current_scope()
- new_args = [self.visit(arg) for arg in call.args]
-
- ins = expr.Tuple(new_args)
- ret_type = call.checked_type
- out_types = flatten_tuple_type(ret_type)
-
- if is_reshape_only(call.op):
- # Handle fused op that only contains reshape op
- return self.emit_reshape_tensor(scope, call.op, new_args, ret_type)
-
- if is_device_copy(call.op):
- # Handle device copy op
- if isinstance(call.op, Function):
- attr = call.op.body.attrs
- else:
- attr = call.attr
- return self.device_copy(
- new_args[0], TVMContext(attr.src_dev_type, 0), TVMContext(attr.dst_dev_type, 0)
- )
-
- if self.is_dynamic(ret_type):
- # Handle dynamic case.
- return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type)
-
- # Handle static case.
- outs = []
- for i, out_ty in enumerate(out_types):
- ctx = self.get_context(call)
- assert isinstance(ctx, TVMContext)
- out = self.make_static_allocation(scope, out_ty, ctx, i)
- outs.append(out)
-
- output = expr.Tuple(outs)
- invoke = self.invoke_tvm(call.op, ins, output)
- scope.let("", invoke)
- return to_tuple_type(ret_type, output.fields)
- return super().visit_call(call)
-
-
-def mk_analysis_annotator(results):
- """Pretty print the annotated relay program with device info"""
-
- def _annotator(exp):
- if exp in results:
- val = results[exp]
- assert len(val) == 2
- ctx = TVMContext(val[0].value, val[1].value)
- return f"<{ctx}>"
- else:
- return ""
-
- return _annotator
-
-
-@module_pass(opt_level=0)
-class ManifestAlloc:
- """The explicit pass wrapper around ManifestAlloc."""
-
- # TODO(zhiics, jroesch) Port this pass to C++.
- def __init__(self, target_host, targets):
- self.target_host = target_host
- self.targets = targets
-
- def transform_module(self, mod, _):
- """Invokes the pass"""
- # TODO(@jroesch): Is there a way to do one shot initialization?
- # can we have def pass_init?
- mod.import_from_std("core.rly")
- mod = InferType()(mod)
-
- assert isinstance(self.targets, (dict, container.Map))
- if len(self.targets) > 1:
- pass_ctx = PassContext.current()
- if "relay.fallback_device_type" in pass_ctx.config:
- fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"])
- else:
- fallback_ctx = cpu(0)
- ca = context_analysis(mod, TVMContext(fallback_ctx.device_type, 0))
- else:
- if isinstance(self.targets, dict):
- dev = list(self.targets.keys())[0]
- else:
- dev, _ = self.targets.items()[0]
- ca = context_analysis(mod, nd.context(dev.value))
-
- # The following code can be used for debugging the module after
- # annotation.
- # print(mod.astext(show_meta_data=False, annotate=mk_analysis_annotator(ca)))
-
- gv_funcs = mod.functions
- for gv, f in gv_funcs.items():
- ea = ManifestAllocPass(self.target_host, ca)
- f = ea.visit(f)
- mod.update_func(gv, f)
- return mod
-
-
-register_func("relay.transform.ManifestAlloc", ManifestAlloc)
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index d908153..7861502 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -58,12 +58,6 @@ namespace transform {
Pass LambdaLift();
Pass InlinePrimitives();
-Pass ManifestAlloc(Target target_host, vm::TargetsMap targets) {
- auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
- ICHECK(f != nullptr) << "unable to load allocation manifestation pass";
- return (*f)(target_host, targets);
-}
-
Pass MemoryPlan() {
auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan");
ICHECK(f != nullptr) << "unable to load the memory planning pass";
diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc
new file mode 100644
index 0000000..360778e
--- /dev/null
+++ b/src/relay/transforms/memory_alloc.cc
@@ -0,0 +1,494 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/memory_alloc.cc
+ * \brief A pass for manifesting explicit memory allocations.
+ */
+
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/transform.h>
+#include <tvm/support/logging.h>
+#include <tvm/target/target.h>
+
+#include <cstdint>
+#include <cstdio>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "../backend/compile_engine.h"
+#include "let_list.h"
+#include "pattern_utils.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+namespace relay {
+
+extern Expr ToTupleType(const Type& ty, const std::vector<Expr>& exprs);
+extern std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
+extern std::vector<TensorType> FlattenTupleType(const Type& type);
+
+using AnalysisResultMap =
+ std::unordered_map<Expr, TVMContext, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+
+inline Constant MakeConstant(const std::vector<int64_t>& value) {
+ return MakeConstantTensor(DataType::Int(64), {static_cast<int64_t>(value.size())}, value);
+}
+
+inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype,
+ Array<IndexExpr> assert_shape) {
+ auto f = runtime::Registry::Get("relay.op.memory._make.alloc_tensor");
+ CHECK(f != nullptr) << "unable to find alloc_tensor op";
+ auto offset = MakeConstantScalar(DataType::Int(64), 0);
+ return (*f)(storage, offset, shape, dtype, assert_shape);
+}
+
+// A pass to check if the fused op contains only reshape ops.
+class CheckReshapeOnly : public ExprVisitor {
+ public:
+ CheckReshapeOnly()
+ : reshape_(Op::Get("reshape")),
+ contr_reshape_(Op::Get("contrib_reverse_reshape")),
+ dyn_reshape_(Op::Get("dyn.reshape")) {}
+
+ void VisitExpr_(const CallNode* cn) final {
+ if (!reshape_only) return;
+ if (cn->op != reshape_ && cn->op != contr_reshape_ && cn->op != dyn_reshape_) {
+ reshape_only = false;
+ }
+ for (auto arg : cn->args) ExprVisitor::VisitExpr(arg);
+ }
+
+ void VisitExpr_(const VarNode* vn) final {
+ if (!vn->checked_type_->IsInstance<TensorTypeNode>()) {
+ reshape_only = false;
+ }
+ }
+
+ const Op& reshape_;
+ const Op& contr_reshape_;
+ const Op& dyn_reshape_;
+ bool reshape_only{true};
+};
+
+// Check if the primitive function contains only reshape ops.
+bool IsReshapeOnly(const Expr& expr) {
+ auto check = CheckReshapeOnly();
+ check.VisitExpr(expr);
+ return check.reshape_only;
+}
+
+class DialectRewriter : public ExprMutator {
+ public:
+ DialectRewriter(const Target& target_host, const AnalysisResultMap& context_analysis_map)
+ : target_host_(target_host),
+ context_analysis_map_(context_analysis_map),
+ device_copy_(runtime::Registry::Get("relay.op._make.device_copy")),
+ invoke_tvm_(runtime::Registry::Get("relay.op.vm.invoke_tvm_op")),
+ alloc_storage_(runtime::Registry::Get("relay.op.memory._make.alloc_storage")),
+ shape_func_(runtime::Registry::Get("relay.op.vm.shape_func")),
+ shape_of_(runtime::Registry::Get("relay.op.vm.shape_of")),
+ reshape_tensor_(runtime::Registry::Get("relay.op.vm.reshape_tensor")),
+ prod_(runtime::Registry::Get("relay.op._make.prod")),
+ divide_(runtime::Registry::Get("relay.op._make.divide")),
+ add_(runtime::Registry::Get("relay.op._make.add")),
+ multiply_(runtime::Registry::Get("relay.op._make.multiply")) {}
+
+ // Get the context of an expression.
+ TVMContext GetContext(const Expr& expr) const {
+ auto it = context_analysis_map_.find(expr);
+ CHECK(it != context_analysis_map_.end()) << "Cannot find expr in the context analysis map:\n"
+ << AsText(expr, false);
+ return it->second;
+ }
+
+ Function Rewrite(const Function& expr) {
+ auto ret = ExprMutator::Mutate(expr);
+ return Downcast<Function>(ret);
+ }
+
+ Expr VisitExpr_(const TupleNode* tn) final {
+ LetList& scope = scopes_.back();
+ Array<Expr> new_fields;
+ for (auto field : tn->fields) {
+ auto new_field = ExprMutator::Mutate(field);
+ if (new_field->IsInstance<ConstantNode>()) {
+ Var const_var("const", Type(nullptr));
+ new_field = scope.Push(const_var, new_field);
+ }
+ new_fields.push_back(new_field);
+ }
+ return Tuple(new_fields);
+ }
+
+ Expr VisitExpr_(const LetNode* ln) final {
+ scopes_.emplace_back();
+
+ const LetNode* let = ln;
+ Expr body;
+ while (let) {
+ auto new_value = ExprMutator::Mutate(let->value);
+ scopes_.back().Push(let->var, new_value);
+ body = let->body;
+ let = body.as<LetNode>();
+ }
+
+ CHECK(body.defined());
+ auto new_body = ExprMutator::Mutate(body);
+ auto ret = scopes_.back().Get(new_body);
+ scopes_.pop_back();
+ return ret;
+ }
+
+ Expr VisitExpr_(const CallNode* cn) final {
+ if (IsPrimitive(cn)) {
+ // Because we are in ANF we do not need to visit the arguments.
+ LetList& scope = scopes_.back();
+ std::vector<Expr> new_args;
+ for (const auto& it : cn->args) {
+ new_args.push_back(ExprMutator::Mutate(it));
+ }
+
+ Tuple ins(new_args);
+ Type ret_type = cn->checked_type_;
+ std::vector<TensorType> out_types = FlattenTupleType(ret_type);
+
+ // Handle fused op that only contains reshape op
+ if (IsReshapeOnly(cn->op)) {
+ Function func = Downcast<Function>(cn->op);
+ return EmitReshapeTensor(&scope, func, new_args, ret_type);
+ }
+
+ // Handle device copy op
+ if (IsDeviceCopy(cn->op)) {
+ Attrs attr;
+ if (const auto* fn = cn->op.as<FunctionNode>()) {
+ const auto* copy_call = fn->body.as<CallNode>();
+ CHECK(copy_call);
+ attr = copy_call->attrs;
+ } else {
+ attr = cn->attrs;
+ }
+ const DeviceCopyAttrs* copy_attr = attr.as<DeviceCopyAttrs>();
+ CHECK(copy_attr);
+ return DeviceCopy(new_args[0], copy_attr->src_dev_type, copy_attr->dst_dev_type);
+ } else if (IsDynamic(ret_type)) {
+ Function func = Downcast<Function>(cn->op);
+ return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type);
+ } else {
+ // Handle the static case
+ Array<Expr> outs;
+ for (size_t i = 0; i < out_types.size(); ++i) {
+ TVMContext ctx = GetContext(GetRef<Call>(cn));
+ auto out = MakeStaticAllocation(&scope, out_types[i], ctx, std::to_string(i));
+ outs.push_back(out);
+ }
+ Tuple output(outs);
+ Expr invoke = (*invoke_tvm_)(cn->op, ins, output);
+ scope.Push(invoke);
+ return ToTupleType(ret_type,
+ std::vector<Expr>(output->fields.begin(), output->fields.end()));
+ }
+ } else {
+ return ExprMutator::VisitExpr_(cn);
+ }
+ }
+
+ private:
+ // Insert a device copy node.
+ Expr DeviceCopy(const Expr& inp, int src_ctx, int dst_ctx) {
+ return ExprMutator::Mutate((*device_copy_)(inp, src_ctx, dst_ctx));
+ }
+
+ // Check if a call invokes a primitive function.
+ bool IsPrimitive(const CallNode* call) const {
+ if (const auto* fn = call->op.as<FunctionNode>()) {
+ return fn->HasNonzeroAttr(attr::kPrimitive);
+ }
+ return false;
+ }
+
+ // Check if the current relay expression is a device copy call. We can simply
+ // check the body of it if it is a function because the device_copy op is opaque.
+ bool IsDeviceCopy(const Expr& expr) const {
+ if (const auto* fn = expr.as<FunctionNode>()) {
+ auto body = fn->body;
+ const CallNode* call = body.as<CallNode>();
+ return call && call->op == Op::Get("device_copy");
+ } else if (const CallNode* cn = expr.as<CallNode>()) {
+ return cn->op == Op::Get("device_copy");
+ } else {
+ return false;
+ }
+ }
+
+ Expr ComputeAlignment(const DataType& dtype) const {
+ int64_t align = dtype.bits() / 8 * dtype.lanes();
+ if (align < 64) {
+ align = 64;
+ }
+ return MakeConstantScalar(DataType::Int(64), align);
+ }
+
+ Expr ComputeStorageInRelay(const Expr& shape, const TensorType& type) const {
+ auto dtype = DataType(type->dtype);
+ Expr els = (*prod_)(shape, Array<Expr>(nullptr), false, false);
+ Expr num = MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes());
+ Expr add = (*add_)(num, MakeConstantScalar(DataType::Int(64), 7));
+ Expr div = MakeConstantScalar(DataType::Int(64), 8);
+ Expr ret = (*multiply_)(els, (*divide_)(add, div));
+ return std::move(ret);
+ }
+
+ Expr ComputeStorage(const TensorType& type) {
+ int64_t size = 1;
+ for (auto it : type->shape) {
+ auto val = it.as<IntImmNode>();
+ CHECK(val);
+ size *= val->value;
+ }
+ size *= (type->dtype.bits() * type->dtype.lanes() + 7) / 8;
+ return std::move(MakeConstantScalar(DataType::Int(64), size));
+ }
+
+ // Allocate a tensor with a statically known shape.
+ Var MakeStaticAllocation(LetList* scope, const TensorType& type, TVMContext ctx,
+ String name_hint) {
+ std::vector<int64_t> int_shape;
+ for (auto it : type->shape) {
+ const auto* imm = it.as<IntImmNode>();
+ CHECK(imm) << "expect static int shape";
+ int_shape.push_back(imm->value);
+ }
+ Expr shape = MakeConstant(int_shape);
+ Expr size = ComputeStorage(type);
+ Expr alignment = ComputeAlignment(type->dtype);
+ // Run type inference later to get the correct type.
+ Var var("storage_" + name_hint, Type(nullptr));
+ Expr value = (*alloc_storage_)(size, alignment, ctx, type->dtype);
+ auto sto = scope->Push(var, value);
+
+ // TODO(@jroesch): There is a bug with typing based on the constant shape.
+ auto tensor = AllocTensor(sto, shape, type->dtype, type->shape);
+ Var tensor_var("tensor_" + name_hint, Type(nullptr));
+ return scope->Push(tensor_var, tensor);
+ }
+
+ // Insert the shape function given a primitive function.
+ Array<Expr> EmitShapeFunc(LetList* scope, const Function& func,
+ const std::vector<Expr>& new_args) {
+ Array<Expr> shape_func_ins;
+ auto engine = CompileEngine::Global();
+ CCacheKey key(func, target_host_);
+ auto cfunc = engine->LowerShapeFunc(key);
+ auto input_states = cfunc->shape_func_param_states;
+
+ Array<Integer> is_inputs;
+ int input_pos = 0;
+ TVMContext cpu_ctx = default_context_;
+ CHECK_EQ(new_args.size(), input_states.size());
+ for (size_t i = 0; i < new_args.size(); ++i) {
+ Expr arg = new_args[i];
+ Type ty;
+ if (const auto* vn = arg.as<VarNode>()) {
+ ty = vn->type_annotation;
+ } else {
+ ty = arg->checked_type();
+ }
+ int state = input_states[i]->value;
+ // Pass Shapes
+ if (state == 2) {
+ std::vector<Expr> exprs = FromTupleType(ty, arg);
+ for (size_t j = 0; j < exprs.size(); ++j) {
+ Expr sh_of = ExprMutator::Mutate((*shape_of_)(exprs[j]));
+ Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr));
+ shape_func_ins.push_back(scope->Push(in_shape_var, sh_of));
+ input_pos++;
+ }
+ is_inputs.push_back(0);
+ } else if (state == 1) {
+ auto new_arg = ExprMutator::Mutate(arg);
+ auto ctx = GetContext(arg);
+ if (ctx.device_type != cpu_ctx.device_type) {
+ new_arg = DeviceCopy(new_arg, ctx.device_type, cpu_ctx.device_type);
+ }
+ Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr));
+ shape_func_ins.push_back(scope->Push(in_shape_var, new_arg));
+ input_pos++;
+ is_inputs.push_back(1);
+ } else {
+ // TODO(@jroesch): handle 3rd case
+ LOG(FATAL) << "unsupported shape function input state";
+ }
+ }
+
+ Array<Expr> out_shapes;
+ for (size_t i = 0; i < cfunc->outputs.size(); ++i) {
+ auto out = cfunc->outputs[i];
+ auto tt = TensorType(out->shape, out->dtype);
+ // Put shape func on CPU. This also ensures that everything between
+ // shape_of and shape_func are on CPU.
+ auto alloc = MakeStaticAllocation(scope, tt, cpu_ctx, std::to_string(i));
+ Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr));
+ alloc = scope->Push(shape_func_out_var, alloc);
+ out_shapes.push_back(alloc);
+ }
+ auto shape_call = (*shape_func_)(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs);
+ Var shape_func_var("shape_func", Type(nullptr));
+ scope->Push(shape_func_var, shape_call);
+ return out_shapes;
+ }
+
+ // Generate the code for invoking a TVM op with a dynamic shape.
+ Expr DynamicInvoke(LetList* scope, const Function& func, const Tuple& ins,
+ const std::vector<Expr>& new_args, const std::vector<TensorType>& out_types,
+ const Type& ret_type) {
+ auto out_shapes = EmitShapeFunc(scope, func, new_args);
+ std::vector<Var> storages;
+ auto func_ctx = GetContext(func);
+ CHECK_EQ(out_shapes.size(), out_types.size());
+ for (size_t i = 0; i < out_shapes.size(); ++i) {
+ auto out_shape = out_shapes[i];
+ auto out_type = out_types[i];
+ auto size = ComputeStorageInRelay(out_shape, out_type);
+ auto alignment = ComputeAlignment(out_type->dtype);
+ Var sto_var("storage_" + std::to_string(i), Type(nullptr));
+ auto val = (*alloc_storage_)(size, alignment, func_ctx, out_type->dtype);
+ storages.push_back(scope->Push(sto_var, val));
+ }
+
+ Array<Expr> outs;
+ for (size_t i = 0; i < storages.size(); ++i) {
+ auto out_shape = out_shapes[i];
+ auto out_type = out_types[i];
+ auto storage = storages[i];
+ auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape);
+ Var out_var("out_" + std::to_string(i), Type(nullptr));
+ outs.push_back(scope->Push(out_var, alloc));
+ }
+
+ Tuple tuple_outs(outs);
+ auto invoke = (*invoke_tvm_)(func, ins, tuple_outs);
+ scope->Push(invoke);
+ return ToTupleType(ret_type,
+ std::vector<Expr>(tuple_outs->fields.begin(), tuple_outs->fields.end()));
+ }
+
+ Expr EmitReshapeTensor(LetList* scope, const Function& func, const std::vector<Expr>& new_args,
+ const Type& ret_type) {
+ TensorType ret_ty = Downcast<TensorType>(ret_type);
+ Expr shape_expr;
+ if (IsDynamic(ret_type)) {
+ auto out_shapes = EmitShapeFunc(scope, func, new_args);
+ shape_expr = out_shapes[0];
+ } else {
+ std::vector<int64_t> shape;
+ for (const auto& it : ret_ty->shape) {
+ const auto* imm = it.as<IntImmNode>();
+ CHECK(imm) << "expect static int shape";
+ shape.push_back(imm->value);
+ }
+ shape_expr = MakeConstant(shape);
+ }
+ return (*reshape_tensor_)(new_args[0], shape_expr, ret_ty->shape);
+ }
+
+ private:
+ Target target_host_;
+ AnalysisResultMap context_analysis_map_;
+ std::vector<LetList> scopes_;
+
+ // Cache the following ops
+ const PackedFunc* device_copy_;
+ const PackedFunc* invoke_tvm_;
+ const PackedFunc* alloc_storage_;
+ const PackedFunc* shape_func_;
+ const PackedFunc* shape_of_;
+ const PackedFunc* reshape_tensor_;
+ const PackedFunc* prod_;
+ const PackedFunc* divide_;
+ const PackedFunc* add_;
+ const PackedFunc* multiply_;
+
+ runtime::DataType compute_dtype_ = runtime::DataType::Int(64);
+ TVMContext default_context_{kDLCPU, 0};
+};
+
+namespace transform {
+
+Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets) {
+ return tvm::transform::CreateModulePass(
+ [=](IRModule mod, const PassContext& pass_ctx) {
+ DLOG(INFO) << "tvm::relay::transform::ManifestAlloc";
+ // We need to mutate module, therefore making a copy of it.
+ mod.CopyOnWrite();
+ mod->ImportFromStd("core.rly");
+ mod = relay::transform::InferType()(mod);
+
+ TVMContext fallback_ctx;
+ if (targets.size() > 1) {
+ auto pass_ctx = PassContext::Current();
+ Optional<Integer> opt_fallback_dev =
+ pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast<int>(kDLCPU)));
+ auto fallback_dev = opt_fallback_dev.value();
+ CHECK_GT(fallback_dev->value, 0U);
+ fallback_ctx.device_type = static_cast<DLDeviceType>(fallback_dev->value);
+ fallback_ctx.device_id = 0;
+ } else {
+ const auto& it = targets.begin();
+ fallback_ctx.device_type = static_cast<DLDeviceType>((*it).first->value);
+ fallback_ctx.device_id = 0;
+ }
+ auto ca = ContextAnalysis(mod, fallback_ctx);
+
+ auto glob_funcs = mod->functions;
+ for (const auto& it : glob_funcs) {
+ if (auto* func_node = it.second.as<FunctionNode>()) {
+ auto func = GetRef<Function>(func_node);
+ auto rewriter = DialectRewriter(target_host, ca);
+ auto updated_func = rewriter.Rewrite(func);
+
+ mod->Update(it.first, updated_func);
+ }
+ }
+
+ mod = relay::transform::InferType()(mod);
+ return mod;
+ },
+ 0, "ManifestAlloc", {});
+}
+
+TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc")
+ .set_body_typed([](Target target_host, Map<tvm::Integer, tvm::Target> targets) {
+ return ManifestAlloc(target_host, targets);
+ });
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 0b575d1..9d05631 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -54,7 +54,6 @@ def check_result(
for kind in ["debug", "vm"]:
targets = targets or tvm.testing.enabled_targets()
for tgt, ctx in targets:
- print(tgt)
if disable_targets and tgt in disable_targets:
continue
if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type):
diff --git a/tests/python/relay/test_memory_passes.py b/tests/python/relay/test_memory_passes.py
index c960d1f..546aaf5 100644
--- a/tests/python/relay/test_memory_passes.py
+++ b/tests/python/relay/test_memory_passes.py
@@ -18,7 +18,6 @@ import tvm
from tvm import te
import numpy as np
from tvm import relay
-from tvm.relay import memory_alloc
def check_memory_plan(func, check_fn):