You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/03/29 02:36:55 UTC
[tvm] branch main updated: [TIR] Properly initialize PRNG seed when copying schedule (#10806)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 62e3d9d [TIR] Properly initialize PRNG seed when copying schedule (#10806)
62e3d9d is described below
commit 62e3d9d27e69988c21493d650fba9fda5002d1a7
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Tue Mar 29 11:35:41 2022 +0900
[TIR] Properly initialize PRNG seed when copying schedule (#10806)
* Make Schedule::Copy non-const, fork RND seed in Copy
* fork seed in traced schedule copy too
commit eeb4a6d4b34909822ea5d56488afd11f254e53a9
Author: Masahiro Masuda <ma...@gmail.com>
Date: Tue Mar 29 06:39:38 2022 +0900
add more comment
commit 183b4cfe5d7938d5e440a9d77b7e8c3871544966
Author: Masahiro Masuda <ma...@gmail.com>
Date: Mon Mar 28 10:04:12 2022 +0900
skip flaky vk test
commit c19ecc17afc8ee1b54aa2260bffb4e1d431ab429
Author: Masahiro Masuda <ma...@gmail.com>
Date: Mon Mar 28 07:34:25 2022 +0900
move intrin decl for vector type
commit 3dd7f045f791b805012227ab4ee866995cc5297d
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sat Mar 26 09:40:29 2022 +0900
disable default post processor, tuning now works with compactness check
commit 2f6fdae675975e2bd95a086dce8a88d9f267746d
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sat Mar 26 08:08:35 2022 +0900
more comment
commit c7ebfa904367885442f928c0cecf875190341930
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sat Mar 26 07:42:46 2022 +0900
add comment
commit 78400bad77f5201b3cebfb8a7fee0642adead060
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sat Mar 26 07:40:28 2022 +0900
disable tuning test for now
commit a33243fbf91863f0a834505cd4936c0fff228603
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sat Mar 26 07:30:03 2022 +0900
remove annotation check in ir comparator
commit 105f98cc76081d46dddbda47dba2578a25cfadb2
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sat Mar 26 07:28:36 2022 +0900
clean up
commit 8aa16f209ee709375d90f2c3a5883a47df6ce104
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sat Mar 26 07:15:24 2022 +0900
Add test
* add test case that hangs without forkseed
---
include/tvm/support/random_engine.h | 2 +-
include/tvm/tir/schedule/schedule.h | 2 +-
src/tir/schedule/concrete_schedule.cc | 3 ++-
src/tir/schedule/concrete_schedule.h | 2 +-
src/tir/schedule/traced_schedule.cc | 3 ++-
src/tir/schedule/traced_schedule.h | 2 +-
tests/python/unittest/test_tir_schedule_sampling.py | 11 +++++++++++
7 files changed, 19 insertions(+), 6 deletions(-)
diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h
index 89b1e91..fe56bb5 100644
--- a/include/tvm/support/random_engine.h
+++ b/include/tvm/support/random_engine.h
@@ -115,7 +115,7 @@ class LinearCongruentialEngine {
* \return The forked seed.
*/
TRandState ForkSeed() {
- // In order for reproducibility, we computer the new seed using RNG's random state and a
+ // In order for reproducibility, we compute the new seed using RNG's random state and a
// different set of parameters. Note that both 32767 and 1999999973 are prime numbers.
return ((*this)() * 32767) % 1999999973;
}
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 0273ece..1d9bfc9 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -123,7 +123,7 @@ class ScheduleNode : public runtime::Object {
* 3) All the random variables are valid in the copy, pointing to the corresponding sref
* reconstructed
*/
- virtual Schedule Copy() const = 0;
+ virtual Schedule Copy() = 0;
/*!
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise non-negative
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 331ae02..e261cf2 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -182,11 +182,12 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb
new_state->get()->DebugVerify();
}
-Schedule ConcreteScheduleNode::Copy() const {
+Schedule ConcreteScheduleNode::Copy() {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
n->error_render_level_ = this->error_render_level_;
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful
+ n->rand_state_ = ForkSeed();
return Schedule(std::move(n));
}
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 32aab1a..59764e3 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -61,7 +61,7 @@ class ConcreteScheduleNode : public ScheduleNode {
public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
- Schedule Copy() const override;
+ Schedule Copy() override;
void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final;
support::LinearCongruentialEngine::TRandState ForkSeed() final;
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 8af66f1..417f80d 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -33,11 +33,12 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand
return Schedule(std::move(n));
}
-Schedule TracedScheduleNode::Copy() const {
+Schedule TracedScheduleNode::Copy() {
ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
n->error_render_level_ = this->error_render_level_;
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful
+ n->rand_state_ = ForkSeed();
n->trace_ = Trace(this->trace_->insts, this->trace_->decisions);
return Schedule(std::move(n));
}
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index 5d355bd..442b50a 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -43,7 +43,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
public:
Optional<Trace> trace() const final { return trace_; }
- Schedule Copy() const final;
+ Schedule Copy() final;
public:
/******** Schedule: Sampling ********/
diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py
index cc2b114..d8f9670 100644
--- a/tests/python/unittest/test_tir_schedule_sampling.py
+++ b/tests/python/unittest/test_tir_schedule_sampling.py
@@ -194,5 +194,16 @@ def test_sample_compute_location():
numpy.testing.assert_allclose(expected_rate, cnt / n, atol=0.04)
+def test_sample_perfect_tile_after_copy():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ sch_copy = sch.copy()
+ _, _, i = sch.get_loops(sch.get_block("B"))
+ sch.sample_perfect_tile(i, n=4)
+
+ _, _, i = sch_copy.get_loops(sch_copy.get_block("B"))
+ # Hangs if ForkSeed is not invoked when copying a schedule
+ sch_copy.sample_perfect_tile(i, n=4)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))