You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/08/30 08:37:29 UTC

[GitHub] [tvm] ekalda commented on a diff in pull request #12550: [AOT] Add AOTLowerMain pass to lower a Relay main into TIR

ekalda commented on code in PR #12550:
URL: https://github.com/apache/tvm/pull/12550#discussion_r958169995


##########
src/relay/backend/aot/aot_lower_main.cc:
##########
@@ -0,0 +1,868 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/aot_lower_main.cc
+ * \brief Lower the Relay main func into an AOT TIR main func.
+ */
+#include "./aot_lower_main.h"
+
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/transform.h>
+
+#include "../../op/call/call.h"
+#include "../../op/memory/device_copy.h"
+#include "../../op/memory/memory.h"
+#include "../../transforms/device_aware_visitors.h"
+#include "../name_transforms.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+using StorageMap =
+    std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+
+/*!
+ * \brief Assigns one or more StorageInfos to expressions requiring storage in a
+ * function to produce an Expr to StorageInfo map.
+ *
+ * This pass is leveraged by AOTMainLowerer to perform an initial naive allocation
+ * for tensors in the Relay main function. The resulting storage map is then lowered
+ * into TIR allocations by AOTMainLowerer where the allocation can be subsequently
+ * optimized by later passes (e.g. USMP).
+ */
+class ExprAllocator : public transform::DeviceAwareExprVisitor {
+ public:
+  ExprAllocator() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
+
+  // run the visitor on a global function.
+  void Run(const Function& func) { VisitExpr(func); }
+
+  std::vector<int> GetReturnSIDs() const { return return_sids_; }
+
+  StorageMap GetStorageMap() const { return expr_storage_map_; }
+
+  using ExprVisitor::VisitExpr_;
+
+  void DeviceAwareVisitExpr_(const CallNode* call_node) final {
+    Expr func;
+    Array<Expr> args;
+
+    CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+    if (call_lowered_props.lowered_func.defined()) {
+      func = call_lowered_props.lowered_func;
+      args = call_lowered_props.arguments;
+    } else {  // Relay functions that have not been lowered and lowered extern functions
+      func = call_node->op;
+      args = call_node->args;
+      if (call_node->op.as<GlobalVarNode>()) {  // Lowered extern function
+        ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
+      } else {  // Relay function which has not been lowered yet
+        ICHECK(call_node->op.as<FunctionNode>())
+            << "Expected the call to be to a lowered primfunc, a lowered extern function or a "
+               "unlowered Relay function.";
+      }
+    }
+    VisitExpr(func);
+    CreateStorage(call_node);
+    for (const Expr& arg : args) {
+      VisitExpr(arg);
+    }
+    AssignReturnSID(GetRef<Expr>(call_node));
+  }
+
+  void DeviceAwareVisitExpr_(const FunctionNode* func_node) final {
+    if (function_nesting() > 1) {
+      // Do not recurse into sub functions.
+      return;
+    }
+    if (func_node->HasNonzeroAttr(attr::kPrimitive)) {
+      // No storage needed for primitive functions
+      return;
+    }
+    for (const auto& param : func_node->params) {
+      CreateStorage(param.get());
+    }
+    VisitExpr(func_node->body);
+  }
+
+  void PreVisitLetBinding_(const Var& var, const Expr& value) final {
+    VisitExpr(value);
+    StorageInfo si = GetStorage(value);
+    expr_storage_map_[var] = si;
+  }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    CreateStorage(op);
+    AssignReturnSID(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const VarNode* op) final { AssignReturnSID(GetRef<Expr>(op)); }
+
+  void VisitExpr_(const TupleNode* op) final {
+    std::vector<int64_t> storage_ids;
+    std::vector<VirtualDevice> virtual_devices;
+    std::vector<int64_t> storage_sizes_in_bytes;
+    Expr expr = GetRef<Expr>(op);
+    for (Expr field : op->fields) {
+      auto sid = GetStorage(field);
+      storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end());
+      virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(),
+                             sid->virtual_devices.end());
+      storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(),
+                                    sid->storage_sizes_in_bytes.begin(),
+                                    sid->storage_sizes_in_bytes.end());
+    }
+    expr_storage_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes);
+    AssignReturnSID(expr);
+  }
+
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    Expr expr = GetRef<Expr>(op);
+    auto sids = GetStorage(op->tuple);
+    ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size());
+    expr_storage_map_[expr] =
+        StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]},
+                    {sids->storage_sizes_in_bytes[op->index]});
+    AssignReturnSID(expr);
+  }
+
+  void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "'If' is not supported."; }
+
+ private:
+  /*!
+   * \brief Assign the expression's storage IDs as the return storage IDs.
+   * \note This is called when visiting every expression on the understanding
+   * that the returned expression will be visited last.
+   */
+  void AssignReturnSID(const Expr& e) {
+    if (expr_storage_map_.find(e) != expr_storage_map_.end()) {
+      StorageInfo& sinfo = expr_storage_map_[e];
+      return_sids_.clear();
+      for (auto sid : sinfo->storage_ids) {
+        return_sids_.push_back(sid);
+      }
+    }
+  }
+
+  /*!
+   * \brief Get the necessary storage for the expression.
+   * \param expr The expression.
+   * \return The corresponding token.
+   */
+  StorageInfo GetStorage(const Expr& expr) {
+    // See through "on_device" calls.
+    Expr true_expr = IgnoreOnDevice(expr);
+    VisitExpr(true_expr);
+    auto it = expr_storage_map_.find(true_expr);
+    ICHECK(it != expr_storage_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " "
+                                          << PrettyPrint(true_expr) << " in storage device map";
+    return it->second;
+  }
+
+  /*!
+   * \brief Create storage for the expression.
+   */
+  void CreateStorage(const ExprNode* op) {
+    Expr expr = GetRef<Expr>(op);
+    return CreateStorage(expr, GetVirtualDevice(expr));
+  }
+
+  /*!
+   * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device.
+   */
+  void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) {
+    ICHECK(!virtual_device->IsFullyUnconstrained())
+        << "invalid virtual device for expr:" << std::endl
+        << PrettyPrint(expr);
+    std::vector<int64_t> storage_ids;
+    std::vector<VirtualDevice> virtual_devices;
+    std::vector<int64_t> storage_sizes_in_bytes;
+    for (const auto& ttype : FlattenTupleType(expr->checked_type())) {
+      storage_ids.push_back(next_available_sid_++);
+      virtual_devices.push_back(virtual_device);
+      storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype->shape, ttype->dtype));
+    }
+    expr_storage_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices),
+                                          std::move(storage_sizes_in_bytes));
+  }
+
+  /*! \brief Map between Exprs and StorageInfos */
+  StorageMap expr_storage_map_;
+  /*! \brief The next available storage ID to be used */
+  int next_available_sid_{0};
+  /*! \brief The storage IDs that correspond to return values */
+  std::vector<int> return_sids_;
+};
+
+class AOTMainLowerer : public MixedModeVisitor {
+ public:
+  AOTMainLowerer(tvm::CompilationConfig config, CallType call_type)
+      : config_(config), call_type_(call_type) {}
+
+  IRModule Lower(IRModule mod, String mod_name) {
+    VLOG_CONTEXT << "AOT";
+    IRModule lowered_mod = GetRef<IRModule>(mod.CopyOnWrite());
+
+    auto lowered_main = lowered_mod->Lookup("main");
+    auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
+
+    // Assign StorageInfo to all the Relay exprs
+    ExprAllocator expr_allocator;
+    expr_allocator.Run(lowered_main_func);
+    expr_storage_map_ = expr_allocator.GetStorageMap();
+
+    for (auto input : lowered_main_func->params) {
+      input_vars_.push_back(input);
+      std::string input_name = SanitizeName(input->name_hint());
+      // We dont want the compiler changing input names in the

Review Comment:
   Nit: I know it is copy pasted code, but
   ```suggestion
         // We don't want the compiler changing input names in the
   ```



##########
src/relay/backend/aot/aot_lower_main.cc:
##########
@@ -0,0 +1,868 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/aot_lower_main.cc
+ * \brief Lower the Relay main func into an AOT TIR main func.
+ */
+#include "./aot_lower_main.h"
+
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/transform.h>
+
+#include "../../op/call/call.h"
+#include "../../op/memory/device_copy.h"
+#include "../../op/memory/memory.h"
+#include "../../transforms/device_aware_visitors.h"
+#include "../name_transforms.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+using StorageMap =
+    std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+
+/*!
+ * \brief Assigns one or more StorageInfos to expressions requiring storage in a
+ * function to produce an Expr to StorageInfo map.

Review Comment:
   Would you mind making this sentence a bit clearer? It only started to make sense to me after I had looked at the code.



##########
src/relay/backend/aot/aot_lower_main.cc:
##########
@@ -0,0 +1,868 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/aot_lower_main.cc
+ * \brief Lower the Relay main func into an AOT TIR main func.
+ */
+#include "./aot_lower_main.h"
+
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/transform.h>
+
+#include "../../op/call/call.h"
+#include "../../op/memory/device_copy.h"
+#include "../../op/memory/memory.h"
+#include "../../transforms/device_aware_visitors.h"
+#include "../name_transforms.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+using StorageMap =
+    std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+
+/*!
+ * \brief Assigns one or more StorageInfos to expressions requiring storage in a
+ * function to produce an Expr to StorageInfo map.
+ *
+ * This pass is leveraged by AOTMainLowerer to perform an initial naive allocation
+ * for tensors in the Relay main function. The resulting storage map is then lowered
+ * into TIR allocations by AOTMainLowerer where the allocation can be subsequently
+ * optimized by later passes (e.g. USMP).
+ */
+class ExprAllocator : public transform::DeviceAwareExprVisitor {
+ public:
+  ExprAllocator() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
+
+  // run the visitor on a global function.
+  void Run(const Function& func) { VisitExpr(func); }
+
+  std::vector<int> GetReturnSIDs() const { return return_sids_; }
+
+  StorageMap GetStorageMap() const { return expr_storage_map_; }
+
+  using ExprVisitor::VisitExpr_;
+
+  void DeviceAwareVisitExpr_(const CallNode* call_node) final {
+    Expr func;
+    Array<Expr> args;
+
+    CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+    if (call_lowered_props.lowered_func.defined()) {
+      func = call_lowered_props.lowered_func;
+      args = call_lowered_props.arguments;
+    } else {  // Relay functions that have not been lowered and lowered extern functions
+      func = call_node->op;
+      args = call_node->args;
+      if (call_node->op.as<GlobalVarNode>()) {  // Lowered extern function
+        ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
+      } else {  // Relay function which has not been lowered yet
+        ICHECK(call_node->op.as<FunctionNode>())
+            << "Expected the call to be to a lowered primfunc, a lowered extern function or a "
+               "unlowered Relay function.";
+      }
+    }
+    VisitExpr(func);
+    CreateStorage(call_node);
+    for (const Expr& arg : args) {
+      VisitExpr(arg);
+    }
+    AssignReturnSID(GetRef<Expr>(call_node));
+  }
+
+  void DeviceAwareVisitExpr_(const FunctionNode* func_node) final {
+    if (function_nesting() > 1) {
+      // Do not recurse into sub functions.
+      return;
+    }
+    if (func_node->HasNonzeroAttr(attr::kPrimitive)) {
+      // No storage needed for primitive functions
+      return;
+    }
+    for (const auto& param : func_node->params) {
+      CreateStorage(param.get());
+    }
+    VisitExpr(func_node->body);
+  }
+
+  void PreVisitLetBinding_(const Var& var, const Expr& value) final {
+    VisitExpr(value);
+    StorageInfo si = GetStorage(value);
+    expr_storage_map_[var] = si;
+  }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    CreateStorage(op);
+    AssignReturnSID(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const VarNode* op) final { AssignReturnSID(GetRef<Expr>(op)); }
+
+  void VisitExpr_(const TupleNode* op) final {
+    std::vector<int64_t> storage_ids;
+    std::vector<VirtualDevice> virtual_devices;
+    std::vector<int64_t> storage_sizes_in_bytes;
+    Expr expr = GetRef<Expr>(op);
+    for (Expr field : op->fields) {
+      auto sid = GetStorage(field);
+      storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end());
+      virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(),
+                             sid->virtual_devices.end());
+      storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(),
+                                    sid->storage_sizes_in_bytes.begin(),
+                                    sid->storage_sizes_in_bytes.end());
+    }
+    expr_storage_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes);
+    AssignReturnSID(expr);
+  }
+
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    Expr expr = GetRef<Expr>(op);
+    auto sids = GetStorage(op->tuple);
+    ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size());
+    expr_storage_map_[expr] =
+        StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]},
+                    {sids->storage_sizes_in_bytes[op->index]});
+    AssignReturnSID(expr);
+  }
+
+  void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "'If' is not supported."; }
+
+ private:
+  /*!
+   * \brief Assign the expression's storage IDs as the return storage IDs.
+   * \note This is called when visiting every expression on the understanding
+   * that the returned expression will be visited last.
+   */
+  void AssignReturnSID(const Expr& e) {
+    if (expr_storage_map_.find(e) != expr_storage_map_.end()) {
+      StorageInfo& sinfo = expr_storage_map_[e];
+      return_sids_.clear();
+      for (auto sid : sinfo->storage_ids) {
+        return_sids_.push_back(sid);
+      }
+    }
+  }
+
+  /*!
+   * \brief Get the necessary storage for the expression.
+   * \param expr The expression.
+   * \return The corresponding token.
+   */
+  StorageInfo GetStorage(const Expr& expr) {
+    // See through "on_device" calls.
+    Expr true_expr = IgnoreOnDevice(expr);
+    VisitExpr(true_expr);
+    auto it = expr_storage_map_.find(true_expr);
+    ICHECK(it != expr_storage_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " "
+                                          << PrettyPrint(true_expr) << " in storage device map";
+    return it->second;
+  }
+
+  /*!
+   * \brief Create storage for the expression.
+   */
+  void CreateStorage(const ExprNode* op) {
+    Expr expr = GetRef<Expr>(op);
+    return CreateStorage(expr, GetVirtualDevice(expr));
+  }
+
+  /*!
+   * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device.
+   */
+  void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) {
+    ICHECK(!virtual_device->IsFullyUnconstrained())
+        << "invalid virtual device for expr:" << std::endl
+        << PrettyPrint(expr);
+    std::vector<int64_t> storage_ids;
+    std::vector<VirtualDevice> virtual_devices;
+    std::vector<int64_t> storage_sizes_in_bytes;
+    for (const auto& ttype : FlattenTupleType(expr->checked_type())) {
+      storage_ids.push_back(next_available_sid_++);
+      virtual_devices.push_back(virtual_device);
+      storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype->shape, ttype->dtype));
+    }
+    expr_storage_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices),
+                                          std::move(storage_sizes_in_bytes));
+  }
+
+  /*! \brief Map between Exprs and StorageInfos */
+  StorageMap expr_storage_map_;
+  /*! \brief The next available storage ID to be used */
+  int next_available_sid_{0};
+  /*! \brief The storage IDs that correspond to return values */
+  std::vector<int> return_sids_;
+};
+
+class AOTMainLowerer : public MixedModeVisitor {
+ public:
+  AOTMainLowerer(tvm::CompilationConfig config, CallType call_type)
+      : config_(config), call_type_(call_type) {}
+
+  IRModule Lower(IRModule mod, String mod_name) {
+    VLOG_CONTEXT << "AOT";
+    IRModule lowered_mod = GetRef<IRModule>(mod.CopyOnWrite());
+
+    auto lowered_main = lowered_mod->Lookup("main");
+    auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
+
+    // Assign StorageInfo to all the Relay exprs
+    ExprAllocator expr_allocator;
+    expr_allocator.Run(lowered_main_func);
+    expr_storage_map_ = expr_allocator.GetStorageMap();
+
+    for (auto input : lowered_main_func->params) {
+      input_vars_.push_back(input);
+      std::string input_name = SanitizeName(input->name_hint());
+      // We dont want the compiler changing input names in the
+      // event of a sanitization collision. Therefore, enforcing
+      // the var created to use the input_name strictly.
+      CreateIOVar(input, input_name, /*use_unique_name = */ false);
+    }
+
+    // Define the storage allocator ids
+    for (auto kv : expr_storage_map_) {
+      for (auto sid : kv.second->storage_ids) {
+        // The buffer_var is created with storage_scope to be global.workspace to be serviced by
+        // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor
+        // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and
+        // should not be lowered to the stack. For more details please refer to the discussion here:
+        // https://github.com/apache/tvm/issues/9022
+        tir::Var buffer_var(MakeString("sid_", sid),
+                            PointerType(PrimType(DataType::Int(8)), "global.workspace"));
+        sids_table_[sid] = buffer_var;
+      }
+    }
+
+    // Retrieve the return sids
+    return_sid_ = expr_allocator.GetReturnSIDs();
+    // Create output vars for the TIR main func
+    // If output tensor names were provided use them
+    if (auto opt = lowered_main->GetAttr<Array<String>>("output_tensor_names")) {
+      Array<String> output_tensor_names = opt.value();
+      Expr output_expr = lowered_main_func->body;
+      if (output_expr->checked_type()->IsInstance<TupleTypeNode>()) {
+        TupleType output_tuple_type = Downcast<TupleType>(output_expr->checked_type());
+        for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) {
+          // AoT Executor Codegen does not create these names,
+          // thus should be used as they are provided.
+          CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i],
+                      /*use_unique_name = */ false);
+        }
+      } else {
+        // AoT Executor Codegen does not create these names,
+        // thus should be used as they are provided.
+        CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false);
+      }
+    } else {
+      // If output tensor names are not provided we will generate output(x)
+      // where x is a counter to create unique names.
+      if (lowered_main_func->body->checked_type()->IsInstance<TupleTypeNode>()) {
+        CreateIOVar(lowered_main_func->body, "output");
+      } else {
+        CreateIOVar(lowered_main_func->body, "output", /*use_unique_name = */ false);
+      }
+    }
+
+    CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts")
+                               .value_or(Map<GlobalVar, String>()));
+    VisitExpr(lowered_main_func->body);
+
+    lowered_mod->Remove(lowered_mod->GetGlobalVar("main"));
+    auto tir_main_func = CreateMainFunc(mod_name);
+    lowered_mod->Update(GlobalVar(runtime::symbol::tvm_module_main), tir_main_func);
+    lowered_mod = tir::transform::RemoveNoOp()(lowered_mod);

