You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/12 01:11:07 UTC

[GitHub] [tvm] huajsj commented on a change in pull request #9494: [Runtime] Pipeline Executor Add Set and Get Input/Output interfaces.

huajsj commented on a change in pull request #9494:
URL: https://github.com/apache/tvm/pull/9494#discussion_r747914365



##########
File path: src/runtime/pipeline/pipeline_struct.h
##########
@@ -134,37 +220,316 @@ struct OutputMap {
     }
   }
 };
+
+/*!
+ * \brief A map of the global module input interfaces and the graph modudles input interfaces.
+ */
+struct InputConnectionConfig {
+  /*!\brief The key is the name of global module input interfaces. the value is the pair of
+   * the index of a graph module and the name of a graph module input interface.
+   */
+  std::unordered_map<std::string, std::pair<int, std::string>> input_connection;
+  bool Empty() { return input_connection.empty(); }
+  std::pair<int, std::string> operator[](const std::string key) {
+    if (input_connection.find(key) == input_connection.end()) {
+      LOG(FATAL) << "Not find the key " << key;
+    }
+    return input_connection[key];
+  }
+
+  size_t size() const { return input_connection.size(); }
+  /*!
+   * \brief Create a input connection config from JSONReader.
+   * \param reader Json reader.
+   */
+  void Load(dmlc::JSONReader* reader) {
+    reader->BeginArray();
+    while (reader->NextArrayItem()) {
+      reader->BeginObject();
+      std::string key;
+      std::string global_interface_name;
+      std::string module_interface_name;
+      int mod_idx = -1;
+      while (reader->NextObjectItem(&key)) {
+        if (key == "global_interface_name") {
+          reader->Read(&global_interface_name);
+        } else if (key == "mod_idx") {
+          reader->Read(&mod_idx);
+        } else if (key == "module_interface_name") {
+          reader->Read(&module_interface_name);
+        } else {
+          LOG(FATAL) << "do not support key " << key;
+        }
+      }
+      ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx;
+      ICHECK(!global_interface_name.empty()) << "Invalid global interface name value";
+      ICHECK(!module_interface_name.empty()) << "Invalid module interface name value";
+      input_connection[global_interface_name] = make_pair(mod_idx, module_interface_name);
+    }
+  }
+};
+/*!
+ * \brief A map of the global module param interfaces and the graph modudles param.
+ */
+struct ParamConnectionConfig {
+  /*!\brief The key is the name of global module param interfaces. the value is the
+   * index of a graph module.
+   */
+  std::unordered_map<std::string, int> param_connection;
+  bool Empty() { return param_connection.empty(); }
+  int operator[](const std::string key) {
+    if (param_connection.find(key) == param_connection.end()) {
+      LOG(FATAL) << "do not support key " << key;
+    }
+    return param_connection[key];
+  }
+  /*!
+   * \brief Create a param connection config from JSONReader.
+   * \param reader Json reader.
+   */
+  void Load(dmlc::JSONReader* reader) {
+    reader->BeginArray();
+    while (reader->NextArrayItem()) {
+      reader->BeginObject();
+      std::string key;
+      std::string global_param_name;
+      int mod_idx = -1;
+      while (reader->NextObjectItem(&key)) {
+        if (key == "global_param_name") {
+          reader->Read(&global_param_name);
+        } else if (key == "mod_idx") {
+          reader->Read(&mod_idx);
+        } else {
+          LOG(FATAL) << "do not support key " << key;
+        }
+      }
+      ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx;
+      ICHECK(!global_param_name.empty()) << "Invalid global param name value";
+      param_connection[global_param_name] = mod_idx;
+    }
+  }
+};
 /*!
  * \brief The binding or dependency information of each module output interface.
  */
