You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/12/17 19:20:46 UTC
[tvm] branch main updated: [MetaSchedule] Random Feature Extractor (#9760)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e1255c9 [MetaSchedule] Random Feature Extractor (#9760)
e1255c9 is described below
commit e1255c9c3963acf1aeca95e15c6b0d934ec041ac
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Dec 17 11:20:26 2021 -0800
[MetaSchedule] Random Feature Extractor (#9760)
Co-authored-by: Junru Shao <ju...@gmail.com>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
Co-authored-by: Ruihang Lai <la...@qq.com>
Co-authored-by: Hongyi Jin <32...@qq.com>
Co-authored-by: Wuwei Lin <wu...@apache.org>
Co-authored-by: Siyuan Feng <Hz...@sjtu.edu.cn>
---
include/tvm/meta_schedule/cost_model.h | 182 +++++++++++++++++++++
include/tvm/meta_schedule/feature_extractor.h | 121 ++++++++++++++
.../{search_strategy => cost_model}/__init__.py | 9 +-
python/tvm/meta_schedule/cost_model/cost_model.py | 149 +++++++++++++++++
.../__init__.py => cost_model/metric.py} | 29 +++-
.../tvm/meta_schedule/cost_model/random_model.py | 123 ++++++++++++++
.../__init__.py | 11 +-
.../feature_extractor/feature_extractor.py | 81 +++++++++
.../feature_extractor/random_feature_extractor.py | 62 +++++++
.../tvm/meta_schedule/search_strategy/__init__.py | 3 +-
python/tvm/meta_schedule/utils.py | 17 +-
src/meta_schedule/cost_model/cost_model.cc | 65 ++++++++
.../feature_extractor/feature_extractor.cc | 51 ++++++
src/meta_schedule/utils.h | 2 +
.../unittest/test_meta_schedule_cost_model.py | 143 ++++++++++++++++
.../test_meta_schedule_feature_extractor.py | 58 +++++++
16 files changed, 1084 insertions(+), 22 deletions(-)
diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h
new file mode 100644
index 0000000..b05dc3c
--- /dev/null
+++ b/include/tvm/meta_schedule/cost_model.h
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_META_SCHEDULE_COST_MODEL_H_
+#define TVM_META_SCHEDULE_COST_MODEL_H_
+
+#include <tvm/meta_schedule/search_strategy.h>
+
+#include <vector>
+
+namespace tvm {
+namespace meta_schedule {
+
+class TuneContext;
+
+/*! \brief Cost model. */
+class CostModelNode : public runtime::Object {
+ public:
+ /*! \brief Virtual destructor. */
+ virtual ~CostModelNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {}
+
+ /*!
+ * \brief Load the cost model from given file location.
+ * \param path The file path.
+ */
+ virtual void Load(const String& path) = 0;
+
+ /*!
+ * \brief Save the cost model to given file location.
+ * \param path The file path.
+ */
+ virtual void Save(const String& path) = 0;
+
+ /*!
+ * \brief Update the cost model given running results.
+ * \param tune_context The tuning context.
+ * \param candidates The measure candidates.
+ * \param results The running results of the measure candidates.
+ */
+ virtual void Update(const TuneContext& tune_context, const Array<MeasureCandidate>& candidates,
+ const Array<RunnerResult>& results) = 0;
+
+ /*!
+ * \brief Predict the normalized score (the larger the better) of given measure candidates.
+ * \param tune_context The tuning context.
+ * \param candidates The measure candidates.
+ * \return The predicted normalized score.
+ */
+ virtual std::vector<double> Predict(const TuneContext& tune_context,
+ const Array<MeasureCandidate>& candidates) = 0;
+
+ static constexpr const char* _type_key = "meta_schedule.CostModel";
+ TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object);
+};
+
+/*! \brief The cost model with customized methods on the python-side. */
+class PyCostModelNode : public CostModelNode {
+ public:
+ /*!
+ * \brief Load the cost model from given file location.
+ * \param path The file path.
+ */
+ using FLoad = runtime::TypedPackedFunc<void(String)>;
+ /*!
+ * \brief Save the cost model to given file location.
+ * \param path The file path.
+ */
+ using FSave = runtime::TypedPackedFunc<void(String)>;
+ /*!
+ * \brief Update the cost model given running results.
+ * \param tune_context The tuning context.
+ * \param candidates The measure candidates.
+ * \param results The running results of the measure candidates.
+ * \return Whether cost model was updated successfully.
+ */
+ using FUpdate = runtime::TypedPackedFunc<void(const TuneContext&, const Array<MeasureCandidate>&,
+ const Array<RunnerResult>&)>;
+ /*!
+ * \brief Predict the running results of given measure candidates.
+ * \param tune_context The tuning context.
+ * \param candidates The measure candidates.
+ * \param p_addr The address to save the the estimated running results.
+ */
+ using FPredict = runtime::TypedPackedFunc<void(const TuneContext&, const Array<MeasureCandidate>&,
+ void* p_addr)>;
+ /*!
+ * \brief Get the cost model as string with name.
+ * \return The string representation of the cost model.
+ */
+ using FAsString = runtime::TypedPackedFunc<String()>;
+
+ /*! \brief The packed function to the `Load` function. */
+ FLoad f_load;
+ /*! \brief The packed function to the `Save` function. */
+ FSave f_save;
+ /*! \brief The packed function to the `Update` function. */
+ FUpdate f_update;
+ /*! \brief The packed function to the `Predict` function. */
+ FPredict f_predict;
+ /*! \brief The packed function to the `AsString` function. */
+ FAsString f_as_string;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_load` is not visited
+ // `f_save` is not visited
+ // `f_update` is not visited
+ // `f_predict` is not visited
+ // `f_as_string` is not visited
+ }
+
+ void Load(const String& path) {
+ ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!";
+ f_load(path);
+ }
+
+ void Save(const String& path) {
+ ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!";
+ f_save(path);
+ }
+ void Update(const TuneContext& tune_context, const Array<MeasureCandidate>& candidates,
+ const Array<RunnerResult>& results) {
+ ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!";
+ f_update(tune_context, candidates, results);
+ }
+
+ std::vector<double> Predict(const TuneContext& tune_context,
+ const Array<MeasureCandidate>& candidates) {
+ ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!";
+ std::vector<double> result(candidates.size(), 0.0);
+ f_predict(tune_context, candidates, result.data());
+ return result;
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.PyCostModel";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode);
+};
+
+/*!
+ * \brief Managed reference to CostModelNode
+ * \sa CostModelNode
+ */
+class CostModel : public runtime::ObjectRef {
+ public:
+ /*!
+ * \brief Create a feature extractor with customized methods on the python-side.
+ * \param f_load The packed function of `Load`.
+ * \param f_save The packed function of `Save`.
+ * \param f_update The packed function of `Update`.
+ * \param f_predict The packed function of `Predict`.
+ * \param f_as_string The packed function of `AsString`.
+ * \return The feature extractor created.
+ */
+ TVM_DLL static CostModel PyCostModel(PyCostModelNode::FLoad f_load, //
+ PyCostModelNode::FSave f_save, //
+ PyCostModelNode::FUpdate f_update, //
+ PyCostModelNode::FPredict f_predict, //
+ PyCostModelNode::FAsString f_as_string);
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CostModel, ObjectRef, CostModelNode);
+};
+
+} // namespace meta_schedule
+} // namespace tvm
+
+#endif // TVM_META_SCHEDULE_COST_MODEL_H_
diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h
new file mode 100644
index 0000000..ee5d94c
--- /dev/null
+++ b/include/tvm/meta_schedule/feature_extractor.h
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_
+#define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_
+
+#include <tvm/meta_schedule/search_strategy.h>
+
+namespace tvm {
+namespace meta_schedule {
+
+class TuneContext;
+
+/*! \brief Extractor for features from measure candidates for use in cost model. */
+class FeatureExtractorNode : public runtime::Object {
+ public:
+ /*! \brief Virtual destructor. */
+ virtual ~FeatureExtractorNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {}
+
+ /*!
+ * \brief Extract features from the given measure candidate.
+ * \param tune_context The tuning context for feature extraction.
+ * \param candidates The measure candidates to extract features from.
+ * \return The feature ndarray extracted.
+ */
+ virtual Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
+ const Array<MeasureCandidate>& candidates) = 0;
+
+ static constexpr const char* _type_key = "meta_schedule.FeatureExtractor";
+ TVM_DECLARE_BASE_OBJECT_INFO(FeatureExtractorNode, Object);
+};
+
+/*! \brief The feature extractor with customized methods on the python-side. */
+class PyFeatureExtractorNode : public FeatureExtractorNode {
+ public:
+ /*!
+ * \brief Extract features from the given measure candidate.
+ * \param tune_context The tuning context for feature extraction.
+ * \param candidates The measure candidates to extract features from.
+ * \return The feature ndarray extracted.
+ */
+ using FExtractFrom = runtime::TypedPackedFunc<Array<tvm::runtime::NDArray>(
+ const TuneContext& tune_context, const Array<MeasureCandidate>& candidates)>;
+ /*!
+ * \brief Get the feature extractor as string with name.
+ * \return The string of the feature extractor.
+ */
+ using FAsString = runtime::TypedPackedFunc<String()>;
+
+ /*! \brief The packed function to the `ExtractFrom` function. */
+ FExtractFrom f_extract_from;
+ /*! \brief The packed function to the `AsString` function. */
+ FAsString f_as_string;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ // `f_extract_from` is not visited
+ // `f_as_string` is not visited
+ }
+
+ Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
+ const Array<MeasureCandidate>& candidates) {
+ ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!";
+ return f_extract_from(tune_context, candidates);
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode);
+};
+
+/*!
+ * \brief Managed reference to FeatureExtractorNode
+ * \sa FeatureExtractorNode
+ */
+class FeatureExtractor : public runtime::ObjectRef {
+ public:
+ /*!
+ * \brief Create a feature extractor that extracts features from each BufferStore
+ * \param buffers_per_store The number of buffers in each BufferStore; Pad or truncate if
+ * necessary.
+ * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity
+ * curve.
+ * \param cache_line_bytes The number of bytes in a cache line.
+ * \return The feature extractor created.
+ */
+ TVM_DLL static FeatureExtractor PerStoreFeature(int buffers_per_store = 5,
+ int arith_intensity_curve_num_samples = 10,
+ int cache_line_bytes = 64);
+ /*!
+ * \brief Create a feature extractor with customized methods on the python-side.
+ * \param f_extract_from The packed function of `ExtractFrom`.
+ * \param f_as_string The packed function of `AsString`.
+ * \return The feature extractor created.
+ */
+ TVM_DLL static FeatureExtractor PyFeatureExtractor(
+ PyFeatureExtractorNode::FExtractFrom f_extract_from,
+ PyFeatureExtractorNode::FAsString f_as_string);
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FeatureExtractor, ObjectRef, FeatureExtractorNode);
+};
+
+} // namespace meta_schedule
+} // namespace tvm
+
+#endif // TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_
diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py
similarity index 76%
copy from python/tvm/meta_schedule/search_strategy/__init__.py
copy to python/tvm/meta_schedule/cost_model/__init__.py
index 609baa2..3d4a81e 100644
--- a/python/tvm/meta_schedule/search_strategy/__init__.py
+++ b/python/tvm/meta_schedule/cost_model/__init__.py
@@ -15,10 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""
-The tvm.meta_schedule.search_strategy package.
-Meta Schedule search strategy utilizes the design spaces given
-to generate measure candidates.
+The tvm.meta_schedule.cost_model package.
"""
-
-from .search_strategy import SearchStrategy, PySearchStrategy
-from .replay_trace import ReplayTrace
+from .cost_model import CostModel, PyCostModel
+from .random_model import RandomModel
diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py
new file mode 100644
index 0000000..f5bd601
--- /dev/null
+++ b/python/tvm/meta_schedule/cost_model/cost_model.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Meta Schedule CostModel."""
+import ctypes
+from typing import List
+
+import numpy as np # type: ignore
+from tvm._ffi import register_object
+from tvm.runtime import Object
+
+from .. import _ffi_api
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..tune_context import TuneContext
+from ..utils import _get_hex_address, check_override
+
+
+@register_object("meta_schedule.CostModel")
+class CostModel(Object):
+ """Cost model."""
+
+ def load(self, path: str) -> None:
+ """Load the cost model from given file location.
+
+ Parameters
+ ----------
+ path : str
+ The file path.
+ """
+ _ffi_api.CostModelLoad(self, path) # type: ignore # pylint: disable=no-member
+
+ def save(self, path: str) -> None:
+ """Save the cost model to given file location.
+
+ Parameters
+ ----------
+ path : str
+ The file path.
+ """
+ _ffi_api.CostModelSave(self, path) # type: ignore # pylint: disable=no-member
+
+ def update(
+ self,
+ tune_context: TuneContext,
+ candidates: List[MeasureCandidate],
+ results: List[RunnerResult],
+ ) -> None:
+ """Update the cost model given running results.
+
+ Parameters
+ ----------
+ tune_context : TuneContext,
+ The tuning context.
+ candidates : List[MeasureCandidate]
+ The measure candidates.
+ results : List[RunnerResult]
+ The running results of the measure candidates.
+ """
+ _ffi_api.CostModelUpdate(self, tune_context, candidates, results) # type: ignore # pylint: disable=no-member
+
+ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
+ """Update the cost model given running results.
+
+ Parameters
+ ----------
+ tune_context : TuneContext,
+ The tuning context.
+ candidates : List[MeasureCandidate]
+ The measure candidates.
+
+ Return
+ ------
+ result : np.ndarray
+ The predicted normalized score.
+ """
+ n = len(candidates)
+ results = np.zeros(shape=(n,), dtype="float64")
+ _ffi_api.CostModelPredict( # type: ignore # pylint: disable=no-member
+ self,
+ tune_context,
+ candidates,
+ results.ctypes.data_as(ctypes.c_void_p),
+ )
+ return results
+
+
+@register_object("meta_schedule.PyCostModel")
+class PyCostModel(CostModel):
+ """An abstract CostModel with customized methods on the python-side."""
+
+ def __init__(self):
+ """Constructor."""
+
+ @check_override(self.__class__, CostModel)
+ def f_load(path: str) -> None:
+ self.load(path)
+
+ @check_override(self.__class__, CostModel)
+ def f_save(path: str) -> None:
+ self.save(path)
+
+ @check_override(self.__class__, CostModel)
+ def f_update(
+ tune_context: TuneContext,
+ candidates: List[MeasureCandidate],
+ results: List[RunnerResult],
+ ) -> None:
+ self.update(tune_context, candidates, results)
+
+ @check_override(self.__class__, CostModel)
+ def f_predict(
+ tune_context: TuneContext, candidates: List[MeasureCandidate], return_ptr
+ ) -> None:
+ n = len(candidates)
+ return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_double))
+ array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,))
+ array_wrapper[:] = self.predict(tune_context, candidates)
+ assert (
+ array_wrapper.dtype == "float64"
+ ), "ValueError: Invalid data type returned from CostModel Predict!"
+
+ def f_as_string() -> str:
+ return str(self)
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.CostModelPyCostModel, # type: ignore # pylint: disable=no-member
+ f_load,
+ f_save,
+ f_update,
+ f_predict,
+ f_as_string,
+ )
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/cost_model/metric.py
similarity index 59%
copy from python/tvm/meta_schedule/search_strategy/__init__.py
copy to python/tvm/meta_schedule/cost_model/metric.py
index 609baa2..efd8dc6 100644
--- a/python/tvm/meta_schedule/search_strategy/__init__.py
+++ b/python/tvm/meta_schedule/cost_model/metric.py
@@ -14,11 +14,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-The tvm.meta_schedule.search_strategy package.
-Meta Schedule search strategy utilizes the design spaces given
-to generate measure candidates.
-"""
+"""Cost model metrics for meta schedule"""
+import numpy as np # type: ignore
-from .search_strategy import SearchStrategy, PySearchStrategy
-from .replay_trace import ReplayTrace
+
+def max_curve(trial_scores: np.ndarray) -> np.ndarray:
+ """f(n) = max([s[i] fo i < n])
+
+ Parameters
+ ----------
+ trial_scores : List[float]
+ the score of i-th trial
+
+ Returns
+ -------
+ curve : np.ndarray
+ A vector, the max-curve function values
+ """
+ ret = np.empty(len(trial_scores))
+ keep = -1e9
+ for i, score in enumerate(trial_scores):
+ keep = max(keep, score)
+ ret[i] = keep
+ return ret
diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py
new file mode 100644
index 0000000..23238d2
--- /dev/null
+++ b/python/tvm/meta_schedule/cost_model/random_model.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Random cost model
+"""
+from typing import List, Optional, Tuple, Union
+
+import numpy as np # type: ignore
+
+from ..cost_model import PyCostModel
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..tune_context import TuneContext
+
+
+class RandomModel(PyCostModel):
+ """Random cost model
+
+ Parameters
+ ----------
+ random_state : Union[Tuple[str, np.ndarray, int, int, float], dict]
+ The random state of the random number generator.
+ path : Optional[str]
+ The path of the random cost model.
+ max_range : Optional[int]
+ The maximum range of random results, [0, max_range].
+
+ Reference
+ ---------
+ https://numpy.org/doc/stable/reference/random/generated/numpy.random.get_state.html
+ """
+
+ random_state: Union[Tuple[str, np.ndarray, int, int, float], dict]
+ path: Optional[str]
+
+ def __init__(
+ self,
+ *,
+ seed: Optional[int] = None,
+ path: Optional[str] = None,
+ max_range: Optional[int] = 100,
+ ):
+ super().__init__()
+ if path is not None:
+ self.load(path)
+ else:
+ np.random.seed(seed)
+ self.random_state = np.random.get_state()
+ self.max_range = max_range
+
+ def load(self, path: str) -> None:
+ """Load the cost model from given file location.
+
+ Parameters
+ ----------
+ path : str
+ The file path.
+ """
+ self.random_state = tuple(np.load(path, allow_pickle=True)) # type: ignore
+
+ def save(self, path: str) -> None:
+ """Save the cost model to given file location.
+
+ Parameters
+ ----------
+ path : str
+ The file path.
+ """
+ np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True)
+
+ def update(
+ self,
+ tune_context: TuneContext,
+ candidates: List[MeasureCandidate],
+ results: List[RunnerResult],
+ ) -> None:
+ """Update the cost model given running results.
+
+ Parameters
+ ----------
+ tune_context : TuneContext,
+ The tuning context.
+ candidates : List[MeasureCandidate]
+ The measure candidates.
+ results : List[RunnerResult]
+ The running results of the measure candidates.
+ """
+
+ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray:
+ """Update the cost model given running results.
+
+ Parameters
+ ----------
+ tune_context : TuneContext,
+ The tuning context.
+ candidates : List[MeasureCandidate]
+ The measure candidates.
+
+ Return
+ ------
+ result : np.ndarray
+ The predicted running results.
+ """
+ np.random.set_state(self.random_state)
+ # TODO(@zxybazh): Use numpy's RandState object:
+ # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState
+ result = np.random.rand(len(candidates)) * self.max_range
+ self.random_state = np.random.get_state()
+ return result
diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/feature_extractor/__init__.py
similarity index 73%
copy from python/tvm/meta_schedule/search_strategy/__init__.py
copy to python/tvm/meta_schedule/feature_extractor/__init__.py
index 609baa2..f29c44b 100644
--- a/python/tvm/meta_schedule/search_strategy/__init__.py
+++ b/python/tvm/meta_schedule/feature_extractor/__init__.py
@@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""
-The tvm.meta_schedule.search_strategy package.
-Meta Schedule search strategy utilizes the design spaces given
-to generate measure candidates.
+The tvm.meta_schedule.feature_extractor package.
+Meta Schedule feature extractors that extracts features from
+measure candidates for use in cost model.
"""
-
-from .search_strategy import SearchStrategy, PySearchStrategy
-from .replay_trace import ReplayTrace
+from .feature_extractor import FeatureExtractor, PyFeatureExtractor
+from .random_feature_extractor import RandomFeatureExtractor
diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py
new file mode 100644
index 0000000..bd7656e
--- /dev/null
+++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Meta Schedule FeatureExtractor."""
+from typing import List
+
+from tvm._ffi import register_object
+from tvm.runtime import Object
+from tvm.runtime.ndarray import NDArray
+
+from .. import _ffi_api
+from ..utils import _get_hex_address, check_override
+from ..tune_context import TuneContext
+from ..search_strategy import MeasureCandidate
+
+
+@register_object("meta_schedule.FeatureExtractor")
+class FeatureExtractor(Object):
+ """Extractor for features from measure candidates for use in cost model."""
+
+ def extract_from(
+ self, tune_context: TuneContext, candidates: List[MeasureCandidate]
+ ) -> List[NDArray]:
+ """Extract features from the given measure candidate.
+
+ Parameters
+ ----------
+ tune_context : TuneContext
+ The tuning context for feature extraction.
+ candidates : List[MeasureCandidate]
+ The measure candidates to extract features from.
+
+ Returns
+ -------
+ features : List[NDArray]
+ The feature numpy ndarray extracted.
+ """
+ result = _ffi_api.FeatureExtractorExtractFrom( # type: ignore # pylint: disable=no-member
+ self, tune_context, candidates
+ )
+ return result
+
+
+@register_object("meta_schedule.PyFeatureExtractor")
+class PyFeatureExtractor(FeatureExtractor):
+ """An abstract feature extractor with customized methods on the python-side."""
+
+ def __init__(self):
+ """Constructor."""
+
+ @check_override(self.__class__, FeatureExtractor)
+ def f_extract_from(
+ tune_context: TuneContext, candidates: List[MeasureCandidate]
+ ) -> List[NDArray]:
+ features = self.extract_from(tune_context, candidates)
+ return features
+
+ def f_as_string() -> str:
+ return str(self)
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.FeatureExtractorPyFeatureExtractor, # type: ignore # pylint: disable=no-member
+ f_extract_from,
+ f_as_string,
+ )
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py
new file mode 100644
index 0000000..7c72a25
--- /dev/null
+++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Random Feature Extractor."""
+from typing import List, Union, Tuple
+
+import numpy as np # type: ignore
+from tvm.runtime.ndarray import NDArray, array
+
+from ..tune_context import TuneContext
+from ..search_strategy import MeasureCandidate
+from ..feature_extractor import PyFeatureExtractor
+
+
+class RandomFeatureExtractor(PyFeatureExtractor):
+ """Random Feature Extractor
+
+ Parameters
+ ----------
+ feature_size : int
+ The size of each block's feature vector.
+ max_block_num : int
+ The maximum number of blocks in each schedule.
+ random_state : Union[Tuple[str, np.ndarray, int, int, float], dict]
+ The current random state of the f
+ """
+
+ feature_size: int
+ max_block_num: int
+ random_state: Union[Tuple[str, np.ndarray, int, int, float], dict]
+
+ def __init__(self, *, feature_size: int = 30, max_block_num: int = 5, seed=0):
+ super().__init__()
+ assert max_block_num >= 1, "Max block number must be greater or equal to one!"
+ self.max_block_num = max_block_num
+ self.feature_size = feature_size
+ np.random.seed(seed)
+ self.random_state = np.random.get_state()
+
+ def extract_from(
+ self, tune_context: TuneContext, candidates: List[MeasureCandidate]
+ ) -> List[NDArray]:
+ np.random.set_state(self.random_state)
+ result = [
+ np.random.rand(np.random.randint(1, self.max_block_num + 1), self.feature_size)
+ for candidate in candidates
+ ]
+ self.random_state = np.random.get_state()
+ return [array(x) for x in result]
diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py
index 609baa2..298cdae 100644
--- a/python/tvm/meta_schedule/search_strategy/__init__.py
+++ b/python/tvm/meta_schedule/search_strategy/__init__.py
@@ -19,6 +19,5 @@ The tvm.meta_schedule.search_strategy package.
Meta Schedule search strategy utilizes the design spaces given
to generate measure candidates.
"""
-
-from .search_strategy import SearchStrategy, PySearchStrategy
+from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy
from .replay_trace import ReplayTrace
diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py
index a9ef514..aaaa956 100644
--- a/python/tvm/meta_schedule/utils.py
+++ b/python/tvm/meta_schedule/utils.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Utilities for meta schedule"""
+import ctypes
import json
import os
import shutil
@@ -24,7 +25,7 @@ import psutil # type: ignore
import tvm
from tvm._ffi import get_global_func, register_func
from tvm.error import TVMError
-from tvm.ir import Array, Map, IRModule
+from tvm.ir import Array, IRModule, Map
from tvm.rpc import RPCSession
from tvm.runtime import PackedFunc, String
from tvm.tir import FloatImm, IntImm
@@ -245,3 +246,17 @@ def check_override(
return func
return inner
+
+
+def _get_hex_address(handle: ctypes.c_void_p) -> str:
+ """Get the hexadecimal address of a handle.
+ Parameters
+ ----------
+ handle : ctypes.c_void_p
+ The handle to be converted.
+ Returns
+ -------
+ result : str
+ The hexadecimal address of the handle.
+ """
+ return hex(ctypes.cast(handle, ctypes.c_void_p).value)
diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc
new file mode 100644
index 0000000..5cd32b0
--- /dev/null
+++ b/src/meta_schedule/cost_model/cost_model.cc
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, //
+ PyCostModelNode::FSave f_save, //
+ PyCostModelNode::FUpdate f_update, //
+ PyCostModelNode::FPredict f_predict, //
+ PyCostModelNode::FAsString f_as_string) {
+ ObjectPtr<PyCostModelNode> n = make_object<PyCostModelNode>();
+ n->f_load = std::move(f_load);
+ n->f_save = std::move(f_save);
+ n->f_update = std::move(f_update);
+ n->f_predict = std::move(f_predict);
+ n->f_as_string = std::move(f_as_string);
+ return CostModel(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<PyCostModelNode>([](const ObjectRef& n, ReprPrinter* p) {
+ const auto* self = n.as<PyCostModelNode>();
+ ICHECK(self);
+ PyCostModelNode::FAsString f_as_string = (*self).f_as_string;
+ ICHECK(f_as_string != nullptr) << "PyCostModel's AsString method not implemented!";
+ p->stream << f_as_string();
+ });
+
+TVM_REGISTER_OBJECT_TYPE(CostModelNode);
+TVM_REGISTER_NODE_TYPE(PyCostModelNode);
+
+TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method<CostModel>(&CostModelNode::Load);
+TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method<CostModel>(&CostModelNode::Save);
+TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate")
+ .set_body_method<CostModel>(&CostModelNode::Update);
+TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict")
+ .set_body_typed([](CostModel model, //
+ const TuneContext& tune_context, //
+ Array<MeasureCandidate> candidates, //
+ void* p_addr) -> void {
+ std::vector<double> result = model->Predict(tune_context, candidates);
+ std::copy(result.begin(), result.end(), static_cast<double*>(p_addr));
+ });
+TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc
new file mode 100644
index 0000000..84d2249
--- /dev/null
+++ b/src/meta_schedule/feature_extractor/feature_extractor.cc
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+FeatureExtractor FeatureExtractor::PyFeatureExtractor(
+ PyFeatureExtractorNode::FExtractFrom f_extract_from, //
+ PyFeatureExtractorNode::FAsString f_as_string) {
+ ObjectPtr<PyFeatureExtractorNode> n = make_object<PyFeatureExtractorNode>();
+ n->f_extract_from = std::move(f_extract_from);
+ n->f_as_string = std::move(f_as_string);
+ return FeatureExtractor(n);
+}
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<PyFeatureExtractorNode>([](const ObjectRef& n, ReprPrinter* p) {
+ const auto* self = n.as<PyFeatureExtractorNode>();
+ ICHECK(self);
+ PyFeatureExtractorNode::FAsString f_as_string = (*self).f_as_string;
+ ICHECK(f_as_string != nullptr) << "PyFeatureExtractor's AsString method not implemented!";
+ p->stream << f_as_string();
+ });
+
+TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode);
+TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode);
+
+TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom")
+ .set_body_method<FeatureExtractor>(&FeatureExtractorNode::ExtractFrom);
+TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor")
+ .set_body_typed(FeatureExtractor::PyFeatureExtractor);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 83e65a5..9b0a371 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -22,7 +22,9 @@
#include <dmlc/memory_io.h>
#include <tvm/meta_schedule/arg_info.h>
#include <tvm/meta_schedule/builder.h>
+#include <tvm/meta_schedule/cost_model.h>
#include <tvm/meta_schedule/database.h>
+#include <tvm/meta_schedule/feature_extractor.h>
#include <tvm/meta_schedule/runner.h>
#include <tvm/meta_schedule/search_strategy.h>
#include <tvm/meta_schedule/space_generator.h>
diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py
new file mode 100644
index 0000000..3f98d71
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_cost_model.py
@@ -0,0 +1,143 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import os
+import re
+import shutil
+import sys
+import tempfile
+from typing import List
+
+import numpy as np
+import pytest
+import tvm
+from tvm.meta_schedule.cost_model import PyCostModel, RandomModel
+from tvm.meta_schedule.runner import RunnerResult
+from tvm.meta_schedule.search_strategy import MeasureCandidate
+from tvm.meta_schedule.tune_context import TuneContext
+from tvm.script import tir as T
+from tvm.tir.schedule.schedule import Schedule
+
+# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring
+@tvm.script.ir_module
+class Matmul:
+ @T.prim_func
+ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ A = T.match_buffer(a, (1024, 1024), "float32")
+ B = T.match_buffer(b, (1024, 1024), "float32")
+ C = T.match_buffer(c, (1024, 1024), "float32")
+ for i, j, k in T.grid(1024, 1024, 1024):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+
+# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,disable=unused-argument
+
+
+def test_meta_schedule_cost_model():
+ class FancyCostModel(PyCostModel):
+ def load(self, path: str) -> None:
+ pass
+
+ def save(self, path: str) -> None:
+ pass
+
+ def update(
+ self,
+ tune_context: TuneContext,
+ candidates: List[MeasureCandidate],
+ results: List[RunnerResult],
+ ) -> None:
+ pass
+
+ def predict(
+ self, tune_context: TuneContext, candidates: List[MeasureCandidate]
+ ) -> np.ndarray:
+ return np.random.rand(10)
+
+ model = FancyCostModel()
+ model.save("fancy_test_location")
+ model.load("fancy_test_location")
+ model.update(TuneContext(), [], [])
+ results = model.predict(TuneContext, [MeasureCandidate(Schedule(mod=Matmul), [])])
+ assert results.shape == (10,)
+
+
+def test_meta_schedule_cost_model_as_string():
+ class NotSoFancyCostModel(PyCostModel):
+ def load(self, path: str) -> None:
+ pass
+
+ def save(self, path: str) -> None:
+ pass
+
+ def update(
+ self,
+ tune_context: TuneContext,
+ candidates: List[MeasureCandidate],
+ results: List[RunnerResult],
+ ) -> None:
+ pass
+
+ def predict(
+ self, tune_context: TuneContext, candidates: List[MeasureCandidate]
+ ) -> np.ndarray:
+ return np.random.rand(10)
+
+ cost_model = NotSoFancyCostModel()
+ pattern = re.compile(r"NotSoFancyCostModel\(0x[a-f|0-9]*\)")
+ assert pattern.match(str(cost_model))
+
+
+def test_meta_schedule_random_model():
+ model = RandomModel()
+ model.update(TuneContext(), [], [])
+ res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(10)])
+ assert len(res) == 10
+ assert min(res) >= 0 and max(res) <= model.max_range
+
+
+def test_meta_schedule_random_model_reseed():
+ model = RandomModel(seed=100)
+ res = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)])
+ new_model = RandomModel(seed=100)
+ new_res = new_model.predict(
+ TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(20)]
+ )
+ assert (res == new_res).all()
+
+
+def test_meta_schedule_random_model_reload():
+ model = RandomModel(seed=25973)
+ model.predict(
+ TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(30)]
+ ) # change state
+ path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_random_model.npy")
+ model.save(path)
+ res1 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)])
+ model.load(path)
+ res2 = model.predict(TuneContext(), [MeasureCandidate(Schedule(Matmul), []) for i in range(70)])
+ shutil.rmtree(os.path.dirname(path))
+ assert (res1 == res2).all()
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py
new file mode 100644
index 0000000..143d446
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import re
+from typing import List
+
+import numpy as np
+from tvm.meta_schedule import TuneContext
+from tvm.meta_schedule.feature_extractor import PyFeatureExtractor
+from tvm.meta_schedule.search_strategy import MeasureCandidate
+
+
+def test_meta_schedule_feature_extractor():
+ class FancyFeatureExtractor(PyFeatureExtractor):
+ def extract_from(
+ self,
+ tune_context: TuneContext, # pylint: disable = unused-argument
+ candidates: List[MeasureCandidate], # pylint: disable = unused-argument
+ ) -> List[np.ndarray]:
+ return [np.random.rand(4, 5)]
+
+ extractor = FancyFeatureExtractor()
+ features = extractor.extract_from(TuneContext(), [])
+ assert len(features) == 1
+ assert features[0].shape == (4, 5)
+
+
+def test_meta_schedule_feature_extractor_as_string():
+ class NotSoFancyFeatureExtractor(PyFeatureExtractor):
+ def extract_from(
+ self,
+ tune_context: TuneContext, # pylint: disable = unused-argument
+ candidates: List[MeasureCandidate], # pylint: disable = unused-argument
+ ) -> List[np.ndarray]:
+ return []
+
+ feature_extractor = NotSoFancyFeatureExtractor()
+ pattern = re.compile(r"NotSoFancyFeatureExtractor\(0x[a-f|0-9]*\)")
+ assert pattern.match(str(feature_extractor))
+
+
+if __name__ == "__main__":
+ test_meta_schedule_feature_extractor()
+ test_meta_schedule_feature_extractor_as_string()