You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/01/29 15:17:00 UTC

[tvm] branch main updated: [Refactor][VM] Port memory_alloc to c++ (#7369)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 44a071a  [Refactor][VM] Port memory_alloc to c++ (#7369)
44a071a is described below

commit 44a071aa1e9ad11c20fbfcf725ddb6dd8a2823c4
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Fri Jan 29 07:16:45 2021 -0800

    [Refactor][VM] Port memory_alloc to c++ (#7369)
    
    * Port memory_alloc to c++
    
    * remove memory python pass
---
 include/tvm/relay/transform.h              |  12 +
 python/tvm/relay/__init__.py               |   1 -
 python/tvm/relay/transform/__init__.py     |   1 -
 python/tvm/relay/transform/memory_alloc.py | 389 -----------------------
 src/relay/backend/vm/compiler.cc           |   6 -
 src/relay/transforms/memory_alloc.cc       | 494 +++++++++++++++++++++++++++++
 tests/python/relay/test_any.py             |   1 -
 tests/python/relay/test_memory_passes.py   |   1 -
 8 files changed, 506 insertions(+), 399 deletions(-)

diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index e4b39da..123b7e3 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -31,6 +31,7 @@
 #include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/runtime/container.h>
+#include <tvm/target/target.h>
 
 #include <string>
 
@@ -419,6 +420,17 @@ TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
  */
 TVM_DLL Pass SimplifyExpr();
 
+/*!
+ * \brief A pass for manifesting explicit memory allocations and rewriting
+ * specific dialects.
+ *
+ * \param target_host The target used by the host for compliation.
+ * \param targets The device type and target pairs for compliation.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets);
+
 }  // namespace transform
 
 /*!
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 97f6d1c..89c8fcb 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -61,7 +61,6 @@ from . import qnn
 from .scope_builder import ScopeBuilder
 
 # Load Memory Passes
-from .transform import memory_alloc
 from .transform import memory_plan
 
 # Required to traverse large programs
diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py
index 1d0ea17..ca9996a 100644
--- a/python/tvm/relay/transform/__init__.py
+++ b/python/tvm/relay/transform/__init__.py
@@ -19,4 +19,3 @@
 # transformation passes
 from .transform import *
 from .recast import recast
-from . import memory_alloc
diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py
deleted file mode 100644
index 66528c8..0000000
--- a/python/tvm/relay/transform/memory_alloc.py
+++ /dev/null
@@ -1,389 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
-"""
-A pass for manifesting explicit memory allocations.
-"""
-import numpy as np
-
-from tvm.ir.transform import PassContext, module_pass
-from tvm.relay.transform import InferType
-from tvm import nd, container
-from ..function import Function
-from ..expr_functor import ExprVisitor, ExprMutator
-from ..scope_builder import ScopeBuilder
-from .. import op
-from ... import DataType, register_func
-from .. import ty, expr
-from ..backend import compile_engine
-from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type
-from ... import cpu
-from ..op.memory import alloc_storage
-from ..analysis import context_analysis
-from ..._ffi.runtime_ctypes import TVMContext
-
-
-def alloc_tensor(storage, shape, dtype="float32", assert_shape=None):
-    offset = expr.const(0, dtype="int64")
-    return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape)
-
-
-def is_primitive(call):
-    return (
-        hasattr(call, "op")
-        and hasattr(call.op, "attrs")
-        and hasattr(call.op.attrs, "Primitive")
-        and int(call.op.attrs.Primitive) == 1
-    )
-
-
-def is_device_copy(func):
-    """
-    Check if the current relay expression is a device copy call. We can simply check
-    the body of it if it is a function becase the device_copy op is opaque.
-    """
-    if isinstance(func, Function):
-        body = func.body
-        return isinstance(body, expr.Call) and body.op == op.get("device_copy")
-    if isinstance(func, expr.Call):
-        return func.op == op.get("device_copy")
-    return False
-
-
-class CheckReshapeOnly(ExprVisitor):
-    """A pass to check if the fused op contains only reshape ops."""
-
-    def __init__(self):
-        super().__init__()
-        self._reshape_ops = [
-            op.get("reshape"),
-            op.get("contrib_reverse_reshape"),
-            op.get("dyn.reshape"),
-        ]
-        self.reshape_only = True
-
-    def visit_call(self, call):
-        if not self.reshape_only:
-            return
-        if call.op not in self._reshape_ops:
-            self.reshape_only = False
-        for arg in call.args:
-            self.visit(arg)
-
-    def visit_var(self, var):
-        var_type = var.checked_type
-        if not isinstance(var_type, ty.TensorType):
-            self.reshape_only = False
-
-
-def is_reshape_only(func):
-    """Check if the primitive function contains only reshape ops."""
-    check = CheckReshapeOnly()
-    check.visit(func)
-    return check.reshape_only
-
-
-class ManifestAllocPass(ExprMutator):
-    """A pass for explicitly manifesting all memory allocations in Relay."""
-
-    def __init__(self, target_host, context_analysis_map):
-        self.invoke_tvm = op.vm.invoke_tvm_op
-        self.shape_func = op.vm.shape_func
-        self.shape_of = op.vm.shape_of
-        self.reshape_tensor = op.vm.reshape_tensor
-        self.scopes = [ScopeBuilder()]
-        self.target_host = target_host
-        self.default_context = cpu(0)
-        self.compute_dtype = "int64"
-        self.context_analysis_map = context_analysis_map
-        super().__init__()
-
-    def get_context(self, exp):
-        """Get the context of a given expression"""
-        assert exp in self.context_analysis_map, exp.astext(False)
-        val = self.context_analysis_map[exp]
-        # val[0], val[1] are device_type and device_id, respectively.
-        # We don't need to unpack after porting this pass to C++.
-        assert len(val) == 2
-        return TVMContext(val[0].value, val[1].value)
-
-    def device_copy(self, inp, src_ctx, dst_ctx):
-        """Insert a device copy node."""
-        return self.visit(op.tensor.device_copy(inp, src_ctx, dst_ctx))
-
-    def current_scope(self):
-        return self.scopes[-1]
-
-    def visit_tuple(self, tup):
-        scope = self.current_scope()
-        new_fields = []
-        for field in tup.fields:
-            field = self.visit(field)
-            if isinstance(field, expr.Constant):
-                field = scope.let("const", field)
-            new_fields.append(field)
-        return expr.Tuple(new_fields)
-
-    def compute_alignment(self, dtype):
-        dtype = DataType(dtype)
-        align = (dtype.bits // 8) * dtype.lanes
-        # MAGIC CONSTANT FROM device_api.h
-        if align < 64:
-            align = 64
-
-        return expr.const(align, dtype="int64")
-
-    def compute_storage_in_relay(self, shape, dtype):
-        dtype = DataType(dtype)
-        els = op.prod(shape)
-        num = expr.const(dtype.bits * dtype.lanes, self.compute_dtype)
-        num = num + expr.const(7, self.compute_dtype)
-        div = expr.const(8, self.compute_dtype)
-        return els * (num / div)
-
-    def compute_storage(self, tensor_type):
-        dtype = DataType(tensor_type.dtype)
-        shape = [int(sh) for sh in tensor_type.shape]
-        size = 1
-        for sh in shape:
-            size *= sh
-        size *= (dtype.bits * dtype.lanes + 7) // 8
-        return expr.const(size, dtype=self.compute_dtype)
-
-    def make_static_allocation(self, scope, tensor_type, ctx, name_hint):
-        """Allocate a tensor with a statically known shape."""
-        shape = [int(sh) for sh in tensor_type.shape]
-        if len(shape) == 0:
-            shape = expr.const(np.empty((), dtype=self.compute_dtype), dtype=self.compute_dtype)
-        else:
-            shape = expr.const(np.array(shape), dtype=self.compute_dtype)
-        size = self.compute_storage(tensor_type)
-        alignment = self.compute_alignment(tensor_type.dtype)
-        dtype = tensor_type.dtype
-        sto = scope.let("storage_{0}".format(name_hint), alloc_storage(size, alignment, ctx, dtype))
-        # TODO(@jroesch): There is a bug with typing based on the constant shape.
-        tensor = alloc_tensor(sto, shape, dtype, tensor_type.shape)
-        return scope.let("tensor_{0}".format(name_hint), tensor)
-
-    def visit_let(self, let):
-        scope = ScopeBuilder()
-
-        self.scopes.append(scope)
-        while isinstance(let, expr.Let):
-            new_val = self.visit(let.value)
-            scope.let(let.var, new_val)
-            let = let.body
-
-        new_body = self.visit(let)
-        scope.ret(new_body)
-        self.scopes.pop()
-
-        return scope.get()
-
-    def emit_shape_func(self, scope, func, new_args):
-        """Insert the shape function given a primitive function."""
-        shape_func_ins = []
-        engine = compile_engine.get()
-        cfunc = engine.lower_shape_func(func, self.target_host)
-        input_states = cfunc.shape_func_param_states
-
-        is_inputs = []
-        input_pos = 0
-        cpu_ctx = nd.cpu(0)
-        for i, (arg, state) in enumerate(zip(new_args, input_states)):
-            state = int(state)
-            # Pass Shapes
-            if state == 2:
-                for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)):
-                    sh_of = self.visit(self.shape_of(subexp))
-                    shape_func_ins.append(scope.let("in_shape_{0}".format(input_pos + j), sh_of))
-                    input_pos += 1
-                is_inputs.append(0)
-            # Pass Inputs
-            elif state == 1:
-                new_arg = self.visit(arg)
-                ctx = self.get_context(arg)
-                if ctx.device_type != cpu_ctx.device_type:
-                    new_arg = self.device_copy(new_arg, ctx, cpu_ctx)
-                shape_func_ins.append(scope.let("in_shape_{0}".format(input_pos), new_arg))
-                input_pos += 1
-                is_inputs.append(1)
-            else:
-                # TODO(@jroesch): handle 3rd case
-                raise Exception("unsupported shape function input state")
-
-        out_shapes = []
-        for i, out in enumerate(cfunc.outputs):
-            tt = ty.TensorType(out.shape, out.dtype)
-            # Put shape func on CPU. This also ensures that everything between
-            # shape_of and shape_func are on CPU.
-            alloc = self.make_static_allocation(scope, tt, cpu_ctx, i)
-            alloc = scope.let("shape_func_out_{0}".format(i), alloc)
-            out_shapes.append(alloc)
-
-        shape_call = self.shape_func(
-            func, expr.Tuple(shape_func_ins), expr.Tuple(out_shapes), is_inputs
-        )
-
-        scope.let("shape_func", shape_call)
-        return out_shapes
-
-    def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
-        """Generate the code for invoking a TVM op with a dynamic shape."""
-        out_shapes = self.emit_shape_func(scope, func, new_args)
-
-        storages = []
-        func_ctx = self.get_context(func)
-        for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)):
-            size = self.compute_storage_in_relay(out_shape, out_type.dtype)
-            alignment = self.compute_alignment(out_type.dtype)
-            sto = scope.let(
-                "storage_{i}".format(i=i), alloc_storage(size, alignment, func_ctx, out_type.dtype)
-            )
-            storages.append(sto)
-
-        outs = []
-        sh_ty_storage = zip(out_shapes, out_types, storages)
-        for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage):
-            alloc = alloc_tensor(storage, out_shape, out_type.dtype, out_type.shape)
-            alloc = scope.let("out_{i}".format(i=i), alloc)
-            outs.append(alloc)
-
-        tuple_outs = expr.Tuple(outs)
-        invoke = self.invoke_tvm(func, ins, tuple_outs)
-        scope.let("", invoke)
-        return to_tuple_type(ret_type, tuple_outs.fields)
-
-    def emit_reshape_tensor(self, scope, func, new_args, ret_type):
-        if self.is_dynamic(ret_type):
-            out_shapes = self.emit_shape_func(scope, func, new_args)
-            shape_expr = out_shapes[0]
-        else:
-            # constant output shape
-            shape = [int(dim) for dim in ret_type.shape]
-            shape_expr = expr.const(shape, dtype=self.compute_dtype)
-        return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape)
-
-    def is_dynamic(self, ret_type):
-        is_dynamic = ty.is_dynamic(ret_type)
-        # TODO(@jroesch): restore this code, more complex then it seems
-        # for arg in call.args:
-        #     is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
-        return is_dynamic
-
-    def visit_call(self, call):
-        if is_primitive(call):
-            # Because we are in ANF we do not need to visit the arguments.
-            scope = self.current_scope()
-            new_args = [self.visit(arg) for arg in call.args]
-
-            ins = expr.Tuple(new_args)
-            ret_type = call.checked_type
-            out_types = flatten_tuple_type(ret_type)
-
-            if is_reshape_only(call.op):
-                # Handle fused op that only contains reshape op
-                return self.emit_reshape_tensor(scope, call.op, new_args, ret_type)
-
-            if is_device_copy(call.op):
-                # Handle device copy op
-                if isinstance(call.op, Function):
-                    attr = call.op.body.attrs
-                else:
-                    attr = call.attr
-                return self.device_copy(
-                    new_args[0], TVMContext(attr.src_dev_type, 0), TVMContext(attr.dst_dev_type, 0)
-                )
-
-            if self.is_dynamic(ret_type):
-                # Handle dynamic case.
-                return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type)
-
-            # Handle static case.
-            outs = []
-            for i, out_ty in enumerate(out_types):
-                ctx = self.get_context(call)
-                assert isinstance(ctx, TVMContext)
-                out = self.make_static_allocation(scope, out_ty, ctx, i)
-                outs.append(out)
-
-            output = expr.Tuple(outs)
-            invoke = self.invoke_tvm(call.op, ins, output)
-            scope.let("", invoke)
-            return to_tuple_type(ret_type, output.fields)
-        return super().visit_call(call)
-
-
-def mk_analysis_annotator(results):
-    """Pretty print the annotated relay program with device info"""
-
-    def _annotator(exp):
-        if exp in results:
-            val = results[exp]
-            assert len(val) == 2
-            ctx = TVMContext(val[0].value, val[1].value)
-            return f"<{ctx}>"
-        else:
-            return ""
-
-    return _annotator
-
-
-@module_pass(opt_level=0)
-class ManifestAlloc:
-    """The explicit pass wrapper around ManifestAlloc."""
-
-    # TODO(zhiics, jroesch) Port this pass to C++.
-    def __init__(self, target_host, targets):
-        self.target_host = target_host
-        self.targets = targets
-
-    def transform_module(self, mod, _):
-        """Invokes the pass"""
-        # TODO(@jroesch): Is there a way to do one shot initialization?
-        # can we have def pass_init?
-        mod.import_from_std("core.rly")
-        mod = InferType()(mod)
-
-        assert isinstance(self.targets, (dict, container.Map))
-        if len(self.targets) > 1:
-            pass_ctx = PassContext.current()
-            if "relay.fallback_device_type" in pass_ctx.config:
-                fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"])
-            else:
-                fallback_ctx = cpu(0)
-            ca = context_analysis(mod, TVMContext(fallback_ctx.device_type, 0))
-        else:
-            if isinstance(self.targets, dict):
-                dev = list(self.targets.keys())[0]
-            else:
-                dev, _ = self.targets.items()[0]
-            ca = context_analysis(mod, nd.context(dev.value))
-
-        # The following code can be used for debugging the module after
-        # annotation.
-        # print(mod.astext(show_meta_data=False, annotate=mk_analysis_annotator(ca)))
-
-        gv_funcs = mod.functions
-        for gv, f in gv_funcs.items():
-            ea = ManifestAllocPass(self.target_host, ca)
-            f = ea.visit(f)
-            mod.update_func(gv, f)
-        return mod
-
-
-register_func("relay.transform.ManifestAlloc", ManifestAlloc)
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index d908153..7861502 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -58,12 +58,6 @@ namespace transform {
 Pass LambdaLift();
 Pass InlinePrimitives();
 
-Pass ManifestAlloc(Target target_host, vm::TargetsMap targets) {
-  auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
-  ICHECK(f != nullptr) << "unable to load allocation manifestation pass";
-  return (*f)(target_host, targets);
-}
-
 Pass MemoryPlan() {
   auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan");
   ICHECK(f != nullptr) << "unable to load the memory planning pass";
diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc
new file mode 100644
index 0000000..360778e
--- /dev/null
+++ b/src/relay/transforms/memory_alloc.cc
@@ -0,0 +1,494 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/memory_alloc.cc
+ * \brief A pass for manifesting explicit memory allocations.
+ */
+
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/transform.h>
+#include <tvm/support/logging.h>
+#include <tvm/target/target.h>
+
+#include <cstdint>
+#include <cstdio>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "../backend/compile_engine.h"
+#include "let_list.h"
+#include "pattern_utils.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+namespace relay {
+
+extern Expr ToTupleType(const Type& ty, const std::vector<Expr>& exprs);
+extern std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
+extern std::vector<TensorType> FlattenTupleType(const Type& type);
+
+using AnalysisResultMap =
+    std::unordered_map<Expr, TVMContext, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+
+inline Constant MakeConstant(const std::vector<int64_t>& value) {
+  return MakeConstantTensor(DataType::Int(64), {static_cast<int64_t>(value.size())}, value);
+}
+
+inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype,
+                        Array<IndexExpr> assert_shape) {
+  auto f = runtime::Registry::Get("relay.op.memory._make.alloc_tensor");
+  CHECK(f != nullptr) << "unable to find alloc_tensor op";
+  auto offset = MakeConstantScalar(DataType::Int(64), 0);
+  return (*f)(storage, offset, shape, dtype, assert_shape);
+}
+
+// A pass to check if the fused op contains only reshape ops.
+class CheckReshapeOnly : public ExprVisitor {
+ public:
+  CheckReshapeOnly()
+      : reshape_(Op::Get("reshape")),
+        contr_reshape_(Op::Get("contrib_reverse_reshape")),
+        dyn_reshape_(Op::Get("dyn.reshape")) {}
+
+  void VisitExpr_(const CallNode* cn) final {
+    if (!reshape_only) return;
+    if (cn->op != reshape_ && cn->op != contr_reshape_ && cn->op != dyn_reshape_) {
+      reshape_only = false;
+    }
+    for (auto arg : cn->args) ExprVisitor::VisitExpr(arg);
+  }
+
+  void VisitExpr_(const VarNode* vn) final {
+    if (!vn->checked_type_->IsInstance<TensorTypeNode>()) {
+      reshape_only = false;
+    }
+  }
+
+  const Op& reshape_;
+  const Op& contr_reshape_;
+  const Op& dyn_reshape_;
+  bool reshape_only{true};
+};
+
+// Check if the primitive function contains only reshape ops.
+bool IsReshapeOnly(const Expr& expr) {
+  auto check = CheckReshapeOnly();
+  check.VisitExpr(expr);
+  return check.reshape_only;
+}
+
+class DialectRewriter : public ExprMutator {
+ public:
+  DialectRewriter(const Target& target_host, const AnalysisResultMap& context_analysis_map)
+      : target_host_(target_host),
+        context_analysis_map_(context_analysis_map),
+        device_copy_(runtime::Registry::Get("relay.op._make.device_copy")),
+        invoke_tvm_(runtime::Registry::Get("relay.op.vm.invoke_tvm_op")),
+        alloc_storage_(runtime::Registry::Get("relay.op.memory._make.alloc_storage")),
+        shape_func_(runtime::Registry::Get("relay.op.vm.shape_func")),
+        shape_of_(runtime::Registry::Get("relay.op.vm.shape_of")),
+        reshape_tensor_(runtime::Registry::Get("relay.op.vm.reshape_tensor")),
+        prod_(runtime::Registry::Get("relay.op._make.prod")),
+        divide_(runtime::Registry::Get("relay.op._make.divide")),
+        add_(runtime::Registry::Get("relay.op._make.add")),
+        multiply_(runtime::Registry::Get("relay.op._make.multiply")) {}
+
+  // Get the context of an expression.
+  TVMContext GetContext(const Expr& expr) const {
+    auto it = context_analysis_map_.find(expr);
+    CHECK(it != context_analysis_map_.end()) << "Cannot find expr in the context analysis map:\n"
+                                             << AsText(expr, false);
+    return it->second;
+  }
+
+  Function Rewrite(const Function& expr) {
+    auto ret = ExprMutator::Mutate(expr);
+    return Downcast<Function>(ret);
+  }
+
+  Expr VisitExpr_(const TupleNode* tn) final {
+    LetList& scope = scopes_.back();
+    Array<Expr> new_fields;
+    for (auto field : tn->fields) {
+      auto new_field = ExprMutator::Mutate(field);
+      if (new_field->IsInstance<ConstantNode>()) {
+        Var const_var("const", Type(nullptr));
+        new_field = scope.Push(const_var, new_field);
+      }
+      new_fields.push_back(new_field);
+    }
+    return Tuple(new_fields);
+  }
+
+  Expr VisitExpr_(const LetNode* ln) final {
+    scopes_.emplace_back();
+
+    const LetNode* let = ln;
+    Expr body;
+    while (let) {
+      auto new_value = ExprMutator::Mutate(let->value);
+      scopes_.back().Push(let->var, new_value);
+      body = let->body;
+      let = body.as<LetNode>();
+    }
+
+    CHECK(body.defined());
+    auto new_body = ExprMutator::Mutate(body);
+    auto ret = scopes_.back().Get(new_body);
+    scopes_.pop_back();
+    return ret;
+  }
+
+  Expr VisitExpr_(const CallNode* cn) final {
+    if (IsPrimitive(cn)) {
+      // Because we are in ANF we do not need to visit the arguments.
+      LetList& scope = scopes_.back();
+      std::vector<Expr> new_args;
+      for (const auto& it : cn->args) {
+        new_args.push_back(ExprMutator::Mutate(it));
+      }
+
+      Tuple ins(new_args);
+      Type ret_type = cn->checked_type_;
+      std::vector<TensorType> out_types = FlattenTupleType(ret_type);
+
+      // Handle fused op that only contains reshape op
+      if (IsReshapeOnly(cn->op)) {
+        Function func = Downcast<Function>(cn->op);
+        return EmitReshapeTensor(&scope, func, new_args, ret_type);
+      }
+
+      // Handle device copy op
+      if (IsDeviceCopy(cn->op)) {
+        Attrs attr;
+        if (const auto* fn = cn->op.as<FunctionNode>()) {
+          const auto* copy_call = fn->body.as<CallNode>();
+          CHECK(copy_call);
+          attr = copy_call->attrs;
+        } else {
+          attr = cn->attrs;
+        }
+        const DeviceCopyAttrs* copy_attr = attr.as<DeviceCopyAttrs>();
+        CHECK(copy_attr);
+        return DeviceCopy(new_args[0], copy_attr->src_dev_type, copy_attr->dst_dev_type);
+      } else if (IsDynamic(ret_type)) {
+        Function func = Downcast<Function>(cn->op);
+        return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type);
+      } else {
+        // Handle the static case
+        Array<Expr> outs;
+        for (size_t i = 0; i < out_types.size(); ++i) {
+          TVMContext ctx = GetContext(GetRef<Call>(cn));
+          auto out = MakeStaticAllocation(&scope, out_types[i], ctx, std::to_string(i));
+          outs.push_back(out);
+        }
+        Tuple output(outs);
+        Expr invoke = (*invoke_tvm_)(cn->op, ins, output);
+        scope.Push(invoke);
+        return ToTupleType(ret_type,
+                           std::vector<Expr>(output->fields.begin(), output->fields.end()));
+      }
+    } else {
+      return ExprMutator::VisitExpr_(cn);
+    }
+  }
+
+ private:
+  // Insert a device copy node.
+  Expr DeviceCopy(const Expr& inp, int src_ctx, int dst_ctx) {
+    return ExprMutator::Mutate((*device_copy_)(inp, src_ctx, dst_ctx));
+  }
+
+  // Check if a call invokes a primitive function.
+  bool IsPrimitive(const CallNode* call) const {
+    if (const auto* fn = call->op.as<FunctionNode>()) {
+      return fn->HasNonzeroAttr(attr::kPrimitive);
+    }
+    return false;
+  }
+
+  // Check if the current relay expression is a device copy call. We can simply
+  // check the body of it if it is a function because the device_copy op is opaque.
+  bool IsDeviceCopy(const Expr& expr) const {
+    if (const auto* fn = expr.as<FunctionNode>()) {
+      auto body = fn->body;
+      const CallNode* call = body.as<CallNode>();
+      return call && call->op == Op::Get("device_copy");
+    } else if (const CallNode* cn = expr.as<CallNode>()) {
+      return cn->op == Op::Get("device_copy");
+    } else {
+      return false;
+    }
+  }
+
+  Expr ComputeAlignment(const DataType& dtype) const {
+    int64_t align = dtype.bits() / 8 * dtype.lanes();
+    if (align < 64) {
+      align = 64;
+    }
+    return MakeConstantScalar(DataType::Int(64), align);
+  }
+
+  Expr ComputeStorageInRelay(const Expr& shape, const TensorType& type) const {
+    auto dtype = DataType(type->dtype);
+    Expr els = (*prod_)(shape, Array<Expr>(nullptr), false, false);
+    Expr num = MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes());
+    Expr add = (*add_)(num, MakeConstantScalar(DataType::Int(64), 7));
+    Expr div = MakeConstantScalar(DataType::Int(64), 8);
+    Expr ret = (*multiply_)(els, (*divide_)(add, div));
+    return std::move(ret);
+  }
+
+  Expr ComputeStorage(const TensorType& type) {
+    int64_t size = 1;
+    for (auto it : type->shape) {
+      auto val = it.as<IntImmNode>();
+      CHECK(val);
+      size *= val->value;
+    }
+    size *= (type->dtype.bits() * type->dtype.lanes() + 7) / 8;
+    return std::move(MakeConstantScalar(DataType::Int(64), size));
+  }
+
+  // Allocate a tensor with a statically known shape.
+  Var MakeStaticAllocation(LetList* scope, const TensorType& type, TVMContext ctx,
+                           String name_hint) {
+    std::vector<int64_t> int_shape;
+    for (auto it : type->shape) {
+      const auto* imm = it.as<IntImmNode>();
+      CHECK(imm) << "expect static int shape";
+      int_shape.push_back(imm->value);
+    }
+    Expr shape = MakeConstant(int_shape);
+    Expr size = ComputeStorage(type);
+    Expr alignment = ComputeAlignment(type->dtype);
+    // Run type inference later to get the correct type.
+    Var var("storage_" + name_hint, Type(nullptr));
+    Expr value = (*alloc_storage_)(size, alignment, ctx, type->dtype);
+    auto sto = scope->Push(var, value);
+
+    // TODO(@jroesch): There is a bug with typing based on the constant shape.
+    auto tensor = AllocTensor(sto, shape, type->dtype, type->shape);
+    Var tensor_var("tensor_" + name_hint, Type(nullptr));
+    return scope->Push(tensor_var, tensor);
+  }
+
+  // Insert the shape function given a primitive function.
+  Array<Expr> EmitShapeFunc(LetList* scope, const Function& func,
+                            const std::vector<Expr>& new_args) {
+    Array<Expr> shape_func_ins;
+    auto engine = CompileEngine::Global();
+    CCacheKey key(func, target_host_);
+    auto cfunc = engine->LowerShapeFunc(key);
+    auto input_states = cfunc->shape_func_param_states;
+
+    Array<Integer> is_inputs;
+    int input_pos = 0;
+    TVMContext cpu_ctx = default_context_;
+    CHECK_EQ(new_args.size(), input_states.size());
+    for (size_t i = 0; i < new_args.size(); ++i) {
+      Expr arg = new_args[i];
+      Type ty;
+      if (const auto* vn = arg.as<VarNode>()) {
+        ty = vn->type_annotation;
+      } else {
+        ty = arg->checked_type();
+      }
+      int state = input_states[i]->value;
+      // Pass Shapes
+      if (state == 2) {
+        std::vector<Expr> exprs = FromTupleType(ty, arg);
+        for (size_t j = 0; j < exprs.size(); ++j) {
+          Expr sh_of = ExprMutator::Mutate((*shape_of_)(exprs[j]));
+          Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr));
+          shape_func_ins.push_back(scope->Push(in_shape_var, sh_of));
+          input_pos++;
+        }
+        is_inputs.push_back(0);
+      } else if (state == 1) {
+        auto new_arg = ExprMutator::Mutate(arg);
+        auto ctx = GetContext(arg);
+        if (ctx.device_type != cpu_ctx.device_type) {
+          new_arg = DeviceCopy(new_arg, ctx.device_type, cpu_ctx.device_type);
+        }
+        Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr));
+        shape_func_ins.push_back(scope->Push(in_shape_var, new_arg));
+        input_pos++;
+        is_inputs.push_back(1);
+      } else {
+        // TODO(@jroesch): handle 3rd case
+        LOG(FATAL) << "unsupported shape function input state";
+      }
+    }
+
+    Array<Expr> out_shapes;
+    for (size_t i = 0; i < cfunc->outputs.size(); ++i) {
+      auto out = cfunc->outputs[i];
+      auto tt = TensorType(out->shape, out->dtype);
+      // Put shape func on CPU. This also ensures that everything between
+      // shape_of and shape_func are on CPU.
+      auto alloc = MakeStaticAllocation(scope, tt, cpu_ctx, std::to_string(i));
+      Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr));
+      alloc = scope->Push(shape_func_out_var, alloc);
+      out_shapes.push_back(alloc);
+    }
+    auto shape_call = (*shape_func_)(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs);
+    Var shape_func_var("shape_func", Type(nullptr));
+    scope->Push(shape_func_var, shape_call);
+    return out_shapes;
+  }
+
+  // Generate the code for invoking a TVM op with a dynamic shape.
+  Expr DynamicInvoke(LetList* scope, const Function& func, const Tuple& ins,
+                     const std::vector<Expr>& new_args, const std::vector<TensorType>& out_types,
+                     const Type& ret_type) {
+    auto out_shapes = EmitShapeFunc(scope, func, new_args);
+    std::vector<Var> storages;
+    auto func_ctx = GetContext(func);
+    CHECK_EQ(out_shapes.size(), out_types.size());
+    for (size_t i = 0; i < out_shapes.size(); ++i) {
+      auto out_shape = out_shapes[i];
+      auto out_type = out_types[i];
+      auto size = ComputeStorageInRelay(out_shape, out_type);
+      auto alignment = ComputeAlignment(out_type->dtype);
+      Var sto_var("storage_" + std::to_string(i), Type(nullptr));
+      auto val = (*alloc_storage_)(size, alignment, func_ctx, out_type->dtype);
+      storages.push_back(scope->Push(sto_var, val));
+    }
+
+    Array<Expr> outs;
+    for (size_t i = 0; i < storages.size(); ++i) {
+      auto out_shape = out_shapes[i];
+      auto out_type = out_types[i];
+      auto storage = storages[i];
+      auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape);
+      Var out_var("out_" + std::to_string(i), Type(nullptr));
+      outs.push_back(scope->Push(out_var, alloc));
+    }
+
+    Tuple tuple_outs(outs);
+    auto invoke = (*invoke_tvm_)(func, ins, tuple_outs);
+    scope->Push(invoke);
+    return ToTupleType(ret_type,
+                       std::vector<Expr>(tuple_outs->fields.begin(), tuple_outs->fields.end()));
+  }
+
+  Expr EmitReshapeTensor(LetList* scope, const Function& func, const std::vector<Expr>& new_args,
+                         const Type& ret_type) {
+    TensorType ret_ty = Downcast<TensorType>(ret_type);
+    Expr shape_expr;
+    if (IsDynamic(ret_type)) {
+      auto out_shapes = EmitShapeFunc(scope, func, new_args);
+      shape_expr = out_shapes[0];
+    } else {
+      std::vector<int64_t> shape;
+      for (const auto& it : ret_ty->shape) {
+        const auto* imm = it.as<IntImmNode>();
+        CHECK(imm) << "expect static int shape";
+        shape.push_back(imm->value);
+      }
+      shape_expr = MakeConstant(shape);
+    }
+    return (*reshape_tensor_)(new_args[0], shape_expr, ret_ty->shape);
+  }
+
+ private:
+  Target target_host_;
+  AnalysisResultMap context_analysis_map_;
+  std::vector<LetList> scopes_;
+
+  // Cache the following ops
+  const PackedFunc* device_copy_;
+  const PackedFunc* invoke_tvm_;
+  const PackedFunc* alloc_storage_;
+  const PackedFunc* shape_func_;
+  const PackedFunc* shape_of_;
+  const PackedFunc* reshape_tensor_;
+  const PackedFunc* prod_;
+  const PackedFunc* divide_;
+  const PackedFunc* add_;
+  const PackedFunc* multiply_;
+
+  runtime::DataType compute_dtype_ = runtime::DataType::Int(64);
+  TVMContext default_context_{kDLCPU, 0};
+};
+
+namespace transform {
+
+Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets) {
+  return tvm::transform::CreateModulePass(
+      [=](IRModule mod, const PassContext& pass_ctx) {
+        DLOG(INFO) << "tvm::relay::transform::ManifestAlloc";
+        // We need to mutate module, therefore making a copy of it.
+        mod.CopyOnWrite();
+        mod->ImportFromStd("core.rly");
+        mod = relay::transform::InferType()(mod);
+
+        TVMContext fallback_ctx;
+        if (targets.size() > 1) {
+          auto pass_ctx = PassContext::Current();
+          Optional<Integer> opt_fallback_dev =
+              pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast<int>(kDLCPU)));
+          auto fallback_dev = opt_fallback_dev.value();
+          CHECK_GT(fallback_dev->value, 0U);
+          fallback_ctx.device_type = static_cast<DLDeviceType>(fallback_dev->value);
+          fallback_ctx.device_id = 0;
+        } else {
+          const auto& it = targets.begin();
+          fallback_ctx.device_type = static_cast<DLDeviceType>((*it).first->value);
+          fallback_ctx.device_id = 0;
+        }
+        auto ca = ContextAnalysis(mod, fallback_ctx);
+
+        auto glob_funcs = mod->functions;
+        for (const auto& it : glob_funcs) {
+          if (auto* func_node = it.second.as<FunctionNode>()) {
+            auto func = GetRef<Function>(func_node);
+            auto rewriter = DialectRewriter(target_host, ca);
+            auto updated_func = rewriter.Rewrite(func);
+
+            mod->Update(it.first, updated_func);
+          }
+        }
+
+        mod = relay::transform::InferType()(mod);
+        return mod;
+      },
+      0, "ManifestAlloc", {});
+}
+
+TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc")
+    .set_body_typed([](Target target_host, Map<tvm::Integer, tvm::Target> targets) {
+      return ManifestAlloc(target_host, targets);
+    });
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 0b575d1..9d05631 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -54,7 +54,6 @@ def check_result(
     for kind in ["debug", "vm"]:
         targets = targets or tvm.testing.enabled_targets()
         for tgt, ctx in targets:
-            print(tgt)
             if disable_targets and tgt in disable_targets:
                 continue
             if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type):
diff --git a/tests/python/relay/test_memory_passes.py b/tests/python/relay/test_memory_passes.py
index c960d1f..546aaf5 100644
--- a/tests/python/relay/test_memory_passes.py
+++ b/tests/python/relay/test_memory_passes.py
@@ -18,7 +18,6 @@ import tvm
 from tvm import te
 import numpy as np
 from tvm import relay
-from tvm.relay import memory_alloc
 
 
 def check_memory_plan(func, check_fn):