-struct PipelineConfig {
-  /*!\brief The key is the module index, this variable records all module pipeline configuration
+class ConfigPipelineExecution {
+ private:
+  /*
+   *!\brief The key is the module index, this variable records all module pipeline configuration
    * information.
    */
-  std::unordered_map<int, OutputMap> config;
-  OutputMap& operator[](int key) {
+  std::unordered_map<int, ConfigOutputBindings> config;
+  /*
+   *\brief The key is the global output index, this variable records the mapping of global output
+   * and the module output.
+   */
+  std::unordered_map<int, ModuleOutputPair> global_output_map;
+  /*
+   *\brief The number of binding of module outputs and inputs.
+   */
+  size_t module_input_output_binding_total_num;
+
+ public:
+  ConfigOutputBindings& operator[](int key) {
     ICHECK(config.find(key) != config.end());
     return config[key];
   }
+  /*!
+   *\brief Check if the module index existing in the "config".
+   */
+  bool FindModuleInConfig(int mod_idx) { return config.find(mod_idx) != config.end(); }
+  /*!
+   *\brief Build the mapping of key and "ConfigOutputBindings", key is module index.
+   */
+  void Insert(int key, const ConfigOutputBindings& map) { config[key] = map; }
 
-  void Insert(int key, const OutputMap& map) { config[key] = map; }
-
-  /*!\brief This function is used to verify whether config is loaded successfully.
+  /*
+   *!\brief This function is used to verify whether config is loaded successfully.
    * \return Return true to indicate that this class has not been successfully loaded.
    */
   bool Empty() { return config.empty(); }
-
   /*!
    * \brief Get the number of global outputs.
    * \return The number of outputs the entire pipeline has.
    */
   size_t GetGlobalOutputNum() const {
-    size_t num_output = 0;
+    // The number of pipeline outputs is the size of "global_output_map";
+    return global_output_map.size();
+  }
+  /*
+   *!\brief Get the map of global outputs and module outputs.
+   */
+  std::unordered_map<int, ModuleOutputPair>& GetGlobalConfigOutputBindings(void) {
+    return global_output_map;
+  }
+  /*
+   *!\brief Get the number of module output and module input bindings.
+   */
+  size_t GetInputOutputBindingNum() const { return module_input_output_binding_total_num; }
+  /*
+   *!\brief Parse the config to construct data struct using in pipeline execution.
+   */
+  void ParseConfiguration(const std::unordered_map<int, ConfigOutputBindings>& config) {
+    if (config.empty()) {
+      LOG(FATAL) << "The Configuration loading not finish yet.";
+    }
+    module_input_output_binding_total_num = 0;
     for (auto mod_output : config) {
-      num_output += mod_output.second.GetGlobalOutputNum();
+      // Get the numbers of binding of input and output.
+      module_input_output_binding_total_num += mod_output.second.GetInputOutputBindingNum();
+      // Use global output index as key to create a mapping of global index and module output.
+      const std::vector<GlobalOutputPair>& global_output =
+          mod_output.second.GetGlobalConfigOutputBindings();
+
+      for (auto output : global_output) {
+        global_output_map[output.global_output_idx] =
+            ModuleOutputPair(mod_output.first, output.mod_output_idx);
+      }
     }
-    return num_output;
+    return;
+  }
+  /*!
+   * \brief Create a pipeline config from JSONReader.
+   * \param reader Json reader.
+   */
+  void Load(dmlc::JSONReader* reader) {
+    reader->BeginArray();
+    while (reader->NextArrayItem()) {
+      std::string key;
+      reader->BeginObject();
+      int mod_idx = -1;
+      ConfigOutputBindings output;
+      std::string dev;
+      while (reader->NextObjectItem(&key)) {
+        if (key == "mod_idx") {
+          reader->Read(&mod_idx);
+        } else if (key == "dev") {
+          reader->Read(&dev);
+        } else if (key == "output") {
+          reader->Read(&output);
+        } else {
+          LOG(FATAL) << "do not support key " << key;
+        }
+      }
+      ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx;
+      // Check if the output is successfully read.
+      ICHECK(!output.Empty()) << "Invalid output binding result.";
+      Insert(mod_idx, output);
+    }
+    // Call this function after "config" loading finished.
+    ParseConfiguration(config);
+  }
+};
+/*
+ *\brief Runtime of backend.
+ */
+class BackendRuntime {
+ private:
+  /*\brief The index of runtime indicate the position in the pipeline.*/
+  int runtime_idx;
+  /*\brief The Runtime module of a backedn graph executor.*/
+  Module module;
+  /*!
+   *\brief To transfer data between two different backends, we need a local
+   * tensor variable as a medium. This variable is a mapping of input data and local
+   * data.
+   */
+  std::unordered_map<DLTensor*, DLTensor*> input_tensor_local_copy;
+  /*!\brief The packed functions.*/
+  tvm::runtime::PackedFunc run;
+  tvm::runtime::PackedFunc set_input;
+  tvm::runtime::PackedFunc get_input;
+  tvm::runtime::PackedFunc get_output;
+  tvm::runtime::PackedFunc get_num_output;
+  tvm::runtime::PackedFunc get_num_inputs;
+  tvm::runtime::PackedFunc get_input_index;
+  /*!\brief The new DLTensor have same shape, data type with a existing DLTensor.*/
+  DLTensor* CreateFromDLTensor(const DLTensor* from) {
+    DLTensor* ret = NULL;
+    TVMArrayAlloc(from->shape, from->ndim, from->dtype.code, from->dtype.bits, from->dtype.lanes,
+                  kDLCPU, 0, &ret);

Review comment:
       this is to help cross device data copy, when doing two non-cpu device NDArray copy , in current NDArray::CopyFromTo logic,  the "from" device DeviceAPI will get use to for the data copy, but the memory reference in "to" device NDArray may only available by using "to" device DeviceAPI, that means to finish a cross device copy we need to involve both "from" and "to" device DeviceAPI,  there solution here is to use a "cpu" NDArray as a medium to first copy data to CPU NDArray then copy the cpu NDArray to target device.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org