You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2023/01/08 12:45:10 UTC

[GitHub] [tvm] yzh119 opened a new pull request, #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

yzh119 opened a new pull request, #13726:
URL: https://github.com/apache/tvm/pull/13726

   # Motivation
   Currently, our default profiler (`time_evaluator`) does not flush the L2 cache per execution, this might lead to incorrect time measurement because the input data last run might reside in L2 cache and reduce the data fetching time in the next run. Both [Triton](https://github.com/openai/triton/blob/ff399fbc2059a6f35cb93534dc29398f7b82dbc7/python/triton/testing.py#L156-L181) and [nvbench](https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/measure_cold.cuh#L123) consider this effect thus reporting more accurate measurements.
   
   # Solution
   `time_evalutor` has an argument `f_preproc` where user can specify a pre-processing function per execution of the kernel being evaluated. Currently, TVM supports `cache_flush_cpu_non_first_arg` which flushes CPU cache. But similar functionality for GPU is missing.
   
   This PR completely borrows the design of nvbench's [l2flush](https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh) struct and allow the user to specify `"l2_cache_flush_cuda"` as a preprocessing function which flushes NVIDIA GPU's L2 cache.
   
   Note that this PR also changes the location where `f_preproc` being triggered: previously `f_preproc` is triggered per repeat but that doesn't sound correct to me because most users specify `repeat=1` and `f_preproc` need to be triggered once per run.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] echuraev commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
echuraev commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064305569


##########
tests/python/unittest/test_evaluator_flush_l2_cache.py:
##########
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import te
+from tvm.script import tir as T
+import tvm.testing
+import numpy as np
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128])
+    B = T.match_buffer(b, [128, 128])
+    C = T.match_buffer(c, [128, 128])
+    for i, j, k in T.grid(128, 128, 128):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+@tvm.testing.requires_cuda
+def test_evaluator_flush_l2_cache():
+    mod = tvm.IRModule.from_expr(matmul)
+    sch = tvm.tir.Schedule(mod)
+    blk = sch.get_block("matmul")
+    i, j, k = sch.get_loops(blk)
+    sch.bind(i, "blockIdx.x")
+    sch.bind(j, "threadIdx.x")
+    f = tvm.build(sch.mod["main"], target="cuda")
+    dev = tvm.cuda(0)
+    evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1)
+
+    a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev)
+    args = [a, b, c]
+    print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000))
+
+    evaluator_with_flush = f.time_evaluator(
+        f.entry_name, dev, repeat=1000, number=1, f_preproc="l2_cache_flush_cuda"
+    )
+    print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000))

Review Comment:
   Are there any criteria where the test should be failed?



##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {

Review Comment:
   Do we really need `l2_size_`? Probably we just can check that `l2_buffer_` in not a `nullptr`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Icemist commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