Review Comment:
   What's the reason for running this pass here instead of the "AoT pipeleline"?



##########
src/relay/backend/aot/aot_lower_main.cc:
##########
@@ -0,0 +1,868 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/aot_lower_main.cc
+ * \brief Lower the Relay main func into an AOT TIR main func.
+ */
+#include "./aot_lower_main.h"
+
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/transform.h>
+
+#include "../../op/call/call.h"
+#include "../../op/memory/device_copy.h"
+#include "../../op/memory/memory.h"
+#include "../../transforms/device_aware_visitors.h"
+#include "../name_transforms.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+using StorageMap =
+    std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+
+/*!
+ * \brief Assigns one or more StorageInfos to expressions requiring storage in a
+ * function to produce an Expr to StorageInfo map.
+ *
+ * This pass is leveraged by AOTMainLowerer to perform an initial naive allocation
+ * for tensors in the Relay main function. The resulting storage map is then lowered
+ * into TIR allocations by AOTMainLowerer where the allocation can be subsequently
+ * optimized by later passes (e.g. USMP).
+ */
+class ExprAllocator : public transform::DeviceAwareExprVisitor {
+ public:
+  ExprAllocator() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
+
+  // run the visitor on a global function.
+  void Run(const Function& func) { VisitExpr(func); }
+
+  std::vector<int> GetReturnSIDs() const { return return_sids_; }
+
+  StorageMap GetStorageMap() const { return expr_storage_map_; }
+
+  using ExprVisitor::VisitExpr_;
+
+  void DeviceAwareVisitExpr_(const CallNode* call_node) final {
+    Expr func;
+    Array<Expr> args;
+
+    CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+    if (call_lowered_props.lowered_func.defined()) {
+      func = call_lowered_props.lowered_func;
+      args = call_lowered_props.arguments;
+    } else {  // Relay functions that have not been lowered and lowered extern functions
+      func = call_node->op;
+      args = call_node->args;
+      if (call_node->op.as<GlobalVarNode>()) {  // Lowered extern function
+        ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
+      } else {  // Relay function which has not been lowered yet
+        ICHECK(call_node->op.as<FunctionNode>())
+            << "Expected the call to be to a lowered primfunc, a lowered extern function or a "
+               "unlowered Relay function.";
+      }
+    }
+    VisitExpr(func);
+    CreateStorage(call_node);
+    for (const Expr& arg : args) {
+      VisitExpr(arg);
+    }
+    AssignReturnSID(GetRef<Expr>(call_node));
+  }
+
+  void DeviceAwareVisitExpr_(const FunctionNode* func_node) final {
+    if (function_nesting() > 1) {
+      // Do not recurse into sub functions.
+      return;
+    }
+    if (func_node->HasNonzeroAttr(attr::kPrimitive)) {
+      // No storage needed for primitive functions
+      return;
+    }
+    for (const auto& param : func_node->params) {
+      CreateStorage(param.get());
+    }
+    VisitExpr(func_node->body);
+  }
+
+  void PreVisitLetBinding_(const Var& var, const Expr& value) final {
+    VisitExpr(value);
+    StorageInfo si = GetStorage(value);
+    expr_storage_map_[var] = si;
+  }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    CreateStorage(op);
+    AssignReturnSID(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const VarNode* op) final { AssignReturnSID(GetRef<Expr>(op)); }
+
+  void VisitExpr_(const TupleNode* op) final {
+    std::vector<int64_t> storage_ids;
+    std::vector<VirtualDevice> virtual_devices;
+    std::vector<int64_t> storage_sizes_in_bytes;
+    Expr expr = GetRef<Expr>(op);
+    for (Expr field : op->fields) {
+      auto sid = GetStorage(field);
+      storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end());
+      virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(),
+                             sid->virtual_devices.end());
+      storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(),
+                                    sid->storage_sizes_in_bytes.begin(),
+                                    sid->storage_sizes_in_bytes.end());
+    }
+    expr_storage_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes);
+    AssignReturnSID(expr);
+  }
+
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    Expr expr = GetRef<Expr>(op);
+    auto sids = GetStorage(op->tuple);
+    ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size());
+    expr_storage_map_[expr] =
+        StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]},
+                    {sids->storage_sizes_in_bytes[op->index]});
+    AssignReturnSID(expr);
+  }
+
+  void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "'If' is not supported."; }
+
+ private:
+  /*!
+   * \brief Assign the expression's storage IDs as the return storage IDs.
+   * \note This is called when visiting every expression on the understanding
+   * that the returned expression will be visited last.
+   */
+  void AssignReturnSID(const Expr& e) {
+    if (expr_storage_map_.find(e) != expr_storage_map_.end()) {
+      StorageInfo& sinfo = expr_storage_map_[e];
+      return_sids_.clear();
+      for (auto sid : sinfo->storage_ids) {
+        return_sids_.push_back(sid);
+      }
+    }
+  }
+
+  /*!
+   * \brief Get the necessary storage for the expression.
+   * \param expr The expression.
+   * \return The corresponding token.
+   */
+  StorageInfo GetStorage(const Expr& expr) {
+    // See through "on_device" calls.
+    Expr true_expr = IgnoreOnDevice(expr);
+    VisitExpr(true_expr);
+    auto it = expr_storage_map_.find(true_expr);
+    ICHECK(it != expr_storage_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " "
+                                          << PrettyPrint(true_expr) << " in storage device map";
+    return it->second;
+  }
+
+  /*!
+   * \brief Create storage for the expression.
+   */
+  void CreateStorage(const ExprNode* op) {
+    Expr expr = GetRef<Expr>(op);
+    return CreateStorage(expr, GetVirtualDevice(expr));
+  }
+
+  /*!
+   * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device.
+   */
+  void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) {
+    ICHECK(!virtual_device->IsFullyUnconstrained())
+        << "invalid virtual device for expr:" << std::endl
+        << PrettyPrint(expr);
+    std::vector<int64_t> storage_ids;
+    std::vector<VirtualDevice> virtual_devices;
+    std::vector<int64_t> storage_sizes_in_bytes;
+    for (const auto& ttype : FlattenTupleType(expr->checked_type())) {
+      storage_ids.push_back(next_available_sid_++);
+      virtual_devices.push_back(virtual_device);
+      storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype->shape, ttype->dtype));
+    }
+    expr_storage_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices),
+                                          std::move(storage_sizes_in_bytes));
+  }
+
+  /*! \brief Map between Exprs and StorageInfos */
+  StorageMap expr_storage_map_;
+  /*! \brief The next available storage ID to be used */
+  int next_available_sid_{0};
+  /*! \brief The storage IDs that correspond to return values */
+  std::vector<int> return_sids_;
+};
+
+class AOTMainLowerer : public MixedModeVisitor {
+ public:
+  AOTMainLowerer(tvm::CompilationConfig config, CallType call_type)
+      : config_(config), call_type_(call_type) {}
+
+  IRModule Lower(IRModule mod, String mod_name) {
+    VLOG_CONTEXT << "AOT";
+    IRModule lowered_mod = GetRef<IRModule>(mod.CopyOnWrite());
+
+    auto lowered_main = lowered_mod->Lookup("main");
+    auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
+
+    // Assign StorageInfo to all the Relay exprs
+    ExprAllocator expr_allocator;
+    expr_allocator.Run(lowered_main_func);
+    expr_storage_map_ = expr_allocator.GetStorageMap();
+
+    for (auto input : lowered_main_func->params) {
+      input_vars_.push_back(input);
+      std::string input_name = SanitizeName(input->name_hint());
+      // We dont want the compiler changing input names in the
+      // event of a sanitization collision. Therefore, enforcing
+      // the var created to use the input_name strictly.
+      CreateIOVar(input, input_name, /*use_unique_name = */ false);
+    }
+
+    // Define the storage allocator ids
+    for (auto kv : expr_storage_map_) {
+      for (auto sid : kv.second->storage_ids) {
+        // The buffer_var is created with storage_scope to be global.workspace to be serviced by
+        // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor
+        // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and
+        // should not be lowered to the stack. For more details please refer to the discussion here:
+        // https://github.com/apache/tvm/issues/9022
+        tir::Var buffer_var(MakeString("sid_", sid),
+                            PointerType(PrimType(DataType::Int(8)), "global.workspace"));
+        sids_table_[sid] = buffer_var;
+      }
+    }
+
+    // Retrieve the return sids
+    return_sid_ = expr_allocator.GetReturnSIDs();
+    // Create output vars for the TIR main func
+    // If output tensor names were provided use them
+    if (auto opt = lowered_main->GetAttr<Array<String>>("output_tensor_names")) {
+      Array<String> output_tensor_names = opt.value();
+      Expr output_expr = lowered_main_func->body;
+      if (output_expr->checked_type()->IsInstance<TupleTypeNode>()) {
+        TupleType output_tuple_type = Downcast<TupleType>(output_expr->checked_type());
+        for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) {
+          // AoT Executor Codegen does not create these names,
+          // thus should be used as they are provided.
+          CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i],
+                      /*use_unique_name = */ false);
+        }
+      } else {
+        // AoT Executor Codegen does not create these names,
+        // thus should be used as they are provided.
+        CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false);
+      }
+    } else {
+      // If output tensor names are not provided we will generate output(x)
+      // where x is a counter to create unique names.
+      if (lowered_main_func->body->checked_type()->IsInstance<TupleTypeNode>()) {
+        CreateIOVar(lowered_main_func->body, "output");
+      } else {
+        CreateIOVar(lowered_main_func->body, "output", /*use_unique_name = */ false);
+      }
+    }
+
+    CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts")
+                               .value_or(Map<GlobalVar, String>()));
+    VisitExpr(lowered_main_func->body);
+
+    lowered_mod->Remove(lowered_mod->GetGlobalVar("main"));

Review Comment:
   Nit: There used to be a comment there about removing Relay main and replacing it with TIR one which I think was quite useful.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org