You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by kp...@apache.org on 2023/06/02 00:35:48 UTC
[tvm] branch main updated: [IR,TE,TIR] Use f-strings for string formatting, NFC (#14990)
This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 00126b0484 [IR,TE,TIR] Use f-strings for string formatting, NFC (#14990)
00126b0484 is described below
commit 00126b04841d2e7716e615a46fe2c9a78b17ee35
Author: Krzysztof Parzyszek <kp...@quicinc.com>
AuthorDate: Thu Jun 1 19:35:41 2023 -0500
[IR,TE,TIR] Use f-strings for string formatting, NFC (#14990)
* [IR,TE,TIR] Use f-strings for string formatting, NFC
Replace uses of % and .format() with f-strings.
Reformat modified files.
* Rearrange pylint directives for better formatting
---
python/tvm/ir/container.py | 4 +-
python/tvm/ir/expr.py | 6 +-
python/tvm/ir/json_compact.py | 4 +-
python/tvm/te/hybrid/module.py | 2 +-
python/tvm/te/hybrid/parser.py | 10 +-
python/tvm/te/hybrid/preprocessor.py | 4 +-
python/tvm/te/hybrid/utils.py | 6 +-
python/tvm/te/operation.py | 11 +--
python/tvm/te/schedule.py | 2 +-
python/tvm/te/tag.py | 2 +-
python/tvm/te/tensor.py | 4 +-
python/tvm/tir/buffer.py | 10 +-
python/tvm/tir/ir_builder.py | 2 +-
python/tvm/tir/schedule/schedule.py | 170 +++++++++++---------------------
python/tvm/tir/schedule/state.py | 9 +-
python/tvm/tir/tensor_intrin/arm_cpu.py | 8 +-
python/tvm/tir/tensor_intrin/cuda.py | 17 ++--
17 files changed, 96 insertions(+), 175 deletions(-)
diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py
index 3c7a57a830..e35c61c05c 100644
--- a/python/tvm/ir/container.py
+++ b/python/tvm/ir/container.py
@@ -46,7 +46,7 @@ class Array(Object):
raise AttributeError("handle is not set")
if name == "type_key":
return super().__getattr__(name)
- raise AttributeError("%s has no attribute %s" % (str(type(self)), name))
+ raise AttributeError(f"{type(self)} has no attribute {name}")
@tvm._ffi.register_object
@@ -77,7 +77,7 @@ class Map(Object):
raise AttributeError("handle is not set")
if name == "type_key":
return super().__getattr__(name)
- raise AttributeError("%s has no attribute %s" % (str(type(self)), name))
+ raise AttributeError(f"{type(self)} has no attribute {name}")
def keys(self):
return iter(self)
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index 1c775b461e..5a83d8e5d9 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -50,7 +50,7 @@ class RelayExpr(BaseExpr):
"""
ret = self._checked_type_
if ret is None:
- raise ValueError("The type checker has not populated" " the checked_type for this node")
+ raise ValueError("The type checker has not populated the checked_type for this node")
return ret
@@ -92,9 +92,7 @@ class GlobalVar(RelayExpr):
return tvm.tir.call_tir(self, *args)
arg_types = [type(x) for x in args]
- raise RuntimeError(
- "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)
- )
+ raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for types {arg_types}")
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index 8e9d3550ca..6ce2a8b9e2 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -153,7 +153,7 @@ def create_updater_06_to_07():
val = jdata["nodes"][root_idx]
sidx = len(nodes)
nodes.append(val)
- item["attrs"][key] = "%d" % sidx
+ item["attrs"][key] = f"{sidx}"
return item
return _convert
@@ -260,5 +260,5 @@ def upgrade_json(json_str):
elif from_version.startswith("0.8"):
data = create_updater_08_to_09()(data)
else:
- raise ValueError("Cannot update from version %s" % from_version)
+ raise ValueError(f"Cannot update from version {from_version}")
return json.dumps(data, indent=2)
diff --git a/python/tvm/te/hybrid/module.py b/python/tvm/te/hybrid/module.py
index af6270045b..729805b31b 100644
--- a/python/tvm/te/hybrid/module.py
+++ b/python/tvm/te/hybrid/module.py
@@ -51,7 +51,7 @@ class HybridModule(object):
temp = utils.tempdir()
dst = temp.relpath("script.py")
with open(dst, "w") as f:
- f.write("import tvm\n@tvm.te.hybrid.script\n%s" % src)
+ f.write(f"import tvm\n@tvm.te.hybrid.script\n{src}")
if name is not None:
self.name = name
diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py
index bd47e41630..846ef818ea 100644
--- a/python/tvm/te/hybrid/parser.py
+++ b/python/tvm/te/hybrid/parser.py
@@ -161,9 +161,7 @@ class HybridParser(ast.NodeVisitor):
if key in self.symbols.keys():
old = str(self.symbols[key])
new = str((ty, val))
- _internal_assert(
- False, "Name conflict in symbol table! [%s] %s -> %s" % (key, old, new)
- )
+ _internal_assert(False, f"Name conflict in symbol table! [{key}] {old} -> {new}")
self.symbols[key] = ty, val
@@ -188,7 +186,7 @@ class HybridParser(ast.NodeVisitor):
continue
if level != node:
continue
- _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)
+ _internal_assert(key in self.symbols.keys(), f"Unknown symbol {key}!")
ty, entry = self.symbols[key] # pylint: disable=invalid-name
if ty in [Symbol.Input, Symbol.OutputBuffer]:
@@ -254,7 +252,7 @@ class HybridParser(ast.NodeVisitor):
return tvm.runtime.convert(self.closure_vars[name])
ty, entry = self.symbols[name]
- _internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
+ _internal_assert(name in self.symbols, f"Unknown symbol {name}!")
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
return entry
if ty is Symbol.ThreadBind:
@@ -473,7 +471,7 @@ class HybridParser(ast.NodeVisitor):
# Contexts'
_internal_assert(
func_id in self.symbols.keys(),
- "The function called (%s) is not in the context either!" % func_id,
+ f"The function called ({func_id}) is not in the context either!",
)
ty, entry = self.symbols[func_id]
_internal_assert(ty is Symbol.Callable, "Are you sure what you call is a function?!")
diff --git a/python/tvm/te/hybrid/preprocessor.py b/python/tvm/te/hybrid/preprocessor.py
index 295476f808..6af584060e 100644
--- a/python/tvm/te/hybrid/preprocessor.py
+++ b/python/tvm/te/hybrid/preprocessor.py
@@ -100,7 +100,7 @@ class PyVariableUsage(ast.NodeVisitor):
raise ValueError("Only support capturing constant values in closure")
return
- _internal_assert(isinstance(node.ctx, ast.Store), "Undeclared variable %s" % node.id)
+ _internal_assert(isinstance(node.ctx, ast.Store), f"Undeclared variable {node.id}")
if self.aug_assign_:
raise ValueError('"First store" cannot be an AugAssign')
self.status[node.id] = (node, self.scope_level[-1], set())
@@ -108,7 +108,7 @@ class PyVariableUsage(ast.NodeVisitor):
decl, loop, usage = self.status[node.id]
usage.add(type(node.ctx))
_internal_assert(
- loop in self.scope_level, "%s is used out of the scope it is defined!" % node.id
+ loop in self.scope_level, f"{node.id} is used out of the scope it is defined!"
)
self.status[node.id] = (decl, loop, usage)
diff --git a/python/tvm/te/hybrid/utils.py b/python/tvm/te/hybrid/utils.py
index 939cca45a3..f653b3e83d 100644
--- a/python/tvm/te/hybrid/utils.py
+++ b/python/tvm/te/hybrid/utils.py
@@ -95,15 +95,15 @@ def _is_tvm_arg_types(args):
for elem in args[1:]:
_internal_assert(
isinstance(elem, tvm_arg_types),
- "Expecting a Var, Tensor or ConstExpr instance but %s get!" % str(type(elem)),
+ f"Expecting a Var, Tensor or ConstExpr instance but {type(elem)} get!",
)
return True
_internal_assert(
- isinstance(args[0], np_arg_types), "Expect a numpy type but %s get!" % str(type(args[0]))
+ isinstance(args[0], np_arg_types), f"Expect a numpy type but {type(args[0])} get!"
)
for elem in args[1:]:
_internal_assert(
- isinstance(elem, np_arg_types), "Expect a numpy type but %s get!" % str(type(elem))
+ isinstance(elem, np_arg_types), f"Expect a numpy type but {type(elem)} get!"
)
return False
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 59bc76f504..1a28f9bb3d 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -100,7 +100,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=N
argspec = inspect.getfullargspec(fcompute)
if len(argspec.args) == 0 and argspec.varargs is None:
- arg_names = ["i%d" % i for i in range(out_ndim)]
+ arg_names = [f"i{i}" for i in range(out_ndim)]
elif argspec.varargs is not None:
# if there is a varargs, it takes the remaining dimensions of out_ndim
num_remaining_args = out_ndim - len(argspec.args)
@@ -125,7 +125,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=N
if out_ndim != len(arg_names):
raise ValueError(
"Number of args to fcompute does not match dimension, "
- "args=%d, dimension=%d" % (len(arg_names), out_ndim)
+ f"args={len(arg_names)}, dimension={out_ndim}"
)
dim_var = [tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
@@ -218,7 +218,7 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr
inputs = []
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
- axis = tvm.tir.IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
+ axis = tvm.tir.IterVar((init[0].shape[0], update[0].shape[0]), f"{name}.idx", 3)
op = _ffi_api.ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs)
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
@@ -351,9 +351,8 @@ def extern(
body = tvm.tir.Evaluate(body)
if not isinstance(body, tvm.tir.Stmt):
raise ValueError(
- "Function '{}' should return PrimExpr or Stmt, but it returned '{}'".format(
- fcompute.__name__, type(body)
- )
+ f"Function '{fcompute.__name__}' should return PrimExpr or Stmt, but it returned "
+ f"'{type(body)}'"
)
op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body)
diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py
index 50f9a22ec2..3dbf2cefe4 100644
--- a/python/tvm/te/schedule.py
+++ b/python/tvm/te/schedule.py
@@ -74,7 +74,7 @@ class Schedule(Object):
if not isinstance(k, _tensor.Operation):
raise ValueError("Expect schedule key to be Tensor or Operation")
if k not in self.stage_map:
- raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
+ raise ValueError(f"Cannot find the operation {k} in schedule")
return self.stage_map[k]
def normalize(self):
diff --git a/python/tvm/te/tag.py b/python/tvm/te/tag.py
index 42d2134a03..87a0034400 100644
--- a/python/tvm/te/tag.py
+++ b/python/tvm/te/tag.py
@@ -47,7 +47,7 @@ class TagScope(object):
def __exit__(self, ptype, value, trace):
assert self._old_scope is None
if not self.accessed:
- warnings.warn("Tag '%s' declared via TagScope was not used." % (self.tag,))
+ warnings.warn(f"Tag '{self.tag}' declared via TagScope was not used.")
TagScope._current = self._old_scope
def __call__(self, fdecl):
diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py
index fc85d830c9..d435e821ac 100644
--- a/python/tvm/te/tensor.py
+++ b/python/tvm/te/tensor.py
@@ -61,7 +61,7 @@ class Tensor(DataProducer, _expr.ExprOp):
ndim = self.ndim
if len(indices) != ndim:
raise ValueError(
- "Need to provide %d index in tensor but %d was provided" % (ndim, len(indices))
+ f"Need to provide {ndim} index in tensor but {len(indices)} was provided"
)
indices = convert_to_object(indices)
args = []
@@ -124,7 +124,7 @@ class Tensor(DataProducer, _expr.ExprOp):
op = self.op
if op.num_outputs == 1:
return op.name
- return "%s.v%d" % (op.name, self.value_index)
+ return f"{op.name}.v{self.value_index}"
class Operation(Object):
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index 11db28e20a..764b8a3dd3 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -93,7 +93,7 @@ class Buffer(Object, Scriptable):
elif value == "w":
mask = mask | Buffer.WRITE
else:
- raise ValueError("Unknown access_mask %s" % access_mask)
+ raise ValueError(f"Unknown access_mask {access_mask}")
access_mask = mask
offset = convert(offset)
extent = convert(extent)
@@ -179,11 +179,7 @@ class Buffer(Object, Scriptable):
def __getitem__(self, indices):
from ..arith import Analyzer # pylint: disable=import-outside-toplevel
- from .expr import ( # pylint: disable=import-outside-toplevel
- BufferLoad,
- Ramp,
- const,
- )
+ from .expr import BufferLoad, Ramp, const # pylint: disable=import-outside-toplevel
from .stmt import BufferRegion # pylint: disable=import-outside-toplevel
if not isinstance(indices, (tuple, list)):
@@ -344,7 +340,7 @@ def decl_buffer(
if offset_factor != 0 and elem_offset is None:
shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32"
- elem_offset = Var("%s_elem_offset" % name, shape_dtype)
+ elem_offset = Var(f"{name}_elem_offset", shape_dtype)
if data is None:
# Bool is represented as uint1 in the IR, but stored as int8
storage_type = PrimType(dtype)
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index ce8cd1b403..50de995a91 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -112,7 +112,7 @@ class BufferVar(ObjectGeneric):
content_element = self._content_type.split("x", maxsplit=1)[0]
if value_element != content_element:
raise ValueError(
- "data type does not match content type %s vs %s" % (value.dtype, self._content_type)
+ f"data type does not match content type {value.dtype} vs {self._content_type}"
)
self._builder.emit(_stmt.BufferStore(self._buffer, value, index))
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index 5cfb12b1f4..8996ad49a6 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -65,11 +65,7 @@ ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer
RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name
# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8
-_ERROR_RENDER_LEVEL: Dict[str, int] = {
- "detail": 0,
- "fast": 1,
- "none": 2,
-}
+_ERROR_RENDER_LEVEL: Dict[str, int] = {"detail": 0, "fast": 1, "none": 2}
def _parse_error_render_level(error_render_level: str) -> int:
@@ -83,9 +79,7 @@ def _parse_error_render_level(error_render_level: str) -> int:
def _parse_enable_checks(enable_checks: bool) -> bool:
if not isinstance(enable_checks, bool):
- raise TypeError(
- "enable_checks only accepts bool value, got {} instead".format(type(enable_checks))
- )
+ raise TypeError(f"enable_checks only accepts bool value, got {type(enable_checks)} instead")
return enable_checks
@@ -290,8 +284,7 @@ class Schedule(Object):
@type_checked
def get(
- self,
- rand_var_or_sref: Union[RAND_VAR_TYPE, StmtSRef],
+ self, rand_var_or_sref: Union[RAND_VAR_TYPE, StmtSRef]
) -> Optional[Union[int, Block, For]]:
"""Returns:
- the corresponding Block that a BlockRV evaluates to;
@@ -312,7 +305,8 @@ class Schedule(Object):
"""
if isinstance(rand_var_or_sref, StmtSRef):
return rand_var_or_sref.stmt
- result = _ffi_api.ScheduleGet(self, rand_var_or_sref) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ result = _ffi_api.ScheduleGet(self, rand_var_or_sref) # type: ignore
if isinstance(result, IntImm):
result = result.value
return result
@@ -354,10 +348,7 @@ class Schedule(Object):
@type_checked
def sample_categorical(
- self,
- candidates: List[int],
- probs: List[float],
- decision: Optional[int] = None,
+ self, candidates: List[int], probs: List[float], decision: Optional[int] = None
) -> ExprRV:
"""Sample an integer given the probability distribution
@@ -376,10 +367,7 @@ class Schedule(Object):
The random variable sampled from candidates
"""
return _ffi_api.ScheduleSampleCategorical( # type: ignore # pylint: disable=no-member
- self,
- candidates,
- probs,
- decision,
+ self, candidates, probs, decision
)
@type_checked
@@ -410,19 +398,13 @@ class Schedule(Object):
"""
return list(
_ffi_api.ScheduleSamplePerfectTile( # type: ignore # pylint: disable=no-member
- self,
- loop,
- n,
- max_innermost_factor,
- decision,
+ self, loop, n, max_innermost_factor, decision
)
)
@type_checked
def sample_compute_location(
- self,
- block: Union[BlockRV, str],
- decision: Optional[int] = None,
+ self, block: Union[BlockRV, str], decision: Optional[int] = None
) -> LoopRV:
"""Sample a compute-at location of the given block
@@ -441,18 +423,12 @@ class Schedule(Object):
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member
- self,
- block,
- decision,
+ self, block, decision
)
########## Schedule: Get blocks & loops ##########
@type_checked
- def get_block(
- self,
- name: str,
- func_name: Optional[str] = None,
- ) -> BlockRV:
+ def get_block(self, name: str, func_name: Optional[str] = None) -> BlockRV:
"""Retrieve a block in a specific function with its name
By default, if `func_name` is not specified, the schedule will search for the block in the
@@ -473,9 +449,7 @@ class Schedule(Object):
IndexError is raised if 0 or multiple blocks exist with the specific name.
"""
return _ffi_api.ScheduleGetBlock( # type: ignore # pylint: disable=no-member
- self,
- name,
- func_name,
+ self, name, func_name
)
@type_checked
@@ -493,7 +467,8 @@ class Schedule(Object):
A list of loops above the given block in its scope, from outer to inner
"""
block = self._normalize_block_arg(block)
- return list(_ffi_api.ScheduleGetLoops(self, block)) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return list(_ffi_api.ScheduleGetLoops(self, block)) # type: ignore
@type_checked
def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockRV]:
@@ -509,7 +484,8 @@ class Schedule(Object):
blocks : List[LoopRV]
A list of leaf blocks inside a specific block/loop
"""
- return list(_ffi_api.ScheduleGetChildBlocks(self, block_or_loop)) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return list(_ffi_api.ScheduleGetChildBlocks(self, block_or_loop)) # type: ignore
@type_checked
def get_producers(self, block: Union[BlockRV, str]) -> List[BlockRV]:
@@ -526,7 +502,8 @@ class Schedule(Object):
A list of producers of the given block
"""
block = self._normalize_block_arg(block)
- return list(_ffi_api.ScheduleGetProducers(self, block)) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return list(_ffi_api.ScheduleGetProducers(self, block)) # type: ignore
@type_checked
def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]:
@@ -543,13 +520,11 @@ class Schedule(Object):
A list of consumers of the given block
"""
block = self._normalize_block_arg(block)
- return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore
@type_checked
- def get_output_blocks(
- self,
- scope_block: Union[BlockRV, str],
- ) -> List[BlockRV]:
+ def get_output_blocks(self, scope_block: Union[BlockRV, str]) -> List[BlockRV]:
"""Get the list of output blocks within the given scope
An output block is a block which has atleast one buffer being written
to, but is not allocated within the PrimFunc
@@ -566,14 +541,12 @@ class Schedule(Object):
"""
scope_block = self._normalize_block_arg(scope_block)
- return list(_ffi_api.ScheduleGetOutputBlocks(self, scope_block)) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return list(_ffi_api.ScheduleGetOutputBlocks(self, scope_block)) # type: ignore
########## Schedule: Transform loops ##########
@type_checked
- def merge(
- self,
- *loops: List[LoopRV],
- ) -> LoopRV:
+ def merge(self, *loops: List[LoopRV]) -> LoopRV:
"""Merge a list of loops into one. The loops under their LCA requires:
1) Under the same scope.
2) Can't have annotations or thread bindings.
@@ -648,11 +621,7 @@ class Schedule(Object):
return _ffi_api.ScheduleMerge(self, loops) # type: ignore # pylint: disable=no-member
@type_checked
- def fuse(
- self,
- *loops: List[LoopRV],
- preserve_unit_iters: bool = True,
- ) -> LoopRV:
+ def fuse(self, *loops: List[LoopRV], preserve_unit_iters: bool = True) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
@@ -710,7 +679,8 @@ class Schedule(Object):
B[vi, vj] = A[vi, vj] * 2.0
"""
- return _ffi_api.ScheduleFuse(self, loops, preserve_unit_iters) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return _ffi_api.ScheduleFuse(self, loops, preserve_unit_iters) # type: ignore
@type_checked
def split(
@@ -791,10 +761,7 @@ class Schedule(Object):
# that there is at most one None in `factors`
return list(
_ffi_api.ScheduleSplit( # type: ignore # pylint: disable=no-member
- self,
- loop,
- factors,
- preserve_unit_iters,
+ self, loop, factors, preserve_unit_iters
)
)
@@ -920,7 +887,8 @@ class Schedule(Object):
--------
reorder
"""
- _ffi_api.ScheduleReorderBlockIterVar(self, block, new_order) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ _ffi_api.ScheduleReorderBlockIterVar(self, block, new_order) # type: ignore
@type_checked
def add_unit_loop(self, block_or_loop: Union[LoopRV, BlockRV]) -> LoopRV:
@@ -976,7 +944,8 @@ class Schedule(Object):
vi = T.axis.spatial(1, 0)
C[()] = A[()] + B[()]
"""
- return _ffi_api.ScheduleAddUnitLoop(self, block_or_loop) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return _ffi_api.ScheduleAddUnitLoop(self, block_or_loop) # type: ignore
########## Schedule: Manipulate ForKind ##########
@@ -1676,10 +1645,7 @@ class Schedule(Object):
@type_checked
def cache_index(
- self,
- block: Union[BlockRV, str],
- storage_scope: str,
- cse_thresh: int = 0,
+ self, block: Union[BlockRV, str], storage_scope: str, cse_thresh: int = 0
) -> List[BlockRV]:
"""Create a block to cache precomputed index for later use.
if there is no index computation, keep unchanged.
@@ -1765,9 +1731,7 @@ class Schedule(Object):
@type_checked
def reindex(
- self,
- block: Union[BlockRV, str],
- buffer: Union[Tuple[str, int], str, Buffer],
+ self, block: Union[BlockRV, str], buffer: Union[Tuple[str, int], str, Buffer]
) -> BlockRV:
"""Create a block that read/write a buffer region into a read/write cache with reindexing.
The layout of the cache will be the same as by the iterators of the block that reads/writes
@@ -1860,22 +1824,14 @@ class Schedule(Object):
########## Schedule: Data movement ##########
def read_at(
- self,
- loop: LoopRV,
- block: BlockRV,
- read_buffer_index: int,
- storage_scope: str,
+ self, loop: LoopRV, block: BlockRV, read_buffer_index: int, storage_scope: str
) -> BlockRV:
return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member
self, loop, block, read_buffer_index, storage_scope
)
def write_at(
- self,
- loop: LoopRV,
- block: BlockRV,
- write_buffer_index: int,
- storage_scope: str,
+ self, loop: LoopRV, block: BlockRV, write_buffer_index: int, storage_scope: str
) -> BlockRV:
return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member
self, loop, block, write_buffer_index, storage_scope
@@ -1978,11 +1934,7 @@ class Schedule(Object):
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleComputeAt( # type: ignore # pylint: disable=no-member
- self,
- block,
- loop,
- preserve_unit_loops,
- index,
+ self, block, loop, preserve_unit_loops, index
)
@type_checked
@@ -2077,11 +2029,7 @@ class Schedule(Object):
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleReverseComputeAt( # type: ignore # pylint: disable=no-member
- self,
- block,
- loop,
- preserve_unit_loops,
- index,
+ self, block, loop, preserve_unit_loops, index
)
@type_checked
@@ -2215,7 +2163,8 @@ class Schedule(Object):
"""
block = self._normalize_block_arg(block)
- _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore
########## Schedule: Reduction ##########
@@ -2295,7 +2244,8 @@ class Schedule(Object):
"""
block = self._normalize_block_arg(block)
- return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore
@type_checked
def rfactor(self, loop: LoopRV, factor_axis: int) -> BlockRV:
@@ -2444,18 +2394,14 @@ class Schedule(Object):
where `B` is the buffer that the reduction block writes to.
Negative indexing is normalized according to numpy convention.
"""
- return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore
######## Schedule: Block annotation ########
@type_checked
def storage_align( # pylint: disable=too-many-arguments
- self,
- block: Union[BlockRV, str],
- buffer_index: int,
- axis: int,
- factor: int,
- offset: int,
+ self, block: Union[BlockRV, str], buffer_index: int, axis: int, factor: int, offset: int
) -> None:
"""Set alignment requirement for specific dimension such that
stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more
@@ -2766,7 +2712,8 @@ class Schedule(Object):
block are divisible by the subspace represented by the loops starting at the given loop.
"""
- return _ffi_api.ScheduleBlockize(self, target, preserve_unit_iters) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return _ffi_api.ScheduleBlockize(self, target, preserve_unit_iters) # type: ignore
@type_checked
def tensorize(
@@ -2932,10 +2879,7 @@ class Schedule(Object):
@type_checked
def annotate(
- self,
- block_or_loop: Union[BlockRV, LoopRV],
- ann_key: str,
- ann_val: AnnotationValueT,
+ self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str, ann_val: AnnotationValueT
) -> None:
"""Annotate a block/loop with a key value pair
@@ -3102,7 +3046,7 @@ class Schedule(Object):
elif isinstance(buffer, tuple):
buffer_index_type, buffer_index = buffer
- assert buffer_index_type in ["read", "write",], (
+ assert buffer_index_type in ["read", "write"], (
f"Invalid buffer_index_type. "
f"Expected 'read' or 'write', "
f"but received {buffer_index_type}"
@@ -3300,9 +3244,7 @@ class Schedule(Object):
@type_checked
def transform_block_layout(
- self,
- block: Union[BlockRV, str],
- index_map: Union[IndexMap, Callable],
+ self, block: Union[BlockRV, str], index_map: Union[IndexMap, Callable]
) -> None:
"""Apply a transformation represented by IndexMap to block
@@ -3534,7 +3476,8 @@ class Schedule(Object):
@type_checked
def can_decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> bool:
"""Check whether the block match padding pattern and can be decomposed."""
- return _ffi_api.CanDecomposePadding(self, block, loop) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return _ffi_api.CanDecomposePadding(self, block, loop) # type: ignore
@type_checked
def pad_einsum(self, block: Union[BlockRV, str], padding: List[int]) -> None:
@@ -3661,11 +3604,7 @@ class Schedule(Object):
######## Schedule: Buffer transformation ########
@type_checked
- def rolling_buffer(
- self,
- block: Union[BlockRV, str],
- write_buffer_index: int,
- ) -> None:
+ def rolling_buffer(self, block: Union[BlockRV, str], write_buffer_index: int) -> None:
"""Compute the target buffer via rolling buffering, select the outermost rollable
axis with a positive bound overlap that appears in the block's ancestor loops
as `rolling axis`, fold and circularize the buffer along the rolling dimension,
@@ -3764,7 +3703,8 @@ class Schedule(Object):
The region_cover property of the consumer block of the target buffer will become false.
"""
block = self._normalize_block_arg(block)
- return _ffi_api.ScheduleRollingBuffer(self, block, write_buffer_index) # type: ignore # pylint: disable=no-member
+ # pylint: disable-next=no-member
+ return _ffi_api.ScheduleRollingBuffer(self, block, write_buffer_index) # type: ignore
########## Schedule: Misc ##########
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index 8b49a4bcfc..0e67411103 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -73,9 +73,7 @@ def _parse_debug_mask(debug_mask: Union[str, int]) -> int:
def _parse_enable_checks(enable_checks: bool) -> bool:
if not isinstance(enable_checks, bool):
- raise TypeError(
- "enable_checks only accepts bool value, got {} instead".format(type(enable_checks))
- )
+ raise TypeError(f"enable_checks only accepts bool value, got {type(enable_checks)} instead")
return enable_checks
@@ -235,8 +233,5 @@ class ScheduleState(Object):
if block_sref_reuse is None:
block_sref_reuse = {}
_ffi_api.ScheduleStateReplace( # type: ignore # pylint: disable=no-member
- self,
- src_sref,
- tgt_stmt,
- block_sref_reuse,
+ self, src_sref, tgt_stmt, block_sref_reuse
)
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py
index 521d882e24..c518f64f5a 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -108,9 +108,9 @@ def get_dotprod_intrin(in_dtype, out_dtype):
else: # if in_dtype == "int8"
instr = "sdot.v4i32.v16i8"
- in_dtype_x4 = "{TYPE}x4".format(TYPE=in_dtype)
- out_dtype_x4 = "{TYPE}x4".format(TYPE=out_dtype)
- in_dtype_x16 = "{TYPE}x16".format(TYPE=in_dtype)
+ in_dtype_x4 = f"{in_dtype}x4"
+ out_dtype_x4 = f"{out_dtype}x4"
+ in_dtype_x16 = f"{in_dtype}x16"
@T.prim_func
def dot_prod_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
@@ -147,7 +147,7 @@ def get_dotprod_intrin(in_dtype, out_dtype):
vec_c = C.vload([0], dtype=out_dtype_x4)
C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
- T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.{INSTR}".format(INSTR=instr)),
+ T.llvm_lookup_intrinsic_id(f"llvm.aarch64.neon.{instr}"),
T.uint32(3),
vec_c,
vec_a,
diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py
index 8d12a39ca7..5d87f1954c 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -566,7 +566,7 @@ def get_wmma_load_intrin(
is_col_major: bool,
) -> Tuple[PrimFunc, PrimFunc]:
"""Generator of wmma_load intrins"""
- wmma_fragment_scope = "wmma.matrix_{}".format("b" if is_b else "a")
+ wmma_fragment_scope = f"wmma.matrix_{'b' if is_b else 'a'}"
layout = "col_major" if is_col_major else "row_major"
offset_factor = get_tensor_core_load_offset_factor(dtype)
@@ -906,8 +906,7 @@ TensorIntrin.register(
WMMA_SYNC_16x16x16_s8s8s32_INTRIN = "wmma_sync_16x16x16_s8s8s32"
TensorIntrin.register(
- WMMA_SYNC_16x16x16_s8s8s32_INTRIN,
- *get_wmma_sync_intrin(16, 16, 16, "int8", "int32", False),
+ WMMA_SYNC_16x16x16_s8s8s32_INTRIN, *get_wmma_sync_intrin(16, 16, 16, "int8", "int32", False)
)
WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN = "wmma_sync_16x16x16_s8s8s32_trans"
@@ -918,8 +917,7 @@ TensorIntrin.register(
WMMA_SYNC_8x8x32_s4s4s32_TRANS_INTRIN = "wmma_sync_8x8x32_s4s4s32_trans"
TensorIntrin.register(
- WMMA_SYNC_8x8x32_s4s4s32_TRANS_INTRIN,
- *get_wmma_sync_intrin(8, 8, 32, "int4", "int32", True),
+ WMMA_SYNC_8x8x32_s4s4s32_TRANS_INTRIN, *get_wmma_sync_intrin(8, 8, 32, "int4", "int32", True)
)
WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_shared"
@@ -984,8 +982,7 @@ TensorIntrin.register(
WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b_shared"
TensorIntrin.register(
- WMMA_LOAD_16x16x16_S8_B_INTRIN,
- *get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, False),
+ WMMA_LOAD_16x16x16_S8_B_INTRIN, *get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, False)
)
WMMA_LOAD_16x16x16_S8_B_DYN_INTRIN = "wmma_load_16x16x16_s8_b_shared_dyn"
@@ -1020,8 +1017,7 @@ TensorIntrin.register(
WMMA_LOAD_8x8x32_S4_A_INTRIN = "wmma_load_8x8x32_s4_a_shared"
TensorIntrin.register(
- WMMA_LOAD_8x8x32_S4_A_INTRIN,
- *get_wmma_load_intrin(8, 8, 32, "int4", "shared", False, False),
+ WMMA_LOAD_8x8x32_S4_A_INTRIN, *get_wmma_load_intrin(8, 8, 32, "int4", "shared", False, False)
)
WMMA_LOAD_8x8x32_S4_A_DYN_INTRIN = "wmma_load_8x8x32_s4_a_shared_dyn"
@@ -1094,8 +1090,7 @@ TensorIntrin.register(
WMMA_STORE_8x8x32_S32_SHARED_DYN_INTRIN = "wmma_store_8x8x32_s32_shared_dyn"
TensorIntrin.register(
- WMMA_STORE_8x8x32_S32_SHARED_DYN_INTRIN,
- *get_wmma_store_intrin(8, 8, 32, "int32", "shared.dyn"),
+ WMMA_STORE_8x8x32_S32_SHARED_DYN_INTRIN, *get_wmma_store_intrin(8, 8, 32, "int32", "shared.dyn")
)
WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global"