You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2023/01/03 13:07:05 UTC

[GitHub] [tvm] ashutosh-arm commented on a diff in pull request #13655: [AOT] Added a test for detecting output size post MLF export

ashutosh-arm commented on code in PR #13655:
URL: https://github.com/apache/tvm/pull/13655#discussion_r1060568532


##########
tests/python/relay/aot/test_crt_aot.py:
##########
@@ -225,6 +225,64 @@ def test_packed_global_variables():
             assert f"{func}_packed" not in tvmgen_names
 
 
+def test_io_size_definition():
+    """Check network IO size definitions in the codegen output."""
+    dtype = "float32"
+    ishape = (1, 32, 14, 14)
+    wshape = (32, 32, 3, 3)
+    interface_api = "c"
+    use_unpacked_api = True
+
+    data0 = relay.var("data", shape=ishape, dtype=dtype)
+    weight0 = relay.var("weight", shape=wshape, dtype=dtype)
+    out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=1)
+    main_f = relay.Function([data0, weight0], out)
+    mod = tvm.IRModule()
+    mod["main"] = main_f
+    mod = transform.InferType()(mod)
+
+    i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+    w1_data = np.random.uniform(0, 1, wshape).astype(dtype)
+
+    inputs = OrderedDict([("data", i_data), ("weight", w1_data)])
+
+    output_list = generate_ref_data(mod, inputs)
+    compiled_models_list = compile_models(
+        models=AOTTestModel(module=mod, inputs=inputs, outputs=output_list),
+        interface_api=interface_api,
+        use_unpacked_api=use_unpacked_api,
+        workspace_byte_alignment=8,
+        enable_op_fusion=True,
+        pass_config=AOT_DEFAULT_RUNNER.pass_config,
+        use_runtime_executor=True,
+        target=tvm.target.Target("c"),
+    )
+    ref_output_size = output_list["output"].size * np.dtype(dtype).itemsize
+    compiled_model = compiled_models_list[0]
+
+    tmp_path = utils.tempdir()
+    base_path = tmp_path.temp_dir
+
+    model = compiled_model.model
+    tar_file = os.path.join(base_path, f"{model.name}.tar")
+    export_model_library_format(compiled_model.executor_factory, tar_file)
+    t = tarfile.open(tar_file)
+    t.extractall(base_path)
+
+    file_list = []
+    for path in (pathlib.Path(base_path) / "codegen" / "host" / "include").iterdir():
+        if path.is_file():
+            file_list.append(path)

Review Comment:
   I will extend the check for inputs. We could directly look for the file, but I thought that check maynot work for multiple models. But it does, so I will update that too.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org