You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/12/06 07:51:18 UTC
[tvm] branch main updated: [MetaSchedule][Hexagon] Add postproc for verifying VTCM usage (#13538)
This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6574e16034 [MetaSchedule][Hexagon] Add postproc for verifying VTCM usage (#13538)
6574e16034 is described below
commit 6574e1603452f6865949647bc8e3bed4dca5e55e
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Dec 6 16:51:12 2022 +0900
[MetaSchedule][Hexagon] Add postproc for verifying VTCM usage (#13538)
* add new postproc VerifyVTCMLimit
* remove pass
* add test
* add doc, missing file
* Add back VectorizeLoop in prereq lowering pass
* fix lint
---
include/tvm/meta_schedule/postproc.h | 5 +
include/tvm/tir/analysis.h | 8 ++
python/tvm/meta_schedule/postproc/__init__.py | 1 +
.../postproc/{__init__.py => verify_vtcm_limit.py} | 25 ++--
src/meta_schedule/postproc/postproc.cc | 7 +-
src/meta_schedule/postproc/verify_vtcm_limit.cc | 104 +++++++++++++++++
src/tir/analysis/calculate_allocated_memory.cc | 9 ++
...est_meta_schedule_postproc_verify_vtcm_limit.py | 127 +++++++++++++++++++++
8 files changed, 272 insertions(+), 14 deletions(-)
diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h
index 13fe470587..76f8d71ad6 100644
--- a/include/tvm/meta_schedule/postproc.h
+++ b/include/tvm/meta_schedule/postproc.h
@@ -144,6 +144,11 @@ class Postproc : public runtime::ObjectRef {
* \return The postprocessor created
*/
TVM_DLL static Postproc VerifyGPUCode();
+ /*!
+ * \brief Verifies that the VTCM usage of a given schedule is within the provided limit.
+ * \return The postprocessor created
+ */
+ TVM_DLL static Postproc VerifyVTCMLimit();
/*!
* \brief Creates a postprocessor that rewrites the layout of input tensor
* \note Weight layout rewrite is supported so far, activation layout rewrite will be added.
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index cb31a7e5ee..a8edc2675f 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -169,6 +169,14 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
+/*!
+ * \brief Verifies that the VTCM usage of the given prim_func is within the provided limit.
+ * \param func The function to be checked.
+ * \param limit The limit to check.
+ * \return true if the VTCM usage is within the provided limit.
+ */
+TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit);
+
/*!
* \brief Auto detect the block access region according to its body stmt
* It will detect the access region as an array in order of appearance in AST
diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py
index f70b740d7b..0598a53e2a 100644
--- a/python/tvm/meta_schedule/postproc/__init__.py
+++ b/python/tvm/meta_schedule/postproc/__init__.py
@@ -24,3 +24,4 @@ from .rewrite_reduction_block import RewriteReductionBlock
from .rewrite_tensorize import RewriteTensorize
from .rewrite_unbound_block import RewriteUnboundBlock
from .verify_gpu_code import VerifyGPUCode
+from .verify_vtcm_limit import VerifyVTCMLimit
diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py
similarity index 59%
copy from python/tvm/meta_schedule/postproc/__init__.py
copy to python/tvm/meta_schedule/postproc/verify_vtcm_limit.py
index f70b740d7b..28d202d5b3 100644
--- a/python/tvm/meta_schedule/postproc/__init__.py
+++ b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py
@@ -14,13 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""The tvm.meta_schedule.postproc package."""
-from .disallow_dynamic_loop import DisallowDynamicLoop
-from .postproc import Postproc, PyPostproc
-from .rewrite_cooperative_fetch import RewriteCooperativeFetch
-from .rewrite_layout import RewriteLayout
-from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll
-from .rewrite_reduction_block import RewriteReductionBlock
-from .rewrite_tensorize import RewriteTensorize
-from .rewrite_unbound_block import RewriteUnboundBlock
-from .verify_gpu_code import VerifyGPUCode
+"""A postprocessor that verifies the VTCM usage of a given schedule."""
+
+from tvm._ffi.registry import register_object
+from .. import _ffi_api
+from .postproc import Postproc
+
+
+@register_object("meta_schedule.VerifyVTCMLimit")
+class VerifyVTCMLimit(Postproc):
+ """Verifies that the VTCM usage of a given schedule is within the provided limit."""
+
+ def __init__(self) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.PostprocVerifyVTCMLimit, # type: ignore # pylint: disable=no-member
+ )
diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc
index c614f3230d..dba523d094 100644
--- a/src/meta_schedule/postproc/postproc.cc
+++ b/src/meta_schedule/postproc/postproc.cc
@@ -94,10 +94,9 @@ Array<Postproc> Postproc::DefaultCUDATensorCore() {
Array<Postproc> Postproc::DefaultHexagon() {
return Array<Postproc>{
- Postproc::DisallowDynamicLoop(),
- Postproc::RewriteParallelVectorizeUnroll(),
- Postproc::RewriteReductionBlock(),
- Postproc::RewriteLayout(),
+ Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
+ Postproc::RewriteReductionBlock(), Postproc::RewriteLayout(),
+ Postproc::VerifyVTCMLimit(),
};
}
diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc
new file mode 100644
index 0000000000..a6b577de9a
--- /dev/null
+++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/transform.h>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace meta_schedule {
+
+class VerifyVTCMLimitNode : public PostprocNode {
+ public:
+ Integer vtcm_capacity;
+
+ void InitializeWithTuneContext(const TuneContext& context) final {
+ ICHECK(context->target.defined());
+ Target target = context->target.value();
+ ICHECK(target->kind->name == "hexagon");
+ // The value of 0 will disable VTCM verification.
+ vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity").value_or(0);
+ }
+
+ bool Verify(const IRModule& mod) const {
+ for (const auto& kv : mod->functions) {
+ if (const auto* prim_func = kv.second.as<tir::PrimFuncNode>()) {
+ if (!tir::VerifyVTCMLimit(GetRef<tir::PrimFunc>(prim_func), vtcm_capacity)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ bool Apply(const tir::Schedule& sch) final {
+ IRModule mod = sch->mod();
+ for (const auto& kv : mod->functions) {
+ const GlobalVar& g_var = kv.first;
+ const BaseFunc& base_func = kv.second;
+ if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
+ IRModule lowered{nullptr};
+ try {
+ auto pass_list = Array<tvm::transform::Pass>();
+ pass_list.push_back(tir::transform::LowerInitBlock());
+ pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+ pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+ pass_list.push_back(tir::transform::CompactBufferAllocation());
+ pass_list.push_back(tir::transform::LowerMatchBuffer());
+ pass_list.push_back(tir::transform::InjectSoftwarePipeline());
+ pass_list.push_back(tir::transform::LowerOpaqueBlock());
+ pass_list.push_back(tir::transform::FlattenBuffer());
+ pass_list.push_back(tir::transform::Simplify());
+ pass_list.push_back(tir::transform::VectorizeLoop(true));
+ pass_list.push_back(tir::transform::StorageRewrite());
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+ tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
+ runtime::String(g_var->name_hint));
+ IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
+ lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
+ } catch (const dmlc::Error& e) {
+ return false;
+ }
+ if (!Verify(lowered)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ Postproc Clone() const {
+ ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>(*this);
+ return Postproc(n);
+ }
+
+ static constexpr const char* _type_key = "meta_schedule.VerifyVTCMLimit";
+ TVM_DECLARE_FINAL_OBJECT_INFO(VerifyVTCMLimitNode, PostprocNode);
+};
+
+Postproc Postproc::VerifyVTCMLimit() {
+ ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>();
+ return Postproc(n);
+}
+
+TVM_REGISTER_NODE_TYPE(VerifyVTCMLimitNode);
+TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit")
+ .set_body_typed(Postproc::VerifyVTCMLimit);
+
+} // namespace meta_schedule
+} // namespace tvm
diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc
index 01457508ab..9da8ec4355 100644
--- a/src/tir/analysis/calculate_allocated_memory.cc
+++ b/src/tir/analysis/calculate_allocated_memory.cc
@@ -87,6 +87,15 @@ TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](
return CalculateAllocatedBytes(func);
});
+bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
+ auto sizes = CalculateAllocatedBytes(func);
+ const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
+ if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
+ return false;
+ }
+ return true;
+}
+
namespace transform {
Pass VerifyVTCMLimit(const Integer& limit) {
diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_vtcm_limit.py b/tests/python/unittest/test_meta_schedule_postproc_verify_vtcm_limit.py
new file mode 100644
index 0000000000..55ea0a6ed8
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_postproc_verify_vtcm_limit.py
@@ -0,0 +1,127 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+import tvm
+import tvm.testing
+from tvm import meta_schedule as ms
+from tvm import tir
+from tvm.script import tir as T
+
+
+def _create_context(mod, target) -> ms.TuneContext:
+ return ms.TuneContext(
+ mod=mod,
+ target=target,
+ space_generator=ms.space_generator.PostOrderApply(
+ sch_rules=[],
+ postprocs=[ms.postproc.VerifyVTCMLimit()],
+ mutator_probs={},
+ ),
+ task_name="test",
+ )
+
+
+# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
+# fmt: off
+
+
+@tvm.script.ir_module
+class Conv2dNCHWcVTCM:
+ @T.prim_func
+ def main(p0: T.Buffer[(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)), "uint8"], p1: T.Buffer[(T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)), "uint8"], conv2d_NCHWc_int8: T.Buffer[(T.int64(1), T.int64(2), T.int64(54), T.int64(54), T.int64(32)), "int32"]):
+ T.func_attr({"tir.noalias": True, "global_symbol": "main"})
+ p0_global_vtcm = T.alloc_buffer([T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)], dtype="uint8", scope="global.vtcm")
+ p1_global_vtcm = T.alloc_buffer([T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)], dtype="uint8", scope="global.vtcm")
+ for n_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
+ for oc_chunk_0, oh_0, ow_0, oc_block_0_0 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(1)):
+ for oc_chunk_1_init, oh_1_init, ow_1_init, oc_chunk_2_init, oh_2_init, ow_2_init in T.grid(T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(1), T.int64(9)):
+ with T.block("conv2d_NCHWc_int8_o_init"):
+ v_n = T.axis.spatial(T.int64(1), T.int64(0))
+ v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1_init + oc_chunk_2_init + oc_chunk_0)
+ v_oh = T.axis.spatial(T.int64(54), oh_2_init + oh_0 * T.int64(27) + oh_1_init)
+ v_ow = T.axis.spatial(T.int64(54), ow_0 * T.int64(27) + ow_1_init * T.int64(9) + ow_2_init)
+ v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0))
+ T.reads()
+ T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)])
+ for oc_block_1 in T.vectorized(T.int64(32)):
+ with T.block("conv2d_NCHWc_int8_init"):
+ v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1)
+ T.reads()
+ T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init])
+ conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init] = 0
+ for kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused in T.serial(T.int64(2), annotations={"software_pipeline_async_stages":[0], "software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}):
+ for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(26912)):
+ with T.block("p0_global.vtcm"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_fused // T.int64(13456))
+ v2 = T.axis.spatial(T.int64(56), oh_0 * T.int64(27) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(13456) // T.int64(464))
+ v3 = T.axis.spatial(T.int64(56), ow_0 * T.int64(27) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(464) // T.int64(16))
+ v4 = T.axis.spatial(T.int64(32), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(16) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(16))
+ T.reads(p0[v0, v1, v2, v3, v4])
+ T.writes(p0_global_vtcm[v0, v1, v2, v3, v4])
+ p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4]
+ for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(9216)):
+ with T.block("p1_global.vtcm"):
+ v0 = T.axis.spatial(T.int64(2), oc_chunk_0)
+ v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(4608))
+ v2 = T.axis.spatial(T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4608) // T.int64(1536))
+ v3 = T.axis.spatial(T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(1536) // T.int64(512))
+ v4 = T.axis.spatial(T.int64(8), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(4) + ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(512) // T.int64(128))
+ v5 = T.axis.spatial(T.int64(32), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(128) // T.int64(4))
+ v6 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4))
+ T.reads(p1[v0, v1, v2, v3, v4, v5, v6])
+ T.writes(p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6])
+ p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = p1[v0, v1, v2, v3, v4, v5, v6]
+ for n_1, oc_chunk_1, oh_1, ow_1, oc_block_0_1, kh_1, kw_1, ic_outer_1, ic_f_inner_1, ic_s_inner_0_1, n_2, oc_chunk_2, oh_2, ow_2, oc_block_0_2 in T.grid(T.int64(1), T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(3), T.int64(3), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(9), T.int64(1)):
+ with T.block("conv2d_NCHWc_int8_o_update"):
+ v_n = T.axis.spatial(T.int64(1), T.int64(0))
+ v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1 + oc_chunk_2 + oc_chunk_0)
+ v_oh = T.axis.spatial(T.int64(54), oh_2 + oh_0 * T.int64(27) + oh_1)
+ v_ow = T.axis.spatial(T.int64(54), ow_0 * T.int64(27) + ow_1 * T.int64(9) + ow_2)
+ v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0))
+ v_kh, v_kw, v_ic_outer = T.axis.remap("RRR", [kh_1, kw_1, ic_outer_1])
+ v_ic_f_inner = T.axis.reduce(T.int64(8), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(4) + ic_f_inner_1)
+ v_ic_s_inner_o = T.axis.reduce(T.int64(1), T.int64(0))
+ T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4)], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, T.int64(0) : T.int64(32), T.int64(0) : T.int64(4)])
+ T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)])
+ for oc_block_1, ic_s_inner_1 in T.grid(T.int64(32), T.int64(4)):
+ with T.block("conv2d_NCHWc_int8"):
+ v_oc_block_i, v_ic_s_inner_i = T.axis.remap("SR", [oc_block_1, ic_s_inner_1])
+ T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i])
+ T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i])
+ T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
+ conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] = conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] + T.Cast("int32", p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i]) * T.Cast("int32", p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i])
+
+#fmt on
+
+
+def test_conv2d_vtcm():
+ def get_target(vtcm_cap):
+ target = tvm.target.hexagon("v68", vtcm_capacity=vtcm_cap)
+ return tvm.target.Target(target, host=target)
+
+ sch = tir.Schedule(Conv2dNCHWcVTCM, debug_mask="all")
+
+ ctx = _create_context(Conv2dNCHWcVTCM, target=get_target(70000))
+ assert not ctx.space_generator.postprocs[0].apply(sch)
+
+ ctx = _create_context(Conv2dNCHWcVTCM, target=get_target(75000))
+ assert ctx.space_generator.postprocs[0].apply(sch)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()