You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/25 22:47:33 UTC

[GitHub] [tvm] giuseros commented on a change in pull request #8072: Add "operator" style to Model Library Format

giuseros commented on a change in pull request #8072:
URL: https://github.com/apache/tvm/pull/8072#discussion_r639251955



##########
File path: python/tvm/micro/model_library_format.py
##########
@@ -203,60 +209,200 @@ def _build_function_memory_map(function_metadata):
     return ret
 
 
-def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name):
-    """Export the build artifact in Model Library Format.
+def _make_tar(source_dir, tar_file_path):
+    """Build a tar file from source_dir."""
+    with tarfile.open(tar_file_path, "w") as tar_f:
 
-    This function creates a .tar archive containing the build artifacts in a standardized
-    layout. It's intended to allow downstream automation to build TVM artifacts against the C
-    runtime.
+        def reset(tarinfo):
+            tarinfo.uid = tarinfo.gid = 0
+            tarinfo.uname = tarinfo.gname = "root"
+            return tarinfo
+
+        tar_f.add(str(source_dir), arcname=".", filter=reset)
+
+
+_GENERATED_VERSION = 2
+
+
+def _export_graph_model_library_format(
+    mod: executor_factory.GraphExecutorFactoryModule, tempdir: pathlib.Path

Review comment:
       Shouldn't this be ExecutorFactoryModule to be compatible with AOT as well?

##########
File path: src/printer/model_library_format_printer.cc
##########
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/var.h>
+
+#include "text_printer.h"
+
+namespace tvm {
+namespace printer {
+
+class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode {
+ public:
+  ModelLibraryFormatPrinter(bool show_meta_data,
+                            const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate,
+                            bool show_warning)
+      : text_printer_{show_meta_data, annotate, show_warning} {}
+
+  const char* type_key() const override { return "model_library_format_printer"; }
+
+  std::string Print(const ObjectRef& node) {
+    Doc doc;
+    doc << text_printer_.PrintFinal(node);
+    return doc.str();
+  }
+
+  PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override {
+    if (name == "print") {
+      return TypedPackedFunc<std::string(ObjectRef)>(
+          [sptr_to_self, this](ObjectRef node) { return Print(node); });
+    } else if (name == "get_var_name") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        ICHECK_EQ(args.size(), 1) << "usage: get_var_name(Var v)";
+
+        std::string var_name;
+        if (text_printer_.GetVarName(args[0], &var_name)) {
+          *rv = var_name;
+        }

Review comment:
       Should this ICHECK if GetVarName returns false?

##########
File path: python/tvm/micro/model_library_format.py
##########
@@ -203,60 +209,200 @@ def _build_function_memory_map(function_metadata):
     return ret
 
 
-def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name):
-    """Export the build artifact in Model Library Format.
+def _make_tar(source_dir, tar_file_path):
+    """Build a tar file from source_dir."""
+    with tarfile.open(tar_file_path, "w") as tar_f:
 
-    This function creates a .tar archive containing the build artifacts in a standardized
-    layout. It's intended to allow downstream automation to build TVM artifacts against the C
-    runtime.
+        def reset(tarinfo):
+            tarinfo.uid = tarinfo.gid = 0
+            tarinfo.uname = tarinfo.gname = "root"
+            return tarinfo
+
+        tar_f.add(str(source_dir), arcname=".", filter=reset)
+
+
+_GENERATED_VERSION = 2
+
+
+def _export_graph_model_library_format(
+    mod: executor_factory.GraphExecutorFactoryModule, tempdir: pathlib.Path
+):
+    """Export a tvm.relay.build artifact in Model Library Format.
 
     Parameters
     ----------
     mod : tvm.relay.backend.executor_factory.ExecutorFactoryModule
         The return value of tvm.relay.build, which will be exported into Model Library Format.
-    file_name : str
-        Path to the .tar archive to generate.
+    tempdir : pathlib.Path
+        Temporary directory to populate with Model Library Format contents.
     """
-    tempdir = utils.tempdir()
     is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule)
     runtime = ["aot"] if is_aot else ["graph"]
 
     metadata = {
-        "version": 2,
+        "version": _GENERATED_VERSION,
         "model_name": mod.libmod_name,
         "export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
         "memory": _build_memory_map(mod),
         "target": {int(k): str(v) for k, v in mod.target.items()},
         "runtimes": runtime,
+        "style": "full-model",
     }
 
-    with open(tempdir.relpath("metadata.json"), "w") as json_f:
+    with open(tempdir / "metadata.json", "w") as json_f:
         json.dump(metadata, json_f, indent=2, sort_keys=True)
 
-    codegen_dir_path = tempdir.relpath("codegen")
-    os.mkdir(codegen_dir_path)
-    _populate_codegen_dir(mod.lib, codegen_dir_path)
+    codegen_dir = tempdir / "codegen"
+    codegen_dir.mkdir()
+    _populate_codegen_dir(mod.lib, codegen_dir)
 
-    parameters_dir_path = tempdir.relpath("parameters")
-    os.mkdir(parameters_dir_path)
-    param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params")
+    parameters_dir = tempdir / "parameters"
+    parameters_dir.mkdir()
+    param_filename = parameters_dir / f"{mod.libmod_name}.params"
     with open(param_filename, "wb") as f:
         f.write(param_dict.save_param_dict(mod.params))
 
-    with open(tempdir.relpath("relay.txt"), "w") as f:
+    src_dir = tempdir / "src"
+    src_dir.mkdir()
+    with open(src_dir / "relay.txt", "w") as f:
         f.write(str(mod.ir_mod))
 
     if not is_aot:
-        graph_config_dir_path = tempdir.relpath(os.path.join("runtime-config", "graph"))
-        os.makedirs(graph_config_dir_path)
-        with open(os.path.join(graph_config_dir_path, "graph.json"), "w") as f:
+        graph_config_dir = tempdir / "runtime-config" / "graph"
+        graph_config_dir.mkdir(parents=True)
+        with open(graph_config_dir / "graph.json", "w") as f:
             f.write(mod.get_executor_config())
 
-    with tarfile.open(file_name, "w") as tar_f:
 
-        def reset(tarinfo):
-            tarinfo.uid = tarinfo.gid = 0
-            tarinfo.uname = tarinfo.gname = "root"
-            return tarinfo
+class NonStaticShapeError(Exception):
+    """Raised when a shape has elements other than IntImm."""
+
+
+def _shape_to_size(shape, dtype):
+    bits_per_item = int(
+        re.match(r"((float)|(int))(?P<width_bits>[0-9]+)", dtype).group("width_bits")
+    )
+    assert bits_per_item is not None, f"don't know how to compute size of type {dtype}"
+    total_bits = bits_per_item
+    for s in shape:
+        total_bits *= s
+
+    return (total_bits + 7) // 8
+
+
+def _write_tir_and_build_operator_memory_map(src_dir, targets, ir_module_by_target):
+    def _eval_shape(param_name, buffer_shape):
+        shape = []
+        for x in buffer_shape:
+            if not isinstance(x, expr.IntImm):
+                raise NonStaticShapeError(
+                    f"Parameter {param_name} has shape with non-IntImm elements: {buffer_shape}"
+                )
+            shape.append(x.value)
+        return shape
+
+    memory_map = {}
+    for target_device_type, target in targets.items():
+        ir_mod = ir_module_by_target[target]
+        printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False)
+        with open(src_dir / f"tir-{target_device_type}.txt", "w") as f:
+            f.write(printer["print"](ir_mod))

Review comment:
       I am not following why adding the TIR in the archive. Is this for test purposes? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org