You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/09/14 14:23:57 UTC
[tvm] branch main updated: [Relay][TE] Use Relay parameter name to generated TE tensor name (#10516)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a40849342d [Relay][TE] Use Relay parameter name to generated TE tensor name (#10516)
a40849342d is described below
commit a40849342d250bd585e19434e4a2473fcf978bcb
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Wed Sep 14 09:23:51 2022 -0500
[Relay][TE] Use Relay parameter name to generated TE tensor name (#10516)
* [Relay][TE] Use Relay parameter name to generated TE tensor name
Previously, the TE placeholders representing relay function parameters
were all named `"placeholder"`, which could be difficult to follow
when debugging larger functions.
---
.../ci_logs/resnet-18-NHWC-B1-cuda.json | 50 +++++++++++-----------
python/tvm/auto_scheduler/measure.py | 17 ++++++--
python/tvm/auto_scheduler/relay_integration.py | 5 ++-
.../backend/contrib/ethosu/tir_to_cs_translator.py | 2 +-
src/relay/backend/te_compiler_cache.cc | 9 ++--
5 files changed, 48 insertions(+), 35 deletions(-)
diff --git a/gallery/how_to/tune_with_autoscheduler/ci_logs/resnet-18-NHWC-B1-cuda.json b/gallery/how_to/tune_with_autoscheduler/ci_logs/resnet-18-NHWC-B1-cuda.json
index 7cb3a67067..c8b9f41a5c 100644
--- a/gallery/how_to/tune_with_autoscheduler/ci_logs/resnet-18-NHWC-B1-cuda.json
+++ b/gallery/how_to/tune_with_autoscheduler/ci_logs/resnet-18-NHWC-B1-cuda.json
@@ -1,26 +1,24 @@
-# Provide valid schedules for resnet-18 on GPU.
-# This is used to run the tutorial on the documentation web server.
-{"i": [["[\"d7b65649a4dd54becea0a52aabbc5af5\", 1, 1000, 1, 1000]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["SP", 4, 1, 1000, [40], 1], ["AN", 4, 2, 6], ["FSP", 3, 1, 0, 1], ["AN", 3, 2, 6], ["CA", 3, 4, 0], ["CI", 2], ["FSP", 1, 1, 0, 1], ["AN", 1, 2, 6], ["CA", 1, 4, 0], ["AN", 4, 0, 5], ["PR", 1, 0, "auto_unroll_max_step$512"], ["PR", 3, 0, "auto_unroll_max_step$512"]]]], "r": [[4.87396e-06], 0, 1.3 [...]
-{"i": [["[\"9847f8cc0b305137f49f2c5c0c8ab25d\", 1, 512, 1000, 512, 1000, 1, 1000]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [1, 50, 1, 1], 1], ["SP", 2, 10, 512, [1, 16], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], ["CHR", 1, "shared", [2]], [...]
-{"i": [["[\"69115f188984ae34ede37c3b8ca40b43\", 1, 7, 7, 512, 1, 1, 1, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 512, [2], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["FU", 1, [0, 1, 2, 3]], ["SP", 1, 0, 512, [32], 1], ["AN", 1, 0, 5], ["AN", 1, 1, 6], ["PR", 1, 0, "auto_unroll_max_step$64"]]]], "r": [[3.91068e-06], 0, 1.63708, 1606984742], "v": "v0.5"}
-{"i": [["[\"ad6cecbf5d85cb1cda3c2bb7af170211\", 1, 7, 7, 512, 4, 4, 512, 512, 1, 7, 7, 512, 1, 1, 1, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [2], 1], ["SP", 8, 4, 512, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "l [...]
-{"i": [["[\"3a69f9fbc63760d99e36b4c17b3bfc57\", 1, 7, 7, 512, 4, 4, 512, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [4], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5 [...]
-{"i": [["[\"d730bcd28f0920f6b97245e2a11bd8d6\", 1, 7, 7, 512, 4, 4, 512, 512, 1, 7, 7, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [1], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, [...]
-{"i": [["[\"f3b6c10fcc6ce01ff01add933e4d21e9\", 1, 14, 14, 256, 4, 4, 256, 256, 1, 14, 14, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, [...]
-{"i": [["[\"b8b52b9be9df6102466a22a014c44c1f\", 1, 14, 14, 256, 4, 4, 256, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", [...]
-{"i": [["[\"d374e472bd9d8164892b9e28a0a8cb59\", 1, 14, 14, 256, 4, 4, 256, 256, 1, 14, 14, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [7], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 4, 1, 1], 1], ["SP", 6, 5, 4, [ [...]
-{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 28, 28, 128, 3, 3, 128, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [2, 7, 1, 1], 1], ["SP", 3, 10, 14, [1, 7, 2, 1], 1], ["SP", 3, 15, 256, [2, 2, 1, 4], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 128, [4, 1], 1], ["RE", 3, [0 [...]
-{"i": [["[\"c4500b4e2fd04e695c32d2f31bbdc14a\", 1, 28, 28, 128, 4, 4, 128, 128, 1, 28, 28, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0 [...]
-{"i": [["[\"e4cdf917b876dbdd64488c3818d9c141\", 1, 28, 28, 128, 4, 4, 128, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 4], 1], ["SP", [...]
-{"i": [["[\"dac19035dd5fe9424ee8617421b9c817\", 1, 28, 28, 128, 4, 4, 128, 128, 1, 28, 28, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [...]
-{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 56, 56, 64, 3, 3, 64, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 1, 2, 1], 1], ["SP", 3, 10, 28, [1, 7, 2, 2], 1], ["SP", 3, 15, 128, [1, 8, 8, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 64, [4, 2], 1], ["RE", 3, [0, 5 [...]
-{"i": [["[\"1e3c4211ffd2f2db91078ae4d04b779d\", 1, 56, 56, 64, 6, 6, 64, 64, 1, 56, 56, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, [...]
-{"i": [["[\"b818b53148cd450f86569dfc3e04cb8a\", 1, 56, 56, 64, 6, 6, 64, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 3, 2, 1], 1], ["SP", 6, 5 [...]
-{"i": [["[\"3ea73fb9b0364374730d09e068821f95\", 1, 56, 56, 64, 6, 6, 64, 64, 1, 56, 56, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [49], 1], ["SP", 8, 4, 64, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 3], 1], ["SP", 6, 5, 6, [1, 3 [...]
-{"i": [["[\"a5612fdeb9db4d579a75ec225ea4c06a\", 1, 112, 112, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 4], ["CI", 1], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 200704, [64], 1], ["AN", 5, 0, 5], ["AN", 5, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 200704, [64], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["PR", 2, 0, "auto_unroll_max_step$16"]]]], "r": [[2.00968e-05], 0, 1 [...]
-{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 224, 224, 3, 7, 7, 3, 64, 1, 1, 1, 64, 1, 112, 112, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [1, 2, 7, 1], 1], ["SP", 3, 10, 112, [1, 7, 1, 1], 1], ["SP", 3, 15, 64, [1, 8, 4, 1], 1], ["SP", 3, 20, 7, [7, 1], 1], ["SP", 3, 23, 7, [1, 7], 1], ["SP", 3, 26, 3, [3, 1], 1], ["RE", 3, [0, 5, [...]
-{"i": [["[\"7006235cfc29b73be524cf390ed5a977\", 1, 56, 56, 64, 1, 1, 64, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [1, 2, 2, 2], 1], ["SP", 3, 10, 56, [1, 7, 1, 2], 1], ["SP", 3, 15, 64, [1, 16, 1, 4], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 8], 1], ["RE", 3, [0, 5, 10, [...]
-{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 56, 56, 64, 1, 1, 64, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 7, 1, 1], 1], ["SP", 3, 10, 28, [1, 2, 1, 7], 1], ["SP", 3, 15, 128, [8, 8, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 2], 1], ["RE", 3, [0, 5, 10 [...]
-{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 28, 28, 128, 1, 1, 128, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 1, 1], 1], ["SP", 3, 10, 14, [2, 1, 7, 1], 1], ["SP", 3, 15, 256, [2, 64, 1, 2], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [1, 2], 1], ["RE", 3, [0, 5 [...]
-{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 14, 14, 256, 1, 1, 256, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 7, 1], 1], ["SP", 3, 10, 7, [1, 1, 1, 1], 1], ["SP", 3, 15, 512, [4, 128, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [1, 16], 1], ["RE", 3, [0, 5, [...]
-{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 14, 14, 256, 3, 3, 256, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [7, 1, 1, 1], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 4, 1, 1], 1], ["SP", 3, 20, 3, [1, 3], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 256, [8, 2], 1], ["RE", 3, [0, 5, [...]
+{"i": [["[\"f19692ed81d032b1697c08adee62f9a5\", [1, 28, 28, 128], [4, 4, 128, 128], [1, 28, 28, 128], [1, 1, 1, 128], [1, 28, 28, 128]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [49], 1], ["SP", 8, 4, 128, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW [...]
+{"i": [["[\"2d10de6646307f0e3e5cf4b31c20e69b\", [1, 56, 56, 64], [1, 1, 64, 64], [1, 56, 56, 64]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [1, 2, 2, 1], 1], ["SP", 3, 10, 56, [1, 8, 1, 7], 1], ["SP", 3, 15, 64, [1, 16, 4, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [32, 1], 1] [...]
+{"i": [["[\"a3df19e5b88592ef5a9ce584a1ca3010\", [1, 7, 7, 512], [4, 4, 512, 512], [1, 7, 7, 512], [1, 1, 1, 512], [1, 1, 1, 512], [1, 7, 7, 512]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [1], 1], ["SP", 8, 4, 512, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7] [...]
+{"i": [["[\"0fad1b42d0d33418e0a8d15d3bbad3c9\", [1, 56, 56, 64], [1, 1, 64, 128], [1, 28, 28, 128]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 7, 2, 2], 1], ["SP", 3, 10, 28, [2, 7, 1, 2], 1], ["SP", 3, 15, 128, [2, 8, 4, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [1, 1], 1 [...]
+{"i": [["[\"0bcf718c0e6566bcd6c3b1437a3b6291\", [1, 28, 28, 128], [4, 4, 128, 128], [1, 1, 1, 128], [1, 28, 28, 128]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [...]
+{"i": [["[\"1097323f3970e5c881ad3a0028ca79cb\", [1, 14, 14, 256], [4, 4, 256, 256], [1, 1, 1, 256], [1, 14, 14, 256]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [ [...]
+{"i": [["[\"d78e8eb6021c4cdda0ad7775d10f751a\", [1, 7, 7, 512], [4, 4, 512, 512], [1, 7, 7, 512], [1, 7, 7, 512]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [4], 1], ["SP", 8, 4, 512, [32], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [2, 2, 1, 1], 1] [...]
+{"i": [["[\"7c2a4f1f432f81c44985590780dfb52d\", [1, 56, 56, 64], [6, 6, 64, 64], [1, 1, 1, 64], [1, 56, 56, 64]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [49], 1], ["SP", 8, 4, 64, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, [...]
+{"i": [["[\"64b7ce5264a64cb340d78b444b0325e6\", [1, 14, 14, 256], [4, 4, 256, 256], [1, 14, 14, 256], [1, 1, 1, 256], [1, 14, 14, 256]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", [...]
+{"i": [["[\"be3babb9a46e32f66b717a3e2a2d522c\", [1, 7, 7, 512], [1, 1, 1, 512]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 512, [64], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["FU", 1, [0, 1, 2, 3]], ["SP", 1, 0, 512, [32], 1], ["AN", 1, 0, 5], ["AN", 1, 1, 6], ["PR", 1, 0, "auto_unroll_max_step$64"]]]], "r": [[3.49558e-06], 0, 0.880265, 1650980753], "v" [...]
+{"i": [["[\"7d79c516e212fe1d73f5dbb90eaca2cf\", [1, 1000], [1, 1000]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["SP", 4, 1, 1000, [20], 1], ["AN", 4, 2, 6], ["FSP", 3, 1, 0, 1], ["AN", 3, 2, 6], ["CA", 3, 4, 0], ["CI", 2], ["AN", 4, 0, 5], ["AN", 1, 0, 6], ["PR", 1, 0, "auto_unroll_max_step$0"], ["PR", 3, 0, "auto_unroll_max_step$16"]]]], "r": [[1.66218e-05], 0, 1.00389, 1650980756], "v [...]
+{"i": [["[\"40b1cf1fd37b0ef111b3cc0247302508\", [1, 7, 7, 512], [4, 4, 512, 512], [1, 1, 1, 512], [1, 7, 7, 512]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [8], 1], ["SP", 8, 4, 512, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, [...]
+{"i": [["[\"0fad1b42d0d33418e0a8d15d3bbad3c9\", [1, 28, 28, 128], [1, 1, 128, 256], [1, 14, 14, 256]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 2, 1, 1], 1], ["SP", 3, 10, 14, [1, 1, 2, 1], 1], ["SP", 3, 15, 256, [4, 8, 1, 4], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [2, 4] [...]
+{"i": [["[\"25577781e50c611c2e45e73c1cb3a6ca\", [1, 28, 28, 128], [4, 4, 128, 128], [1, 28, 28, 128], [1, 28, 28, 128]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [7], 1], ["SP", 8, 4, 128, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, [...]
+{"i": [["[\"07f9fcad27bdd3233f86fe35a5185d33\", [1, 28, 28, 128], [3, 3, 128, 256], [1, 1, 1, 256], [1, 14, 14, 256]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 2, 1], 1], ["SP", 3, 10, 14, [1, 1, 2, 7], 1], ["SP", 3, 15, 256, [1, 16, 1, 1], 1], ["SP", 3, 20, 3, [1, 1], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 128, [...]
+{"i": [["[\"07f9fcad27bdd3233f86fe35a5185d33\", [1, 14, 14, 256], [3, 3, 256, 512], [1, 1, 1, 512], [1, 7, 7, 512]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 7, 1, 1], 1], ["SP", 3, 10, 7, [7, 1, 1, 1], 1], ["SP", 3, 15, 512, [1, 16, 2, 8], 1], ["SP", 3, 20, 3, [1, 3], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 256, [1, [...]
+{"i": [["[\"6c4f6234946e16bcf9e48bdf289f9200\", [1, 56, 56, 64], [6, 6, 64, 64], [1, 56, 56, 64], [1, 1, 1, 64], [1, 56, 56, 64]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 64, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, " [...]
+{"i": [["[\"07f9fcad27bdd3233f86fe35a5185d33\", [1, 224, 224, 3], [7, 7, 3, 64], [1, 1, 1, 64], [1, 112, 112, 64]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [1, 2, 28, 1], 1], ["SP", 3, 10, 112, [7, 1, 1, 1], 1], ["SP", 3, 15, 64, [1, 32, 1, 1], 1], ["SP", 3, 20, 7, [1, 7], 1], ["SP", 3, 23, 7, [7, 1], 1], ["SP", 3, 26, 3, [1 [...]
+{"i": [["[\"10b8215aaf2e14d47d40b4093e6f41a0\", [1, 56, 56, 64], [6, 6, 64, 64], [1, 56, 56, 64], [1, 56, 56, 64]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [7], 1], ["SP", 8, 4, 64, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 1, 1, 1], 1] [...]
+{"i": [["[\"7f3fee61bc3c2604395f5d343b840b7c\", [1, 14, 14, 256], [4, 4, 256, 256], [1, 14, 14, 256], [1, 14, 14, 256]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [49], 1], ["SP", 8, 4, 256, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, [...]
+{"i": [["[\"0fad1b42d0d33418e0a8d15d3bbad3c9\", [1, 14, 14, 256], [1, 1, 256, 512], [1, 7, 7, 512]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 1, 1], 1], ["SP", 3, 10, 7, [7, 1, 1, 1], 1], ["SP", 3, 15, 512, [2, 128, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [8, 4], [...]
+{"i": [["[\"affd3c4a65f665e451a06d65bf32750d\", [1, 112, 112, 64], [1, 1, 1, 64], [1, 56, 56, 64]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 4], ["CI", 1], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 200704, [1], 1], ["AN", 5, 0, 5], ["AN", 5, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 200704, [2], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["PR", 2, 0, "auto_unroll_max_step$1024"]]]], "r" [...]
+{"i": [["[\"00a059b856ac30ac172b6252254479a6\", [1, 512], [1000, 512], [1, 1000], [1, 1000]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [4, 50, 1, 1], 1], ["SP", 2, 10, 512, [2, 4], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], [ [...]
+{"i": [["[\"07f9fcad27bdd3233f86fe35a5185d33\", [1, 56, 56, 64], [3, 3, 64, 128], [1, 1, 1, 128], [1, 28, 28, 128]]", "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0, []], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 2, 7, 2], 1], ["SP", 3, 10, 28, [2, 7, 1, 2], 1], ["SP", 3, 15, 128, [1, 8, 8, 1], 1], ["SP", 3, 20, 3, [1, 3], 1], ["SP", 3, 23, 3, [3, 1], 1], ["SP", 3, 26, 64, [1, [...]
diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py
index 6f331499b0..e59e78f571 100644
--- a/python/tvm/auto_scheduler/measure.py
+++ b/python/tvm/auto_scheduler/measure.py
@@ -780,7 +780,7 @@ def register_task_input_check_func(func_name, f=None, override=False):
return register
-def prepare_input_map(args):
+def prepare_input_map(args, workload_key=None):
"""This function deals with special task inputs. Map the input Tensor of a TVM subgraph
to a specific buffer name in the global buffer map.
@@ -789,6 +789,11 @@ def prepare_input_map(args):
args : List[Tensor]
Input/output Tensor of a TVM subgraph.
+ workload_key: Optional[str]
+ The workload for which these inputs are being prepared. This
+ is used to identify if an input is being provided by (see
+ `register_task_input_buffer`).
+
Returns
-------
Dict[Tensor, str] :
@@ -803,13 +808,19 @@ def prepare_input_map(args):
global TASK_INPUT_CHECK_FUNC_REGISTRY
+ from .search_task import TASK_INPUT_BUFFER_TABLE
+
# A dict that maps the input tensor arg to a buffer name
tensor_input_map = {}
# Case 0: Check placeholder name
for arg in args:
if isinstance(arg.op, tvm.te.PlaceholderOp):
- if arg.op.name != "placeholder":
+ if (
+ workload_key
+ and workload_key in TASK_INPUT_BUFFER_TABLE
+ and arg.op.name in TASK_INPUT_BUFFER_TABLE[workload_key]
+ ):
tensor_input_map[arg] = arg.op.name
# Case 1: Check specific tensor inputs
@@ -843,7 +854,7 @@ def prepare_runner_args(inp, build_res):
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency
task_input_names = inp.task.task_input_names
- tensor_input_map = prepare_input_map(build_res.args)
+ tensor_input_map = prepare_input_map(build_res.args, inp.task.workload_key)
if not task_input_names:
tensor_input_map = {}
args = []
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index 9541232a6a..52c7f44fce 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -336,7 +336,8 @@ def auto_schedule_topi(func_name, outs):
logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err))
return None
- key = register_workload_tensors(dag.workload_key(), io_tensors)
+ workload_key = dag.workload_key()
+ key = register_workload_tensors(workload_key, io_tensors)
target = tvm.target.Target.current()
dispatch_ctx = DispatchContext.current
@@ -356,7 +357,7 @@ def auto_schedule_topi(func_name, outs):
# in the task extraction mode
if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK:
env.add_workload_key(func_name, key)
- input_map = prepare_input_map(io_tensors)
+ input_map = prepare_input_map(io_tensors, workload_key)
if input_map:
env.add_workload_input_names(key, list(input_map.values()))
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
index a3d46170df..f5c8994bec 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
@@ -268,7 +268,7 @@ def extract_param_base_addresses(mod, buffer_info, scratch_region_map) -> List[u
size_bytes = element_size_bytes * np.prod(list(buffer.shape))
base_addresses.append(
util.BaseAddress(
- param.name,
+ param.name.replace("-", "_"),
idx,
_get_region(buffer_info[param].btype, param, scratch_region_map),
size_bytes,
diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index 1d7566ebe2..a8eb6a5810 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -131,7 +131,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
for (Var param : relay_func->params) {
Array<tvm::te::Tensor> inputs;
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
- tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+ tvm::te::Tensor tensor =
+ tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype, param->vid->name_hint);
inputs.push_back(tensor);
fn_inputs_.push_back(tensor);
}
@@ -478,7 +479,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
// Add data placeholder (in case we discover we need it below)
Shape shape = GetShape(ttype->shape);
- tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype);
+ tvm::te::Tensor data_tensor =
+ tvm::te::placeholder(shape, ttype->dtype, "data_" + param->vid->name_hint);
data_inputs.push_back(data_tensor);
// Add shape placeholder (in case we discover we need it below)
int64_t ndim = shape.size();
@@ -486,7 +488,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
if (ndim > 0) {
sshape.push_back(tvm::Integer(ndim));
}
- tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64));
+ tvm::te::Tensor shape_tensor =
+ tvm::te::placeholder(sshape, DataType::Int(64), "shape_" + param->vid->name_hint);
shape_inputs.push_back(shape_tensor);
}
param_data_[param] = data_inputs;