You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/27 23:35:59 UTC

[GitHub] [tvm] tkonolige commented on a change in pull request #9132: [Meta Schedule][M3a] SearchStrategy

tkonolige commented on a change in pull request #9132:
URL: https://github.com/apache/tvm/pull/9132#discussion_r717093945



##########
File path: src/meta_schedule/utils.h
##########
@@ -21,13 +21,91 @@
 
 #include <tvm/meta_schedule/arg_info.h>
 #include <tvm/meta_schedule/builder.h>
+#include <tvm/meta_schedule/runner.h>
+#include <tvm/meta_schedule/search_strategy.h>
 #include <tvm/meta_schedule/space_generator.h>
 #include <tvm/meta_schedule/tune_context.h>
+#include <tvm/support/parallel_for.h>
 
+#include <vector>
+
+#include "../printer/text_printer.h"
 #include "../support/array.h"
+#include "../tir/schedule/primitive.h"
 
 namespace tvm {
-namespace meta_schedule {}  // namespace meta_schedule
+namespace meta_schedule {
+
+/*!
+ * \brief Find the entry function of the given IRModule.

Review comment:
       In the doc comment, can you specify what determines the entry function? You have comments in the code, but this information may be useful for people calling this function.

##########
File path: src/meta_schedule/search_strategy/replay_trace.cc
##########
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/*! \brief A search strategy that replays the trace. */
+class ReplayTraceNode : public SearchStrategyNode {
+ public:
+  using TRandState = support::LinearCongruentialEngine::TRandState;
+
+  /*! \brief The state of the search strategy. */
+  struct State {
+    /*! \brief The search strategy itself */
+    ReplayTraceNode* self;
+    /*! \brief The design spaces. */
+    Array<tir::Schedule> design_spaces;
+    /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
+    int st;
+    /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
+    int ed;
+
+    explicit State(ReplayTraceNode* self, Array<tir::Schedule> design_spaces)
+        : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {}
+
+    inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
+    inline void NotifyRunnerResults(const Array<RunnerResult>& results);
+  };
+
+  /*! \brief The number of trials per iteration. */
+  int num_trials_per_iter;
+  /*! \brief The number of total trials. */
+  int num_trials_total;
+
+  /*! \brief The module to be tuned. */
+  IRModule mod_{nullptr};
+  /*! \brief The metadata of the function arguments. */
+  Array<ArgInfo> args_info_{nullptr};
+  /*! \brief The number of threads to use. */
+  int num_threads_ = -1;
+  /*! \brief The random state */
+  TRandState rand_state_ = -1;

Review comment:
       In the doc comment, can you indicate what `-1` means for both these arguments.

##########
File path: include/tvm/meta_schedule/search_strategy.h
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+
+#include <tvm/tir/schedule/schedule.h>
+
+#include "./arg_info.h"
+#include "./runner.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+// Forward declaration
+class TuneContext;
+
+/*! \brief The measure candidate class. */
+class MeasureCandidateNode : public runtime::Object {
+ public:
+  /*! \brief The schedule for profiling. */
+  tir::Schedule sch;
+  /*! \brief The argument information. */
+  Array<ArgInfo> args_info;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("sch", &sch);
+    v->Visit("args_info", &args_info);
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.MeasureCandidate";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MeasureCandidateNode.
+ * \sa MeasureCandidateNode
+ */
+class MeasureCandidate : public runtime::ObjectRef {
+ public:
+  /*!
+   * \brief Constructor of MeasureCandidate.
+   * \param sch The schedule for profiling.

Review comment:
       I'd recommend using "benchmarking" instead of "profiling" in this PR. In the TVM codebase, the term profiling is exclusively for when we want to get a breakdown of runtime spent in various parts of a program. This generally corresponds with what is used in the greater computer science literature. I don't think benchmarking is that good a word either, but I think it is a better fit for this.

##########
File path: include/tvm/meta_schedule/search_strategy.h
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+
+#include <tvm/tir/schedule/schedule.h>
+
+#include "./arg_info.h"
+#include "./runner.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+// Forward declaration
+class TuneContext;
+
+/*! \brief The measure candidate class. */
+class MeasureCandidateNode : public runtime::Object {
+ public:
+  /*! \brief The schedule for profiling. */
+  tir::Schedule sch;
+  /*! \brief The argument information. */
+  Array<ArgInfo> args_info;

Review comment:
       ```suggestion
     /*! \brief Argument information (shapes, dtypes) passed to the schedule when running. */
     Array<ArgInfo> args_info;
   ```

##########
File path: include/tvm/meta_schedule/search_strategy.h
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+
+#include <tvm/tir/schedule/schedule.h>
+
+#include "./arg_info.h"
+#include "./runner.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+// Forward declaration
+class TuneContext;
+
+/*! \brief The measure candidate class. */
+class MeasureCandidateNode : public runtime::Object {
+ public:
+  /*! \brief The schedule for profiling. */
+  tir::Schedule sch;
+  /*! \brief The argument information. */
+  Array<ArgInfo> args_info;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("sch", &sch);
+    v->Visit("args_info", &args_info);
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.MeasureCandidate";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MeasureCandidateNode.
+ * \sa MeasureCandidateNode
+ */
+class MeasureCandidate : public runtime::ObjectRef {
+ public:
+  /*!
+   * \brief Constructor of MeasureCandidate.
+   * \param sch The schedule for profiling.
+   * \param args_info The argument information.
+   */
+  TVM_DLL MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode);
+};
+
+/*! \brief The search strategy for measure candidates generation. */
+class SearchStrategyNode : public runtime::Object {
+ public:
+  /*! \brief Virtual destructor */
+  virtual ~SearchStrategyNode() = default;
+
+  /*!
+   * \brief Initialize the search strategy with tuning context.
+   * \param tune_context The tuning context for initialization.
+   */
+  virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0;

Review comment:
       Is this function called only once? If so, can you document that? If not, what is the search strategy supposed to do when reinitialized? Should it drop all internal state or should it start creating candidates that are from both contexts?

##########
File path: include/tvm/meta_schedule/search_strategy.h
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+
+#include <tvm/tir/schedule/schedule.h>
+
+#include "./arg_info.h"
+#include "./runner.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+// Forward declaration
+class TuneContext;
+
+/*! \brief The measure candidate class. */

Review comment:
       ```suggestion
   /*! \brief A schedule (with input shapes) to be measured. */
   ```

##########
File path: include/tvm/meta_schedule/runner.h
##########
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_RUNNER_H_
+#define TVM_META_SCHEDULE_RUNNER_H_
+
+#include <tvm/ir/expr.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+/*! \brief The runner's result. */
+class RunnerResultNode : public runtime::Object {
+ public:
+  /*! \brief The run time in seconds.*/
+  Optional<Array<FloatImm>> run_secs;
+  /*! \brief The error message, if any. */

Review comment:
       If an error message exists, then is `run_secs` empty? And vice versa? If so, could you document it? Also, maybe it would be helpful to provide a convenience method to check if the result is an error or not?
   
   If we want to go all way, we could have the `RunnerResult` be an abstract base class and have `ErrorResult` and `SuccessResult` subclasses. But that might be a bit of overkill (it would map well to python though).

##########
File path: include/tvm/meta_schedule/runner.h
##########
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_RUNNER_H_
+#define TVM_META_SCHEDULE_RUNNER_H_
+
+#include <tvm/ir/expr.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+/*! \brief The runner's result. */

Review comment:
       ```suggestion
   /*! \brief Output of benchmarking a MeasureCandidate with a Runner. */
   ```

##########
File path: src/meta_schedule/search_strategy/replay_trace.cc
##########
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+/*! \brief A search strategy that replays the trace. */
+class ReplayTraceNode : public SearchStrategyNode {
+ public:
+  using TRandState = support::LinearCongruentialEngine::TRandState;
+
+  /*! \brief The state of the search strategy. */
+  struct State {
+    /*! \brief The search strategy itself */
+    ReplayTraceNode* self;
+    /*! \brief The design spaces. */
+    Array<tir::Schedule> design_spaces;
+    /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
+    int st;
+    /*! \brief `[st, ed)` are the indices of the next batch of candidates. */
+    int ed;
+
+    explicit State(ReplayTraceNode* self, Array<tir::Schedule> design_spaces)
+        : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {}
+
+    inline Optional<Array<MeasureCandidate>> GenerateMeasureCandidates();
+    inline void NotifyRunnerResults(const Array<RunnerResult>& results);
+  };
+
+  /*! \brief The number of trials per iteration. */
+  int num_trials_per_iter;
+  /*! \brief The number of total trials. */
+  int num_trials_total;
+
+  /*! \brief The module to be tuned. */
+  IRModule mod_{nullptr};
+  /*! \brief The metadata of the function arguments. */
+  Array<ArgInfo> args_info_{nullptr};
+  /*! \brief The number of threads to use. */
+  int num_threads_ = -1;
+  /*! \brief The random state */
+  TRandState rand_state_ = -1;
+  /*! \brief The state of the search strategy. */
+  std::unique_ptr<State> state_ = nullptr;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("num_trials_per_iter", &num_trials_per_iter);
+    v->Visit("num_trials_total", &num_trials_total);
+    // `mod_` is not visited
+    // `args_info_` is not visited
+    // `num_threads_` is not visited
+    // `rand_state_` is not visited
+    // `state_` is not visited
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.ReplayTrace";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode);
+
+ public:

Review comment:
       nit: duplicate `public` (line 26 above)?

##########
File path: include/tvm/meta_schedule/search_strategy.h
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_
+
+#include <tvm/tir/schedule/schedule.h>
+
+#include "./arg_info.h"
+#include "./runner.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+// Forward declaration
+class TuneContext;
+
+/*! \brief The measure candidate class. */
+class MeasureCandidateNode : public runtime::Object {
+ public:
+  /*! \brief The schedule for profiling. */
+  tir::Schedule sch;
+  /*! \brief The argument information. */
+  Array<ArgInfo> args_info;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("sch", &sch);
+    v->Visit("args_info", &args_info);
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.MeasureCandidate";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MeasureCandidateNode.
+ * \sa MeasureCandidateNode
+ */
+class MeasureCandidate : public runtime::ObjectRef {
+ public:
+  /*!
+   * \brief Constructor of MeasureCandidate.
+   * \param sch The schedule for profiling.
+   * \param args_info The argument information.
+   */
+  TVM_DLL MeasureCandidate(tir::Schedule sch, Array<ArgInfo> args_info);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode);
+};
+
+/*! \brief The search strategy for measure candidates generation. */

Review comment:
       Given this is an interface, I think it would be helpful to have some documentation on how it is used during tuning. For example, in what order are the member functions called? You could provide a high-level pseudo code of how the tuning loop is run in regards to the search strategy.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org