You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/15 00:03:24 UTC
[incubator-tvm] branch master updated: [RELAY] Remove re-exports of
tvm.transform (#5337)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 275e317 [RELAY] Remove re-exports of tvm.transform (#5337)
275e317 is described below
commit 275e317c568a75db8a13960bcb9112f7859ef9aa
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Tue Apr 14 17:03:15 2020 -0700
[RELAY] Remove re-exports of tvm.transform (#5337)
---
docs/api/python/ir.rst | 8 ++
docs/dev/convert_layout.rst | 2 +-
docs/dev/relay_pass_infra.rst | 4 +-
include/tvm/ir/transform.h | 4 +-
python/tvm/ir/json_compact.py | 2 +-
python/tvm/ir/transform.py | 7 +-
python/tvm/relay/__init__.py | 11 ---
python/tvm/relay/backend/interpreter.py | 8 +-
python/tvm/relay/qnn/transform.py | 4 +-
python/tvm/relay/quantize/quantize.py | 33 ++++----
python/tvm/relay/testing/__init__.py | 2 +-
python/tvm/relay/testing/py_converter.py | 4 +-
python/tvm/relay/transform/transform.py | 89 +++++++++-------------
src/ir/transform.cc | 6 +-
src/relay/transforms/print_ir.cc | 49 ------------
tests/python/relay/test_op_level10.py | 8 +-
tests/python/relay/test_pass_alter_op_layout.py | 4 +-
tests/python/relay/test_pass_annotation.py | 4 +-
tests/python/relay/test_pass_canonicalize_cast.py | 4 +-
.../relay/test_pass_combine_parallel_conv2d.py | 2 +-
.../relay/test_pass_combine_parallel_dense.py | 2 +-
tests/python/relay/test_pass_convert_op_layout.py | 4 +-
.../relay/test_pass_dead_code_elimination.py | 2 +-
.../relay/test_pass_eliminate_common_subexpr.py | 2 +-
tests/python/relay/test_pass_eta_expand.py | 8 +-
tests/python/relay/test_pass_fold_constant.py | 4 +-
tests/python/relay/test_pass_fold_scale_axis.py | 2 +-
tests/python/relay/test_pass_lazy_gradient_init.py | 26 +++----
tests/python/relay/test_pass_legalize.py | 4 +-
tests/python/relay/test_pass_mac_count.py | 2 +-
tests/python/relay/test_pass_manager.py | 34 ++++-----
tests/python/relay/test_pass_partial_eval.py | 6 +-
tests/python/relay/test_pass_partition_graph.py | 8 +-
tests/python/relay/test_pass_qnn_legalize.py | 4 +-
tests/python/relay/test_pass_to_a_normal_form.py | 4 +-
tests/python/relay/test_pass_to_cps.py | 3 +-
tutorials/dev/relay_pass_infra.py | 26 +++----
vta/python/vta/top/graphpack.py | 2 +-
38 files changed, 169 insertions(+), 229 deletions(-)
diff --git a/docs/api/python/ir.rst b/docs/api/python/ir.rst
index 1f0dc0c..c2a1a1e 100644
--- a/docs/api/python/ir.rst
+++ b/docs/api/python/ir.rst
@@ -21,3 +21,11 @@ tvm.ir
:members:
:imported-members:
:autosummary:
+
+
+tvm.transform
+-------------
+.. automodule:: tvm.transform
+ :members:
+ :imported-members:
+ :autosummary:
diff --git a/docs/dev/convert_layout.rst b/docs/dev/convert_layout.rst
index 715d810..7345c15 100644
--- a/docs/dev/convert_layout.rst
+++ b/docs/dev/convert_layout.rst
@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
# Convert the layout to NCHW
# RemoveUnunsedFunctions is used to clean up the graph.
- seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
+ seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3):
mod = seq(mod)
diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst
index b42e128..3b443fa 100644
--- a/docs/dev/relay_pass_infra.rst
+++ b/docs/dev/relay_pass_infra.rst
@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
func = relay.Function([x], z2)
# Customize the optimization pipeline.
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for
.. code:: python
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 4c55204..3680f6d 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
+ * \param header The header to be attached to the output.
+ * \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
-TVM_DLL Pass PrintIR(std::string header);
+TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
} // namespace transform
} // namespace tvm
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index e091cd1..9a881cf 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -106,7 +106,7 @@ def create_updater_06_to_07():
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
- "relay.Sequantial": _rename("transform.Sequantial"),
+ "relay.Sequential": _rename("transform.Sequential"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index da74fb2..614f969 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -329,7 +329,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
return create_module_pass
-def PrintIR(header):
+def PrintIR(header="", show_meta_data=False):
"""A special trace pass that prints the header and IR.
Parameters
@@ -337,8 +337,11 @@ def PrintIR(header):
header : str
The header to be displayed along with the dump.
+ show_meta_data : bool
+ A boolean flag to indicate if meta data should be printed.
+
Returns
--------
The pass
"""
- return _ffi_transform_api.PrintIR(header)
+ return _ffi_transform_api.PrintIR(header, show_meta_data)
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 1517cf9..4e52019 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -128,20 +128,9 @@ Prelude = prelude.Prelude
# Scope builder
ScopeBuilder = scope_builder.ScopeBuilder
-module_pass = transform.module_pass
-function_pass = transform.function_pass
-
# Parser
fromtext = parser.fromtext
# Param Serialization
save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict
-
-# Pass manager
-PassInfo = transform.PassInfo
-PassContext = transform.PassContext
-Pass = transform.Pass
-ModulePass = transform.ModulePass
-FunctionPass = transform.FunctionPass
-Sequential = transform.Sequential
diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py
index 9c4be29..213a6c6 100644
--- a/python/tvm/relay/backend/interpreter.py
+++ b/python/tvm/relay/backend/interpreter.py
@@ -210,10 +210,10 @@ class Interpreter(Executor):
opt_mod : tvm.IRModule
The optimized module.
"""
- seq = transform.Sequential([transform.SimplifyInference(),
- transform.FuseOps(0),
- transform.ToANormalForm(),
- transform.InferType()])
+ seq = tvm.transform.Sequential([transform.SimplifyInference(),
+ transform.FuseOps(0),
+ transform.ToANormalForm(),
+ transform.InferType()])
return seq(self.mod)
def _make_executor(self, expr=None):
diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py
index 6d38490..492c739 100644
--- a/python/tvm/relay/qnn/transform.py
+++ b/python/tvm/relay/qnn/transform.py
@@ -60,7 +60,7 @@ def CanonicalizeOps():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that canonicalizes QNN ops to Relay ops.
"""
@@ -108,7 +108,7 @@ def Legalize():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that legalizes QNN ops.
"""
diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py
index 2ad4e18..958d0dc 100644
--- a/python/tvm/relay/quantize/quantize.py
+++ b/python/tvm/relay/quantize/quantize.py
@@ -17,6 +17,7 @@
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
import tvm.ir
+import tvm
from tvm.runtime import Object
from . import _quantize
@@ -240,7 +241,7 @@ def partition():
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizePartition()
@@ -253,7 +254,7 @@ def annotate():
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
@@ -267,7 +268,7 @@ def realize():
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass for quantization realization.
"""
return _quantize.QuantizeRealize()
@@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
- optimize = _transform.Sequential([_transform.SimplifyInference(),
- _transform.FoldConstant(),
- _transform.FoldScaleAxis(),
- _transform.CanonicalizeOps(),
- _transform.FoldConstant()])
+ optimize = tvm.transform.Sequential(
+ [_transform.SimplifyInference(),
+ _transform.FoldConstant(),
+ _transform.FoldScaleAxis(),
+ _transform.CanonicalizeOps(),
+ _transform.FoldConstant()])
if params:
mod['main'] = _bind_params(mod['main'], params)
@@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
"""
mod = prerequisite_optimize(mod, params)
- calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
- name="QuantizeCalibrate")
+ calibrate_pass = tvm.transform.module_pass(
+ calibrate(dataset), opt_level=1,
+ name="QuantizeCalibrate")
quant_passes = [partition(),
annotate(),
calibrate_pass]
if not current_qconfig().do_simulation:
quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant())
- quantize_seq = _transform.Sequential(quant_passes)
- with _transform.PassContext(opt_level=3,
- required_pass=["QuantizeAnnotate",
- "QuantizeCalibrate",
- "QuantizeRealize"]):
+ quantize_seq = tvm.transform.Sequential(quant_passes)
+ with tvm.transform.PassContext(opt_level=3,
+ required_pass=["QuantizeAnnotate",
+ "QuantizeCalibrate",
+ "QuantizeRealize"]):
with quantize_context():
mod = quantize_seq(mod)
diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py
index 54c9091..58c6fe8 100644
--- a/python/tvm/relay/testing/__init__.py
+++ b/python/tvm/relay/testing/__init__.py
@@ -47,7 +47,7 @@ from .py_converter import to_python, run_as_python
from ..transform import gradient
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py
index eec5e16..61a04ec 100644
--- a/python/tvm/relay/testing/py_converter.py
+++ b/python/tvm/relay/testing/py_converter.py
@@ -95,8 +95,8 @@ class PythonConverter(ExprFunctor):
# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
- opts = relay.transform.Sequential([relay.transform.SimplifyInference(),
- relay.transform.FuseOps(fuse_opt_level=0)])
+ opts = tvm.transform.Sequential([relay.transform.SimplifyInference(),
+ relay.transform.FuseOps(fuse_opt_level=0)])
mod = opts(mod)
optimized = mod['main']
return optimized if isinstance(unwrapped, Function) else optimized.body
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index 918894f..292c5fd 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -22,10 +22,9 @@ import types
import inspect
import functools
-import tvm
+import tvm.ir
from tvm import te
from tvm.runtime import ndarray as _nd
-from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass
from tvm import relay
from . import _ffi_api
@@ -78,12 +77,13 @@ def build_config(opt_level=2,
pass_context: PassContext
The pass context for optimizations.
"""
- return PassContext(opt_level, fallback_device, required_pass,
- disabled_pass, trace)
+ return tvm.ir.transform.PassContext(
+ opt_level, fallback_device, required_pass,
+ disabled_pass, trace)
@tvm._ffi.register_object("relay.FunctionPass")
-class FunctionPass(Pass):
+class FunctionPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
"""
@@ -94,7 +94,7 @@ def InferType():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered type inference pass.
"""
return _ffi_api.InferType()
@@ -106,7 +106,7 @@ def FoldScaleAxis():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass to fold expressions.
Note
@@ -123,7 +123,7 @@ def BackwardFoldScaleAxis():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass to backward fold expressions.
Note
@@ -144,7 +144,7 @@ def RemoveUnusedFunctions(entry_functions=None):
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass to remove unused functions.
"""
if entry_functions is None:
@@ -156,7 +156,7 @@ def ForwardFoldScaleAxis():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass to forward fold expressions.
Note
@@ -174,7 +174,7 @@ def SimplifyInference():
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass to perform operator simplification.
"""
return _ffi_api.SimplifyInference()
@@ -185,7 +185,7 @@ def FastMath():
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass to perform fast math operations.
"""
return _ffi_api.FastMath()
@@ -198,7 +198,7 @@ def CanonicalizeOps():
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass performing the canonicalization.
"""
return _ffi_api.CanonicalizeOps()
@@ -214,7 +214,7 @@ def DeadCodeElimination(inline_once=False):
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that eliminates the dead code in a Relay program.
"""
return _ffi_api.DeadCodeElimination(inline_once)
@@ -227,7 +227,7 @@ def LazyGradientInit():
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
"""
@@ -238,7 +238,7 @@ def FoldConstant():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass for constant folding.
"""
return _ffi_api.FoldConstant()
@@ -255,7 +255,7 @@ def FuseOps(fuse_opt_level=-1):
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass for operator fusion.
"""
return _ffi_api.FuseOps(fuse_opt_level)
@@ -272,7 +272,7 @@ def CombineParallelConv2D(min_num_branches=3):
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that combines parallel conv2d operators.
"""
return _ffi_api.CombineParallelConv2D(min_num_branches)
@@ -304,7 +304,7 @@ def CombineParallelDense(min_num_branches=3):
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that combines parallel dense operators.
"""
return _ffi_api.CombineParallelDense(min_num_branches)
@@ -318,7 +318,7 @@ def AlterOpLayout():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that alters the layout of operators.
"""
return _ffi_api.AlterOpLayout()
@@ -366,7 +366,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that rewrites an expr.
"""
return _ffi_api.Legalize(legalize_map_attr_name)
@@ -387,7 +387,7 @@ def MergeComposite(pattern_table):
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that merges operators into a single composite
relay function.
"""
@@ -413,7 +413,7 @@ def MergeCompilerRegions():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that merges compiler regions.
"""
return _ffi_api.MergeCompilerRegions()
@@ -433,7 +433,7 @@ def RewriteAnnotatedOps(fallback_device):
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
@@ -448,7 +448,7 @@ def ToANormalForm():
Returns
-------
- ret: Union[tvm.relay.Pass, tvm.relay.Expr]
+ ret: Union[tvm.transform.Pass, tvm.relay.Expr]
The registered pass that transforms an expression into A Normal Form.
"""
return _ffi_api.ToANormalForm()
@@ -462,7 +462,7 @@ def ToCPS(expr, mod=None):
Returns
-------
- result: tvm.relay.Pass
+ result: tvm.transform.Pass
The registered pass that transforms an expression into CPS.
"""
return _ffi_api.to_cps(expr, mod)
@@ -481,7 +481,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False):
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that eta expands an expression.
"""
return _ffi_api.EtaExpand(expand_constructor, expand_global_var)
@@ -492,7 +492,7 @@ def ToGraphNormalForm():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that transforms an expression into Graph Normal Form.
"""
return _ffi_api.ToGraphNormalForm()
@@ -509,7 +509,7 @@ def EliminateCommonSubexpr(fskip=None):
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that eliminates common subexpressions.
"""
return _ffi_api.EliminateCommonSubexpr(fskip)
@@ -527,7 +527,7 @@ def PartialEvaluate():
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that performs partial evaluation on an expression.
"""
return _ffi_api.PartialEvaluate()
@@ -539,7 +539,7 @@ def CanonicalizeCast():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that canonicalizes cast expression.
"""
return _ffi_api.CanonicalizeCast()
@@ -551,36 +551,19 @@ def LambdaLift():
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The registered pass that lifts the lambda function.
"""
return _ffi_api.LambdaLift()
-def PrintIR(show_meta_data=True):
- """
- Print the IR for a module to help debugging.
-
- Parameters
- ----------
- show_meta_data : bool
- A boolean flag to indicate if meta data should be printed.
-
- Returns
- -------
- ret : tvm.relay.Pass
- The registered pass that prints the module IR.
- """
- return _ffi_api.PrintIR(show_meta_data)
-
-
def PartitionGraph():
"""Partition a Relay program into regions that can be executed on different
backends.
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that partitions the Relay program.
"""
return _ffi_api.PartitionGraph()
@@ -598,7 +581,7 @@ def AnnotateTarget(targets):
Returns
-------
- ret : tvm.relay.Pass
+ ret : tvm.transform.Pass
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
@@ -614,7 +597,7 @@ def Inline():
Returns
-------
- ret: tvm.relay.Pass
+ ret: tvm.transform.Pass
The registered pass that performs inlining for a Relay IR module.
"""
return _ffi_api.Inline()
@@ -809,7 +792,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
def create_function_pass(pass_arg):
"""Internal function that creates a function pass"""
fname = name if name else pass_arg.__name__
- info = PassInfo(opt_level, fname, required)
+ info = tvm.transform.PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 0161cb3..c1547d5 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -474,10 +474,10 @@ TVM_REGISTER_GLOBAL("transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope);
-Pass PrintIR(std::string header) {
- auto pass_func =[header](IRModule mod, const PassContext& ctx) {
+Pass PrintIR(std::string header, bool show_meta_data) {
+ auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n"
- << mod;
+ << AsText(mod, show_meta_data);
return mod;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
diff --git a/src/relay/transforms/print_ir.cc b/src/relay/transforms/print_ir.cc
deleted file mode 100644
index cf06b50..0000000
--- a/src/relay/transforms/print_ir.cc
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *
- * \file src/relay/transforms/print_ir.cc
- *
- * \brief Print the module IR to help debugging.
- */
-#include <tvm/relay/expr.h>
-#include <tvm/relay/transform.h>
-
-namespace tvm {
-namespace relay {
-
-namespace transform {
-
-Pass PrintIR(bool show_meta_data) {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
- [=](IRModule m, PassContext pc) {
- LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data);
- return m;
- };
- return CreateModulePass(pass_func, 0, "PrintIR", {});
-}
-
-TVM_REGISTER_GLOBAL("relay._transform.PrintIR")
-.set_body_typed(PrintIR);
-
-} // namespace transform
-
-} // namespace relay
-} // namespace tvm
diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py
index 30e2506..5e57c80 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -53,10 +53,10 @@ def test_checkpoint_alpha_equal():
df = transform.gradient(run_infer_type(f))
# run PE and DCE
- with transform.PassContext(opt_level=3):
+ with tvm.transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
- mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df))
+ mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"]
df_parsed = relay.parser.fromtext(
@@ -109,10 +109,10 @@ def test_checkpoint_alpha_equal_tuple():
df = transform.gradient(run_infer_type(f))
# run PE and DCE
- with transform.PassContext(opt_level=3):
+ with tvm.transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
- mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df))
+ mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"]
df_parsed = relay.parser.fromtext(
diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py
index a30492f..2a2e265 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -26,8 +26,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py
index ea92546..582d46a 100644
--- a/tests/python/relay/test_pass_annotation.py
+++ b/tests/python/relay/test_pass_annotation.py
@@ -28,8 +28,8 @@ from tvm.relay import transform
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod["main"]
diff --git a/tests/python/relay/test_pass_canonicalize_cast.py b/tests/python/relay/test_pass_canonicalize_cast.py
index 7b6617a..e13547b 100644
--- a/tests/python/relay/test_pass_canonicalize_cast.py
+++ b/tests/python/relay/test_pass_canonicalize_cast.py
@@ -54,9 +54,9 @@ def test_canonicalize_cast():
bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
y = before(data, conv_weight, bias1, bias2)
mod = tvm.IRModule.from_expr(y)
- seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
+ seq = tvm.transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
_transform.InferType()])
- with _transform.PassContext(opt_level=3):
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
y = mod["main"]
y_expected = expected(data, conv_weight, bias1, bias2)
diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py
index 345f068..7f7f185 100644
--- a/tests/python/relay/test_pass_combine_parallel_conv2d.py
+++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return mod["main"]
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
return mod["main"]
diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py
index f0f2e18..12beafb 100644
--- a/tests/python/relay/test_pass_combine_parallel_dense.py
+++ b/tests/python/relay/test_pass_combine_parallel_dense.py
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return mod["main"]
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
return mod["main"]
diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py
index c783971..c5a7b0e 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -26,8 +26,8 @@ from tvm.relay import transform, analysis
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py
index 60dfa62..35fd444 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -47,7 +47,7 @@ e = env()
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py
index 89e3b67..7af524d 100644
--- a/tests/python/relay/test_pass_eliminate_common_subexpr.py
+++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py
@@ -24,7 +24,7 @@ from tvm.relay import transform, analysis
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py
index 84ff54a..e0a189b 100644
--- a/tests/python/relay/test_pass_eta_expand.py
+++ b/tests/python/relay/test_pass_eta_expand.py
@@ -33,8 +33,8 @@ def test_eta_expand_global_var():
@aux
}
""")
- seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
- with _transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
@@ -62,8 +62,8 @@ def test_eta_expand_constructor():
Cons
}
""")
- seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
- with _transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index 3ddafd7..4f44d2b 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -24,7 +24,7 @@ from tvm.relay.testing import run_infer_type, create_workload
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
@@ -174,7 +174,7 @@ def test_fold_batch_norm():
add = relay.add(conv, bias)
return relay.Function(relay.analysis.free_vars(add), add)
- remove_bn_pass = transform.Sequential([
+ remove_bn_pass = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.SimplifyInference(),
relay.transform.FoldConstant(),
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py
index bf2a708..d7c437a 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -26,7 +26,7 @@ def _get_positive_scale(size):
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py
index f9c762e..4149268 100644
--- a/tests/python/relay/test_pass_lazy_gradient_init.py
+++ b/tests/python/relay/test_pass_lazy_gradient_init.py
@@ -80,7 +80,7 @@ def test_add_tuple():
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
- mod = transform.PrintIR(show_meta_data=True)(mod)
+ mod = tvm.transform.PrintIR(show_meta_data=True)(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], tensor_type)
@@ -116,7 +116,7 @@ def test_mult():
def test_ret_tuple():
"""Test tuple return type. Check types and semantic equivalence."""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
@@ -141,7 +141,7 @@ def test_ret_tuple():
def test_add_broadcast():
"""Test adding matrices of different size. Check types and semantic equivalence."""
mod = tvm.IRModule()
-
+
shape1 = (3, 4, 1)
shape2 = (1, 5)
dtype = 'float32'
@@ -173,7 +173,7 @@ def test_reverse_ad_identity():
"""Simple test with reverse mode ad."""
# of f(x) = x
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
@@ -201,7 +201,7 @@ def test_reverse_ad_identity():
def test_multivar_reverse_ad():
"""Simple test with multivariate reverse mode ad."""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
@@ -232,7 +232,7 @@ def test_multivar_reverse_ad():
def test_after_partial_eval():
"""Test transformation following reverse mode ad and PartialEval"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
@@ -248,7 +248,7 @@ def test_after_partial_eval():
mod["main"] = back_func
back_func = mod["main"]
- seq = transform.Sequential([
+ seq = tvm.transform.Sequential([
transform.PartialEvaluate(),
transform.LazyGradientInit(),
transform.DeadCodeElimination()
@@ -270,7 +270,7 @@ def test_after_partial_eval():
def test_before_partial_eval():
"""Test transformation before PartialEval"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
@@ -284,7 +284,7 @@ def test_before_partial_eval():
back_func = run_infer_type(back_func)
mod["main"] = back_func
- seq = transform.Sequential([
+ seq = tvm.transform.Sequential([
transform.LazyGradientInit(),
transform.PartialEvaluate(),
transform.DeadCodeElimination()
@@ -306,7 +306,7 @@ def test_before_partial_eval():
def test_zeros():
"""Simple test using "zeros" op"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
@@ -328,7 +328,7 @@ def test_zeros():
def test_ones():
"""Simple test using "ones" op"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
@@ -350,7 +350,7 @@ def test_ones():
def test_zeros_like():
"""Simple test using "zeros_like" op"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
@@ -372,7 +372,7 @@ def test_zeros_like():
def test_ones_like():
"""Simple test using "ones_like" op"""
mod = tvm.IRModule()
-
+
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py
index 1456700..0882149 100644
--- a/tests/python/relay/test_pass_legalize.py
+++ b/tests/python/relay/test_pass_legalize.py
@@ -28,8 +28,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
diff --git a/tests/python/relay/test_pass_mac_count.py b/tests/python/relay/test_pass_mac_count.py
index 697aad8..d490ac7 100644
--- a/tests/python/relay/test_pass_mac_count.py
+++ b/tests/python/relay/test_pass_mac_count.py
@@ -23,7 +23,7 @@ from tvm.relay import analysis, transform
def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py
index 0a6555b..28ccf6f 100644
--- a/tests/python/relay/test_pass_manager.py
+++ b/tests/python/relay/test_pass_manager.py
@@ -129,13 +129,13 @@ def test_module_pass():
opt_tester = OptTester(mod)
pass_ctx = None
- @_transform.module_pass(opt_level=opt_level, name=pass_name)
+ @tvm.transform.module_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx):
return opt_tester.transform(expr, ctx)
def test_pass_registration():
mod_pass = transform
- assert isinstance(mod_pass, _transform.ModulePass)
+ assert isinstance(mod_pass, tvm.transform.ModulePass)
pass_info = mod_pass.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
@@ -143,8 +143,8 @@ def test_module_pass():
def test_pass_registration_no_decorator():
def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
- mod_pass = _transform.module_pass(direct_transform, opt_level=3)
- assert isinstance(mod_pass, _transform.ModulePass)
+ mod_pass = tvm.transform.module_pass(direct_transform, opt_level=3)
+ assert isinstance(mod_pass, tvm.transform.ModulePass)
pass_info = mod_pass.info
assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 3
@@ -285,7 +285,7 @@ def test_function_pass():
def test_module_class_pass():
- @relay.transform.module_pass(opt_level=1)
+ @tvm.transform.module_pass(opt_level=1)
class TestPipeline:
"""Simple test function to replace one argument to another."""
def __init__(self, new_mod, replace):
@@ -309,7 +309,7 @@ def test_module_class_pass():
def test_pass_info():
- info = relay.transform.PassInfo(opt_level=1, name="xyz")
+ info = tvm.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1
assert info.name == "xyz"
@@ -350,7 +350,7 @@ def test_sequential_pass():
opt_tester = OptTester(mod)
pass_ctx = None
- @_transform.module_pass(opt_level=1)
+ @tvm.transform.module_pass(opt_level=1)
def mod_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
@@ -367,21 +367,21 @@ def test_sequential_pass():
passes = [module_pass, function_pass]
opt_level = 2
pass_name = "sequential"
- sequential = _transform.Sequential(passes=passes, opt_level=opt_level)
+ sequential = tvm.transform.Sequential(passes=passes, opt_level=opt_level)
pass_info = sequential.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
def test_no_pass():
passes = []
- sequential = _transform.Sequential(opt_level=1, passes=passes)
+ sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential(mod)
mod_func = ret_mod[v_sub]
check_func(sub, mod_func)
def test_only_module_pass():
passes = [module_pass]
- sequential = _transform.Sequential(opt_level=1, passes=passes)
+ sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with relay.build_config(required_pass=["mod_transform"]):
ret_mod = sequential(mod)
# Check the subtract function.
@@ -396,7 +396,7 @@ def test_sequential_pass():
def test_only_function_pass():
# Check the subtract function.
passes = [function_pass]
- sequential = _transform.Sequential(opt_level=1, passes=passes)
+ sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with relay.build_config(required_pass=["func_transform"]):
ret_mod = sequential(mod)
_, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
@@ -411,7 +411,7 @@ def test_sequential_pass():
# function pass.
mod = tvm.IRModule({v_sub: sub, v_log: log})
passes = [module_pass, function_pass]
- sequential = _transform.Sequential(opt_level=1, passes=passes)
+ sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
required = ["mod_transform", "func_transform"]
with relay.build_config(required_pass=required):
ret_mod = sequential(mod)
@@ -482,7 +482,7 @@ def test_sequential_with_scoping():
z1 = relay.add(z, z)
return relay.Function([x], z1)
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
@@ -507,10 +507,10 @@ def test_print_ir(capfd):
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
- relay.transform.PrintIR(),
+ tvm.transform.PrintIR(),
relay.transform.DeadCodeElimination()
])
@@ -520,7 +520,7 @@ def test_print_ir(capfd):
out = capfd.readouterr().err
- assert "Dumping the module IR" in out
+ assert "PrintIR" in out
assert "multiply" in out
__TRACE_COUNTER__ = 0
@@ -539,7 +539,7 @@ def test_print_debug_callback():
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)
- seq = _transform.Sequential([
+ seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.DeadCodeElimination()
diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py
index 0f3eea6..45593b4 100644
--- a/tests/python/relay/test_pass_partial_eval.py
+++ b/tests/python/relay/test_pass_partial_eval.py
@@ -38,8 +38,8 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
@@ -58,7 +58,7 @@ def dcpe(expr, mod=None, grad=False):
if mod:
assert isinstance(expr, Function)
mod["main"] = expr
- seq = transform.Sequential(passes)
+ seq = tvm.transform.Sequential(passes)
mod = seq(mod)
return mod["main"]
return run_opt_pass(expr, passes)
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 5148d4e..2ee8538 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -496,7 +496,7 @@ def test_function_lifting():
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
- opt_pass = transform.Sequential([
+ opt_pass = tvm.transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
@@ -578,7 +578,7 @@ def test_function_lifting_inline():
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
- opt_pass = transform.Sequential([
+ opt_pass = tvm.transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
@@ -878,13 +878,13 @@ def test_dnnl_fuse():
# This is required for constant folding
mod["main"] = bind_params_by_name(mod["main"], params)
- remove_bn_pass = transform.Sequential([
+ remove_bn_pass = tvm.transform.Sequential([
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
])
- composite_partition = transform.Sequential([
+ composite_partition = tvm.transform.Sequential([
remove_bn_pass,
transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"),
diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py
index c291c4e..5f7deff 100644
--- a/tests/python/relay/test_pass_qnn_legalize.py
+++ b/tests/python/relay/test_pass_qnn_legalize.py
@@ -37,8 +37,8 @@ def alpha_equal(x, y):
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py
index d7babf3..5a63db7 100644
--- a/tests/python/relay/test_pass_to_a_normal_form.py
+++ b/tests/python/relay/test_pass_to_a_normal_form.py
@@ -28,8 +28,8 @@ from tvm.relay.analysis import Feature
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
- seq = transform.Sequential(passes)
- with transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py
index 4aaa9a0..6edf185 100644
--- a/tests/python/relay/test_pass_to_cps.py
+++ b/tests/python/relay/test_pass_to_cps.py
@@ -71,7 +71,8 @@ def test_cps_pe():
x = run_infer_type(x)
y = un_cps(x)
y = run_infer_type(y)
- x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
+ x = run_opt_pass(x, tvm.transform.Sequential(
+ [transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
assert Feature.fRefCreate not in detect_feature(x)
unit = relay.Function([], relay.const(0., dtype='float32'))
f_ref = relay.Var("f_ref")
diff --git a/tutorials/dev/relay_pass_infra.py b/tutorials/dev/relay_pass_infra.py
index 6b844ff..980d96c 100644
--- a/tutorials/dev/relay_pass_infra.py
+++ b/tutorials/dev/relay_pass_infra.py
@@ -29,7 +29,7 @@ introduced an infrastructure to manage the optimization passes.
The optimizations of a Relay program could be applied at various granularity,
namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass`
and py:class:`tvm.relay.transform.ModulePass`
-respectively. Or users can rely on py:class:`tvm.relay.transform.Sequential` to apply a sequence of passes
+respectively. Or users can rely on py:class:`tvm.transform.Sequential` to apply a sequence of passes
on a Relay program where the dependencies between passes can be resolved by the
pass infra. For more details about each type of these passes, please refer to
the :ref:`relay-pass-infra`
@@ -130,22 +130,22 @@ print(mod)
# fusion, as this pass generates let bindings for each expression to
# canonicalize a Relay program.
#
-# Relay, hence, provides :py:class:`tvm.relay.transform.Sequential` to alleviate developers from handling
+# Relay, hence, provides :py:class:`tvm.transform.Sequential` to alleviate developers from handling
# these issues explicitly by specifying the required passes of each pass and
# packing them as a whole to execute. For example, the same passes can now be
-# applied using the sequential style as the following. :py:class:`tvm.relay.transform.Sequential` is
+# applied using the sequential style as the following. :py:class:`tvm.transform.Sequential` is
# similiar to `torch.nn.sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_
# and `mxnet.gluon.block <https://mxnet.incubator.apache.org/api/python/docs/_modules/mxnet/gluon/block.html>`_.
# For example, `torch.nn.sequential` is used to contain a sequence of PyTorch
# `Modules` that will be added to build a network. It focuses on the network
-# layers. Instead, the :py:class:`tvm.relay.transform.Sequential` in our pass infra works on the optimizing
+# layers. Instead, the :py:class:`tvm.transform.Sequential` in our pass infra works on the optimizing
# pass.
-# Now let's execute some passes through :py:class:`tvm.relay.transform.Sequential`
+# Now let's execute some passes through :py:class:`tvm.transform.Sequential`
f = example()
mod = tvm.IRModule.from_expr(f)
# Glob the interested passes.
-seq = relay.transform.Sequential([relay.transform.FoldConstant(),
+seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
relay.transform.FuseOps(fuse_opt_level=2)])
mod1 = seq(mod)
@@ -156,7 +156,7 @@ print(mod1)
# identical addition operations. This is because `EliminateCommonSubexpr`
# was not actually performed. The reason is because only the passes that have
# optimization level less or equal to 2 will be executed by default under
-# :py:class:`tvm.relay.transform.Sequential`. The pass infra,
+# :py:class:`tvm.transform.Sequential`. The pass infra,
# however, provides a configuration interface
# for users to customize the optimization level that they want to execute.
@@ -186,7 +186,7 @@ with relay.build_config(opt_level=3):
mod4 = seq(mod)
print(mod4)
-seq1 = relay.transform.Sequential([relay.transform.AlterOpLayout()])
+seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()])
with relay.build_config(opt_level=3):
with tvm.target.create("llvm"):
mod5 = seq1(mod)
@@ -237,11 +237,11 @@ print(mod3)
f = example()
mod = tvm.IRModule.from_expr(f)
-seq = relay.transform.Sequential([relay.transform.FoldConstant(),
- relay.transform.PrintIR(False),
- relay.transform.EliminateCommonSubexpr(),
- relay.transform.FuseOps(),
- relay.transform.PrintIR(False)])
+seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
+ tvm.transform.PrintIR(),
+ relay.transform.EliminateCommonSubexpr(),
+ relay.transform.FuseOps(),
+ tvm.transform.PrintIR()])
with relay.build_config(opt_level=3):
mod = seq(mod)
diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py
index aca00a6..2334de7 100644
--- a/vta/python/vta/top/graphpack.py
+++ b/vta/python/vta/top/graphpack.py
@@ -24,7 +24,7 @@ from tvm.relay import ExprMutator
def run_opt_pass(expr, opt_pass):
"""Exectue a relay pass."""
- assert isinstance(opt_pass, transform.Pass)
+ assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]