You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/03/29 02:36:55 UTC

[tvm] branch main updated: [TIR] Properly initialize PRNG seed when copying schedule (#10806)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 62e3d9d  [TIR] Properly initialize PRNG seed when copying schedule (#10806)
62e3d9d is described below

commit 62e3d9d27e69988c21493d650fba9fda5002d1a7
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Tue Mar 29 11:35:41 2022 +0900

    [TIR] Properly initialize PRNG seed when copying schedule (#10806)
    
    * Make Schedule::Copy non-const, fork RND seed in Copy
    
    * fork seed in traced schedule copy too
    
    commit eeb4a6d4b34909822ea5d56488afd11f254e53a9
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Tue Mar 29 06:39:38 2022 +0900
    
        add more comment
    
    commit 183b4cfe5d7938d5e440a9d77b7e8c3871544966
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Mon Mar 28 10:04:12 2022 +0900
    
        skip flaky vk test
    
    commit c19ecc17afc8ee1b54aa2260bffb4e1d431ab429
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Mon Mar 28 07:34:25 2022 +0900
    
        move intrin decl for vector type
    
    commit 3dd7f045f791b805012227ab4ee866995cc5297d
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sat Mar 26 09:40:29 2022 +0900
    
        disable default post processor, tuning now works with compactness check
    
    commit 2f6fdae675975e2bd95a086dce8a88d9f267746d
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sat Mar 26 08:08:35 2022 +0900
    
        more comment
    
    commit c7ebfa904367885442f928c0cecf875190341930
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sat Mar 26 07:42:46 2022 +0900
    
        add comment
    
    commit 78400bad77f5201b3cebfb8a7fee0642adead060
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sat Mar 26 07:40:28 2022 +0900
    
        disable tuning test for now
    
    commit a33243fbf91863f0a834505cd4936c0fff228603
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sat Mar 26 07:30:03 2022 +0900
    
        remove annotation check in ir comparator
    
    commit 105f98cc76081d46dddbda47dba2578a25cfadb2
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sat Mar 26 07:28:36 2022 +0900
    
        clean up
    
    commit 8aa16f209ee709375d90f2c3a5883a47df6ce104
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sat Mar 26 07:15:24 2022 +0900
    
        Add test
    
    * add test case that hangs without forkseed
---
 include/tvm/support/random_engine.h                 |  2 +-
 include/tvm/tir/schedule/schedule.h                 |  2 +-
 src/tir/schedule/concrete_schedule.cc               |  3 ++-
 src/tir/schedule/concrete_schedule.h                |  2 +-
 src/tir/schedule/traced_schedule.cc                 |  3 ++-
 src/tir/schedule/traced_schedule.h                  |  2 +-
 tests/python/unittest/test_tir_schedule_sampling.py | 11 +++++++++++
 7 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h
index 89b1e91..fe56bb5 100644
--- a/include/tvm/support/random_engine.h
+++ b/include/tvm/support/random_engine.h
@@ -115,7 +115,7 @@ class LinearCongruentialEngine {
    * \return The forked seed.
    */
   TRandState ForkSeed() {
-    // In order for reproducibility, we computer the new seed using RNG's random state and a
+    // In order for reproducibility, we compute the new seed using RNG's random state and a
     // different set of parameters. Note that both 32767 and 1999999973 are prime numbers.
     return ((*this)() * 32767) % 1999999973;
   }
diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h
index 0273ece..1d9bfc9 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -123,7 +123,7 @@ class ScheduleNode : public runtime::Object {
    * 3) All the random variables are valid in the copy, pointing to the corresponding sref
    * reconstructed
    */
-  virtual Schedule Copy() const = 0;
+  virtual Schedule Copy() = 0;
   /*!
    * \brief Seed the randomness
    * \param seed The new random seed, -1 if use device random, otherwise non-negative
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 331ae02..e261cf2 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -182,11 +182,12 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb
   new_state->get()->DebugVerify();
 }
 
-Schedule ConcreteScheduleNode::Copy() const {
+Schedule ConcreteScheduleNode::Copy() {
   ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
   n->error_render_level_ = this->error_render_level_;
   ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
   n->analyzer_ = std::make_unique<arith::Analyzer>();  // new analyzer needed because it is stateful
+  n->rand_state_ = ForkSeed();
   return Schedule(std::move(n));
 }
 
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 32aab1a..59764e3 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -61,7 +61,7 @@ class ConcreteScheduleNode : public ScheduleNode {
  public:
   ScheduleState state() const final { return state_; }
   Optional<Trace> trace() const override { return NullOpt; }
-  Schedule Copy() const override;
+  Schedule Copy() override;
   void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final;
   support::LinearCongruentialEngine::TRandState ForkSeed() final;
 
diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc
index 8af66f1..417f80d 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -33,11 +33,12 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand
   return Schedule(std::move(n));
 }
 
-Schedule TracedScheduleNode::Copy() const {
+Schedule TracedScheduleNode::Copy() {
   ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
   n->error_render_level_ = this->error_render_level_;
   ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
   n->analyzer_ = std::make_unique<arith::Analyzer>();  // new analyzer needed because it is stateful
+  n->rand_state_ = ForkSeed();
   n->trace_ = Trace(this->trace_->insts, this->trace_->decisions);
   return Schedule(std::move(n));
 }
diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h
index 5d355bd..442b50a 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -43,7 +43,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
 
  public:
   Optional<Trace> trace() const final { return trace_; }
-  Schedule Copy() const final;
+  Schedule Copy() final;
 
  public:
   /******** Schedule: Sampling ********/
diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py
index cc2b114..d8f9670 100644
--- a/tests/python/unittest/test_tir_schedule_sampling.py
+++ b/tests/python/unittest/test_tir_schedule_sampling.py
@@ -194,5 +194,16 @@ def test_sample_compute_location():
         numpy.testing.assert_allclose(expected_rate, cnt / n, atol=0.04)
 
 
+def test_sample_perfect_tile_after_copy():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    sch_copy = sch.copy()
+    _, _, i = sch.get_loops(sch.get_block("B"))
+    sch.sample_perfect_tile(i, n=4)
+
+    _, _, i = sch_copy.get_loops(sch_copy.get_block("B"))
+    # Hangs if ForkSeed is not invoked when copying a schedule
+    sch_copy.sample_perfect_tile(i, n=4)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))