You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/10 05:39:06 UTC

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6657: [AutoScheduler] Improve test cases

comaniac commented on a change in pull request #6657:
URL: https://github.com/apache/incubator-tvm/pull/6657#discussion_r502749481



##########
File path: python/tvm/auto_scheduler/measure_record.py
##########
@@ -159,3 +175,38 @@ def load_best(filename, workload_key=None, target=None):
             best_res = res
 
     return best_inp, best_res
+
+
+def correct_measure_input(inp, rebuild_state=False):

Review comment:
       The naming is confusing. I imagine the use case would be like
   ```python
   for inp, res in load_records(file):
       inp = correct_measure_input(inp)
   ```
   From the code snippet, it's hard to get what "correct" measure input means.
   
   One suggestion is calling this function in `load_records` by configure, something like `load_records(file, full_model=True)`?

##########
File path: tests/python/unittest/test_auto_scheduler_measure.py
##########
@@ -167,53 +167,76 @@ def test_record_pragma_storage_align_rfactor():
     record_common(dag, s)
 
 
-def test_measure_local_builder_runner(enable_cpu_cache_flush=False):
+def test_correct_measure_input():
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
+
+    inp = auto_scheduler.measure.MeasureInput(task, task.compute_dag.init_state)
+    res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
+
+    with tempfile.NamedTemporaryFile() as fp:
+        auto_scheduler.save_records(fp.name, [inp], [res])
+
+        log_reader = auto_scheduler.RecordReader(fp.name)
+        inputs, results = log_reader.read_lines()
+        assert len(inputs) == 1
+
+        raw_inp = inputs[0]
+
+        correct_inp = auto_scheduler.measure_record.correct_measure_input(raw_inp)
+        assert str(correct_inp.task.compute_dag) == str(inp.task.compute_dag)
+
+        correct_inp = auto_scheduler.measure_record.correct_measure_input(
+            raw_inp, rebuild_state=True
+        )
+        assert str(correct_inp.state) == str(inp.state)
+
+
+def test_measure_local_builder_runner():
     if not tvm.testing.device_enabled("llvm"):
         return
 
-    dag, s0 = get_tiled_matmul()
-    tgt = tvm.target.Target("llvm")
-    task = auto_scheduler.SearchTask(dag, "test", tgt)
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
 
-    minp = auto_scheduler.MeasureInput(task, s0)
-    local_builder = auto_scheduler.LocalBuilder()
-    local_runner = auto_scheduler.LocalRunner(
-        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
-    )
+    for enable_cpu_cache_flush in [True, False]:
+        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
+        local_builder = auto_scheduler.LocalBuilder()
+        local_runner = auto_scheduler.LocalRunner(
+            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
+        )
 
-    bress = local_builder.build([minp])
-    assert bress[0].error_no == 0
-    mress = local_runner.run([minp], bress)
-    assert mress[0].error_no == 0
+        bress = local_builder.build([minp])
+        assert bress[0].error_no == 0
+        mress = local_runner.run([minp], bress)
+        assert mress[0].error_no == 0
 
 
-def test_measure_local_builder_rpc_runner(enable_cpu_cache_flush=False):
+def test_measure_local_builder_rpc_runner():
     if not tvm.testing.device_enabled("llvm"):
         return
 
-    dag, s0 = get_tiled_matmul()
-    tgt = tvm.target.Target("llvm")
-    task = auto_scheduler.SearchTask(dag, "test", tgt)
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")
 
-    minp = auto_scheduler.MeasureInput(task, s0)
-    local_builder = auto_scheduler.LocalBuilder()
-    measure_ctx = auto_scheduler.LocalRPCMeasureContext(
-        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
-    )
-    rpc_runner = measure_ctx.runner
+    for enable_cpu_cache_flush in [True, False]:

Review comment:
       Is this the only place testing `enable_cpu_cache_flush`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org