Icemist commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064645650


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {
+      CUDA_CALL(cudaFree(l2_buffer_));
+    }
+  }
+
+  void Flush() {
+    if (!initialized_) {

Review Comment:
   we expect that after cudaMalloc the variable l2_buffer_ will not be nullptr? then can we use this instead of initialized_?



##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {
+      CUDA_CALL(cudaFree(l2_buffer_));
+    }
+  }
+
+  void Flush() {
+    if (!initialized_) {
+      // initialize l2_buffer_ and l2_size_
+      initialized_ = true;
+      int device_id;
+      CUDA_CALL(cudaGetDevice(&device_id));
+      CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id));
+      if (l2_size_ > 0) {
+        void* buffer = l2_buffer_;

Review Comment:
   I'm not good at CUDA, but can't we do without the buffer temporary variable? 
   `CUDA_CALL(cudaMalloc((void**)&l2_buffer_, l2_size_))`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064146163


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   Another concern is whether we should have a `warmup` argument for `time_evaluator` function, I think it's common in all frameworks' profiling tools.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064968656


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,79 @@
+/*

Review Comment:
   We don't like to include code from other projects in TVM unless it lives in the `3rdparty` directory. As this code is basically the same as https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh, it should be moved into `3rdparty/nvbench` and the license header at the top should be changed to the one in the nvbench repository.
   
   @areusch is this acceptable?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064172973


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   I also noticed `ProfileFunction` defined in `profiling.cc` which has a `warmup_iters` argument, which somehow makes me confused: in what case should user select `ProfileFunction` instead of `WrapTimeEvaluator`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige merged pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
tkonolige merged PR #13726:
URL: https://github.com/apache/tvm/pull/13726


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064590379


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   Got it, I prefer supporting the `warmup` argument in another PR. Thank you for the discussion! I'll mark this conversation as resolved.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Icemist commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
Icemist commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064162325


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   > Another concern is whether we should have a `warmup` argument for `time_evaluator` function, I think it's common in all frameworks' profiling tools. 
   
   There is a pre-start call of the function, before starting the logic of measurements.  Do you think that one is not enough? Or warm it up with something else?
   ```
   // skip first time call, to activate lazy compilation components.
   pf.CallPacked(args, &temp);
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Icemist commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
Icemist commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064141278


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   > Thanks for the clarification, I'll fix it soon. Unfortunately it looks like there are a lot of misuse of `number` in the codebase, maybe we should create another PR fixing them.
   I don't think it matters that much if we are satisfied with the average number and don't need to call f_preproc. So don't be too eager about it.
   



##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   > Thanks for the clarification, I'll fix it soon. Unfortunately it looks like there are a lot of misuse of `number` in the codebase, maybe we should create another PR fixing them.
   
   I don't think it matters that much if we are satisfied with the average number and don't need to call f_preproc. So don't be too eager about it.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064468131


##########
tests/python/unittest/test_evaluator_flush_l2_cache.py:
##########
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import te
+from tvm.script import tir as T
+import tvm.testing
+import numpy as np
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128])
+    B = T.match_buffer(b, [128, 128])
+    C = T.match_buffer(c, [128, 128])
+    for i, j, k in T.grid(128, 128, 128):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+@tvm.testing.requires_cuda
+def test_evaluator_flush_l2_cache():
+    mod = tvm.IRModule.from_expr(matmul)
+    sch = tvm.tir.Schedule(mod)
+    blk = sch.get_block("matmul")
+    i, j, k = sch.get_loops(blk)
+    sch.bind(i, "blockIdx.x")
+    sch.bind(j, "threadIdx.x")
+    f = tvm.build(sch.mod["main"], target="cuda")
+    dev = tvm.cuda(0)
+    evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1)
+
+    a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev)
+    args = [a, b, c]
+    print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000))
+
+    evaluator_with_flush = f.time_evaluator(
+        f.entry_name, dev, repeat=1000, number=1, f_preproc="l2_cache_flush_cuda"
+    )
+    print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000))

Review Comment:
   There isn't but I'd love to have one, any ideas on that?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Icemist commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
Icemist commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064497488


##########
tests/python/unittest/test_evaluator_flush_l2_cache.py:
##########
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import te
+from tvm.script import tir as T
+import tvm.testing
+import numpy as np
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128])
+    B = T.match_buffer(b, [128, 128])
+    C = T.match_buffer(c, [128, 128])
+    for i, j, k in T.grid(128, 128, 128):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+@tvm.testing.requires_cuda
+def test_evaluator_flush_l2_cache():
+    mod = tvm.IRModule.from_expr(matmul)
+    sch = tvm.tir.Schedule(mod)
+    blk = sch.get_block("matmul")
+    i, j, k = sch.get_loops(blk)
+    sch.bind(i, "blockIdx.x")
+    sch.bind(j, "threadIdx.x")
+    f = tvm.build(sch.mod["main"], target="cuda")
+    dev = tvm.cuda(0)
+    evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1)
+
+    a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev)
+    args = [a, b, c]
+    print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000))
+
+    evaluator_with_flush = f.time_evaluator(
+        f.entry_name, dev, repeat=1000, number=1, f_preproc="l2_cache_flush_cuda"
+    )
+    print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000))

Review Comment:
   > Are there any criteria where the test should be failed?
   
   This test tests for the presence of **l2_cache_flush_cuda**. It's normal integration test :)
   
   Not sure if we need to do the first run: evaluator_no_flush.
   In some tests, the base scenario is hidden with an **if False** block.
   For example you can look at do_tune block in  _tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py::test_packed_8x8x32_resnet50_
   Or use **@pytest.mark.parametrize** for **f_preproc**



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Icemist commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
Icemist commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064139675


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   From experience that this change needs to be reversed.
   The **repeat** and **number** parameters are designed for different scenarios. For your scenario, you need to increase the **repeat** and use **number=1**.
   Scenario with increasing **number** (or using **min_repeat_ms**) is used when it is problematic to measure single runs - "_The timer is specifically in the outer loop so that we can time functions with a very small runtime or use less precise timers._"



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065519635


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,79 @@
+/*

Review Comment:
   @tkonolige @tqchen  I have refactored the code and added license files correspondingly, let me know if they look good to you:)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065499844


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {
+      CUDA_CALL(cudaFree(l2_buffer_));
+    }
+  }
+
+  void Flush() {
+    if (!initialized_) {

Review Comment:
   Per [CUDA runtime documentation](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html): 
   >[cudaDevAttrL2CacheSize](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1gg49e2f8c2c0bd6fe264f2fc970912e5cd1b0342682d15910022ba3f383a851ad7): Size of L2 cache in bytes. 0 if the device doesn't have L2 cache.
   
   `l2_size_` would be set to 0 if the device does not have `l2`, in which case `l2_buffer_` is still a nullptr I suppose.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] echuraev commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
echuraev commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065007662


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {
+      CUDA_CALL(cudaFree(l2_buffer_));
+    }
+  }
+
+  void Flush() {
+    if (!initialized_) {

Review Comment:
   I think that @Icemist wanted to propose removing `initialized_` variable and check `l2_buffer_` instead:
   ```c++
     void Flush() {
       if (l2_buffer_ == nullptr) {
         // ...
   ```
   
   With the same idea, I have asked my question about `l2_size_` variable in the destructor. 



##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {

Review Comment:
   My bad. I have missed it. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Icemist commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
Icemist commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064507883


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   Maybe warmup has a place, also we already have skip_first=1 similar to pytorch. 
   But this would require changes not only in the C++ part, but also in the C and webasm part. 
   
   For ProfileFunction  plz look original PR with https://github.com/apache/tvm/pull/9553 there are a couple of words about the motivation for adding.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065275352


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,79 @@
+/*

Review Comment:
   Thank @tkonolige @tqchen for the comments, I'll refactor the code and use nvbench's license.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] echuraev commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
echuraev commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065019000


##########
tests/python/unittest/test_evaluator_flush_l2_cache.py:
##########
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import te
+from tvm.script import tir as T
+import tvm.testing
+import numpy as np
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128])
+    B = T.match_buffer(b, [128, 128])
+    C = T.match_buffer(c, [128, 128])
+    for i, j, k in T.grid(128, 128, 128):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+@tvm.testing.requires_cuda
+def test_evaluator_flush_l2_cache():
+    mod = tvm.IRModule.from_expr(matmul)
+    sch = tvm.tir.Schedule(mod)
+    blk = sch.get_block("matmul")
+    i, j, k = sch.get_loops(blk)
+    sch.bind(i, "blockIdx.x")
+    sch.bind(j, "threadIdx.x")
+    f = tvm.build(sch.mod["main"], target="cuda")
+    dev = tvm.cuda(0)
+    evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1)
+
+    a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev)
+    args = [a, b, c]
+    print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000))
+
+    evaluator_with_flush = f.time_evaluator(
+        f.entry_name, dev, repeat=1000, number=1, f_preproc="l2_cache_flush_cuda"
+    )
+    print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000))

Review Comment:
   > This test tests for the presence of **l2_cache_flush_cuda**. It's normal integration test :)
   
   Yes, but I thought if we can check that `l2_cache_flush_cuda` works. Previously, it was two runs of `time_evaluator`. And I thought if we can compare the execution time, then we probably can say if `l2_cache_flush_cuda` works or not.
   
   Probably the better test for such functionality is to check L2 cache after `Flush` function. But probably such test should be implemented by using gtest, and I'm not sure that such infrastructure exists for cuda. Anyway, I don't think that we should spend a lot of time on implementing such test in this PR and I don't want to block this PR only because of such test. The existing test must be enough.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065198605


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,79 @@
+/*

Review Comment:
   Please see an example of compiler-rt header 
   - original file(with fixes) https://github.com/apache/tvm/blob/main/3rdparty/compiler-rt/builtin_fp16.h
   - mention license here https://github.com/apache/tvm/blob/main/LICENSE#L235
   - add licenses file to this folder https://github.com/apache/tvm/blob/main/licenses/LICENSE.builtin_fp16.txt



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #13726:
URL: https://github.com/apache/tvm/pull/13726#issuecomment-1374827722

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @echuraev, @icemist, @tkonolige <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064145960


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   I have reverted `profiling.cc` and updated unittest accordingly.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064715623


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {
+      CUDA_CALL(cudaFree(l2_buffer_));
+    }
+  }
+
+  void Flush() {
+    if (!initialized_) {
+      // initialize l2_buffer_ and l2_size_
+      initialized_ = true;
+      int device_id;
+      CUDA_CALL(cudaGetDevice(&device_id));
+      CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id));
+      if (l2_size_ > 0) {
+        void* buffer = l2_buffer_;

Review Comment:
   I adopted your suggestion because it looks better:)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065499844


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {
+      CUDA_CALL(cudaFree(l2_buffer_));
+    }
+  }
+
+  void Flush() {
+    if (!initialized_) {

Review Comment:
   Per [CUDA runtime documentation](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html): 
   >[cudaDevAttrL2CacheSize](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1gg49e2f8c2c0bd6fe264f2fc970912e5cd1b0342682d15910022ba3f383a851ad7): Size of L2 cache in bytes. 0 if the device doesn't have L2 cache.
   
   `l2_size_` would be set to 0 if the device does not have `l2`, in which case `l2_buffer_` is still a nullptr after cudaMalloc, and we will run into this branch again and again.
   
   I suppose it's still better to keep `initialized_`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064598583


##########
tests/python/unittest/test_evaluator_flush_l2_cache.py:
##########
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import te
+from tvm.script import tir as T
+import tvm.testing
+import numpy as np
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [128, 128])
+    B = T.match_buffer(b, [128, 128])
+    C = T.match_buffer(c, [128, 128])
+    for i, j, k in T.grid(128, 128, 128):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                C[vi, vj] = 0.0
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+@tvm.testing.requires_cuda
+def test_evaluator_flush_l2_cache():
+    mod = tvm.IRModule.from_expr(matmul)
+    sch = tvm.tir.Schedule(mod)
+    blk = sch.get_block("matmul")
+    i, j, k = sch.get_loops(blk)
+    sch.bind(i, "blockIdx.x")
+    sch.bind(j, "threadIdx.x")
+    f = tvm.build(sch.mod["main"], target="cuda")
+    dev = tvm.cuda(0)
+    evaluator_no_flush = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1)
+
+    a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
+    c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev)
+    args = [a, b, c]
+    print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000))
+
+    evaluator_with_flush = f.time_evaluator(
+        f.entry_name, dev, repeat=1000, number=1, f_preproc="l2_cache_flush_cuda"
+    )
+    print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000))

Review Comment:
   `pytest.mark.parametrize` sounds good to me



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065275352


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,79 @@
+/*

Review Comment:
   Thank @tkonolige @tqchen for the comments, I'll refactor the code to include nvbench.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065189279


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,79 @@
+/*

Review Comment:
   @yzh119 Let us move to `3rdparty/nvbench` then we can include it from there 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064172623


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   PyTorch profiler has a `warmup` argument: https://pytorch.org/docs/stable/profiler.html#torch.profiler.schedule, where users can specify a custom number of steps to warm up.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064140526


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   Thanks for the clarification, I'll fix it soon.
   Unfortunately, it looks like there are a lot of incorrect use cases of `number` in the TVM codebase, maybe we should create another PR fixing them.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064140526


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   Thanks for the clarification, I'll fix it soon.
   Unfortunately it looks like there are a lot of misuse of `number` in the codebase, maybe we should create another PR fixing them.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064491541


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {

Review Comment:
   Yes, we need it, every time we call the flush function, we need `l2_size_` for `cudaMemsetAsync`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064590379


##########
src/runtime/profiling.cc:
##########
@@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
               std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
         }
 
+        int64_t accum_t_nanos = 0;
         // start timing
-        Timer t = Timer::Start(dev);
         for (int j = 0; j < number; ++j) {
+          // call preprocessing function
+          if (f_preproc != nullptr) {
+            f_preproc.CallPacked(args, &temp);

Review Comment:
   Got it, I prefer supporting the `warmup` argument in another PR. Thank you for the discussion!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yzh119 commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
yzh119 commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1064709512


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {
+      CUDA_CALL(cudaFree(l2_buffer_));
+    }
+  }
+
+  void Flush() {
+    if (!initialized_) {
+      // initialize l2_buffer_ and l2_size_
+      initialized_ = true;
+      int device_id;
+      CUDA_CALL(cudaGetDevice(&device_id));
+      CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id));
+      if (l2_size_ > 0) {
+        void* buffer = l2_buffer_;

Review Comment:
   I suppose they are equivalent.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] echuraev commented on a diff in pull request #13726: [Profiler] Allow user to flush L2 cache in `time_evalutor` function for profiling CUDA kernels

Posted by GitBox <gi...@apache.org>.
echuraev commented on code in PR #13726:
URL: https://github.com/apache/tvm/pull/13726#discussion_r1065517256


##########
src/runtime/cuda/l2_cache_flush.cc:
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// Acknowledgement: l2flush struct in nvbench project.
+// Reference:
+// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cuda_common.h"
+
+namespace tvm {
+
+namespace runtime {
+
+class L2Flush {
+ public:
+  L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}
+
+  ~L2Flush() {
+    if (l2_size_ > 0) {
+      CUDA_CALL(cudaFree(l2_buffer_));
+    }
+  }
+
+  void Flush() {
+    if (!initialized_) {

Review Comment:
   Thank you for the clarification! Yes, in this case I think `initialized_` is the better solution.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org