This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 48200fc3d7 [TOPI] Use f-strings for string formatting, NFC (#14822)
48200fc3d7 is described below
commit 48200fc3d789efc62d5a0aa872280985aaf88b4b
Author: Krzysztof Parzyszek <[email protected]>
AuthorDate: Thu May 11 02:13:10 2023 -0500
[TOPI] Use f-strings for string formatting, NFC (#14822)
* [TOPI] Use f-strings for string formatting, NFC
Replace uses of % and .format() with f-strings.
* Format updated files
---
python/tvm/topi/arm_cpu/conv2d.py | 15 ++----
python/tvm/topi/arm_cpu/conv2d_alter_op.py | 26 +++++-----
python/tvm/topi/cuda/batch_matmul.py | 2 +-
python/tvm/topi/cuda/conv2d_int8.py | 4 +-
python/tvm/topi/cuda/conv3d_direct.py | 2 +-
python/tvm/topi/cuda/dense.py | 18 ++-----
python/tvm/topi/cuda/group_conv2d_nchw.py | 12 ++---
python/tvm/topi/cuda/scan.py | 2 +-
python/tvm/topi/cuda/softmax.py | 5 +-
python/tvm/topi/cuda/sparse.py | 7 ++-
python/tvm/topi/generic/conv2d.py | 34 ++++++-------
python/tvm/topi/hexagon/conv2d_alter_op.py | 6 +--
python/tvm/topi/hexagon/qnn/conv2d_alter_op.py | 6 +--
python/tvm/topi/hls/nn.py | 15 +++---
python/tvm/topi/image/resize.py | 55 ++++----------------
python/tvm/topi/nn/conv2d.py | 70 +++++++-------------------
python/tvm/topi/nn/depthwise_conv2d.py | 32 ++++--------
python/tvm/topi/nn/fifo_buffer.py | 10 ++--
python/tvm/topi/nn/upsampling.py | 4 +-
python/tvm/topi/nn/utils.py | 17 ++++---
python/tvm/topi/nn/winograd_util.py | 4 +-
python/tvm/topi/reduction.py | 2 +-
python/tvm/topi/testing/poolnd_python.py | 12 ++---
python/tvm/topi/testing/resize_python.py | 2 +-
python/tvm/topi/transform.py | 11 ++--
python/tvm/topi/x86/conv2d_alter_op.py | 18 +++----
python/tvm/topi/x86/conv2d_avx_1x1.py | 4 +-
python/tvm/topi/x86/conv2d_int8.py | 6 +--
python/tvm/topi/x86/conv3d.py | 31 +++---------
python/tvm/topi/x86/nn.py | 5 +-
30 files changed, 143 insertions(+), 294 deletions(-)
diff --git a/python/tvm/topi/arm_cpu/conv2d.py
b/python/tvm/topi/arm_cpu/conv2d.py
index ab489161a8..a478818084 100644
--- a/python/tvm/topi/arm_cpu/conv2d.py
+++ b/python/tvm/topi/arm_cpu/conv2d.py
@@ -33,10 +33,7 @@ from .conv2d_spatial_pack import (
schedule_conv2d_spatial_pack_nchw,
schedule_conv2d_spatial_pack_nhwc,
)
-from .mprofile.dsp.conv2d import (
- conv2d_nhwc_dsp_compute,
- conv2d_nhwc_dsp_schedule,
-)
+from .mprofile.dsp.conv2d import conv2d_nhwc_dsp_compute,
conv2d_nhwc_dsp_schedule
@autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu")
@@ -267,13 +264,7 @@ def _schedule_winograd(cfg, s, output, last):
if isinstance(U.op, tvm.te.ComputeOp):
kernel, G = U.op.input_tensors
s[G].compute_inline()
- (
- eps,
- nu,
- k,
- c,
- kk,
- ) = s[U].op.axis
+ (eps, nu, k, c, kk) = s[U].op.axis
if autotvm.GLOBAL_SCOPE.in_tuning:
# kernel transformation will be pre-computed during compilation,
so we skip
# this part to make tuning records correct
@@ -364,7 +355,7 @@ def conv2d_nchw_winograd_nnpack(cfg, data, kernel, strides,
padding, dilation, o
tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16,
)
else:
- raise ValueError("Unsupported data type {} for conv2d winograd
nnpack".format(dtype))
+ raise ValueError(f"Unsupported data type {dtype} for conv2d winograd
nnpack")
@autotvm.register_topi_schedule("conv2d_nchw_winograd_nnpack.arm_cpu")
diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py
b/python/tvm/topi/arm_cpu/conv2d_alter_op.py
index d4878f4b69..b0fdb99cbe 100644
--- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py
+++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py
@@ -152,9 +152,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.conv2d(
- inputs[0],
- relay.Constant(tvm.nd.array(reshaped_new_kernel)),
- **new_attrs,
+ inputs[0], relay.Constant(tvm.nd.array(reshaped_new_kernel)),
**new_attrs
)
# Only microTVM does layout alteration for NHWC layout with real data types
@@ -167,7 +165,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
CO, _, KH, KW = get_const_tuple(kernel.shape)
VC = cfg["tile_co"].size[-1]
- new_attrs["kernel_layout"] = "OIHW%do" % VC
+ new_attrs["kernel_layout"] = f"OIHW{VC}o"
new_data = data
new_kernel = te.placeholder((idxd(CO, VC), CI, KH, KW, VC),
dtype=kernel.dtype)
@@ -275,7 +273,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
CO, M, KH, KW = get_const_tuple(kernel.shape)
VC = cfg["tile_co"].size[-1]
- new_attrs["kernel_layout"] = "OIHW%do" % (cfg["tile_co"].size[-1])
+ new_attrs["kernel_layout"] = f"OIHW{cfg['tile_co'].size[-1]}o"
# Store the same config for the altered operator (workload)
new_data = data
@@ -309,10 +307,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
# update new attrs
new_attrs["channels"] = out_channel
- new_attrs["data_layout"] = "NCHW%dc" % ic_bn
+ new_attrs["data_layout"] = f"NCHW{ic_bn}c"
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
- new_attrs["kernel_layout"] = "OIHW%di%do" % (ic_bn, oc_bn)
- new_attrs["out_layout"] = "NCHW%dc" % oc_bn
+ new_attrs["kernel_layout"] = f"OIHW{ic_bn}i{oc_bn}o"
+ new_attrs["out_layout"] = f"NCHW{oc_bn}c"
# Store altered operator's config
new_data = te.placeholder(
@@ -353,9 +351,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
# update new attrs
new_attrs["channels"] = out_channel
- new_attrs["data_layout"] = "NCHW%dc" % ic_bn
- new_attrs["kernel_layout"] = "OIHW1i%do" % oc_bn
- new_attrs["out_layout"] = "NCHW%dc" % oc_bn
+ new_attrs["data_layout"] = f"NCHW{ic_bn}c"
+ new_attrs["kernel_layout"] = f"OIHW1i{oc_bn}o"
+ new_attrs["out_layout"] = f"NCHW{oc_bn}c"
# Store altered operator's config.
new_data = te.placeholder(
@@ -407,9 +405,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
# update new attrs
new_attrs["channels"] = out_channel
- new_attrs["data_layout"] = "NCHW%dc" % ic_bn
- new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn //
n_elems, oc_bn, n_elems)
- new_attrs["out_layout"] = "NCHW%dc" % oc_bn
+ new_attrs["data_layout"] = f"NCHW{ic_bn}c"
+ new_attrs["kernel_layout"] = f"OIHW{ic_bn //
n_elems:n}i{oc_bn:n}o{n_elems:n}i"
+ new_attrs["out_layout"] = f"NCHW{oc_bn}c"
# Store altered operator's config.
new_data = te.placeholder(
diff --git a/python/tvm/topi/cuda/batch_matmul.py
b/python/tvm/topi/cuda/batch_matmul.py
index d2f5c9b9c5..83b000a4b9 100644
--- a/python/tvm/topi/cuda/batch_matmul.py
+++ b/python/tvm/topi/cuda/batch_matmul.py
@@ -342,7 +342,7 @@ def _schedule_batch_matmul_int8(cfg, s, output):
_, N, _ = get_const_tuple(input_y.shape)
k_factor = 4
- assert K % k_factor == 0, "Input dimension must divide {}".format(k_factor)
+ assert K % k_factor == 0, f"Input dimension must divide {k_factor}"
if K % 16 == 0:
k_factor = 16
diff --git a/python/tvm/topi/cuda/conv2d_int8.py
b/python/tvm/topi/cuda/conv2d_int8.py
index 0edd64e0e3..b959136999 100644
--- a/python/tvm/topi/cuda/conv2d_int8.py
+++ b/python/tvm/topi/cuda/conv2d_int8.py
@@ -90,7 +90,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding,
dilation, layout, out_
batch, channels, height, width = get_const_tuple(data.shape)
assert (
channels % ic_block_factor == 0
- ), "Number of input channels should be multiple of
{}".format(ic_block_factor)
+ ), f"Number of input channels should be multiple of {ic_block_factor}"
packed_data = te.compute(
(batch, channels // ic_block_factor, height, width,
ic_block_factor),
lambda n, c, h, w, vc: data[n, c * ic_block_factor + vc, h, w],
@@ -100,7 +100,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding,
dilation, layout, out_
out_channels, in_channels, kernel_h, kernel_w =
get_const_tuple(kernel.shape)
assert (
out_channels % oc_block_factor == 0
- ), "Number of output channels should be multiple of
{}".format(oc_block_factor)
+ ), f"Number of output channels should be multiple of {oc_block_factor}"
packed_kernel = te.compute(
(
out_channels // oc_block_factor,
diff --git a/python/tvm/topi/cuda/conv3d_direct.py
b/python/tvm/topi/cuda/conv3d_direct.py
index faccb75bad..2a8e573621 100644
--- a/python/tvm/topi/cuda/conv3d_direct.py
+++ b/python/tvm/topi/cuda/conv3d_direct.py
@@ -31,7 +31,7 @@ def schedule_direct_conv3d_cuda(cfg, s, conv, layout,
workload_name):
elif layout == "NDHWC":
n, d, y, x, f = s[conv].op.axis
else:
- raise ValueError("not support this layout {} yet".format(layout))
+ raise ValueError(f"not support this layout {layout} yet")
rc, rd, ry, rx = s[conv].op.reduce_axis
cfg.define_split("tile_f", f, num_outputs=4)
cfg.define_split("tile_d", d, num_outputs=4)
diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py
index 258b22d5c8..fa2c4a0f9d 100644
--- a/python/tvm/topi/cuda/dense.py
+++ b/python/tvm/topi/cuda/dense.py
@@ -29,13 +29,7 @@ logger = logging.getLogger("topi")
def _matmul_cublas_common(
- cfg,
- tensor_a,
- tensor_b,
- bias=None,
- out_dtype=None,
- transpose_a=False,
- transpose_b=False,
+ cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False,
transpose_b=False
):
assert len(tensor_a.shape) == 2 and len(tensor_b.shape) == 2, "only
support 2-dim matmul"
if bias is not None:
@@ -58,13 +52,7 @@ def _matmul_cublas_common(
@autotvm.register_topi_compute("matmul_cublas.cuda")
def matmul_cublas(
- cfg,
- tensor_a,
- tensor_b,
- bias=None,
- out_dtype=None,
- transpose_a=False,
- transpose_b=False,
+ cfg, tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False,
transpose_b=False
):
"""Matmul operator on CUDA with CUBLAS"""
return _matmul_cublas_common(cfg, tensor_a, tensor_b, bias, out_dtype,
transpose_a, transpose_b)
@@ -142,7 +130,7 @@ def _schedule_dense_int8(cfg, s, output):
out_dim, _ = get_const_tuple(weight.shape)
in_dim_factor = 4
- assert in_dim % in_dim_factor == 0, "Input dimension must divide
{}".format(in_dim_factor)
+ assert in_dim % in_dim_factor == 0, f"Input dimension must divide
{in_dim_factor}"
if in_dim % 16 == 0:
in_dim_factor = 16
diff --git a/python/tvm/topi/cuda/group_conv2d_nchw.py
b/python/tvm/topi/cuda/group_conv2d_nchw.py
index b48ea3a5f8..ba0fa8a4c4 100644
--- a/python/tvm/topi/cuda/group_conv2d_nchw.py
+++ b/python/tvm/topi/cuda/group_conv2d_nchw.py
@@ -245,10 +245,10 @@ def group_conv2d_NCHWc_int8(
assert out_channels % groups == 0, "output channels must divide group
size"
assert (
channels % ic_block_factor == 0
- ), "Number of input channels per group must divide
{}".format(ic_block_factor)
+ ), f"Number of input channels per group must divide {ic_block_factor}"
assert (
out_channels % oc_block_factor == 0
- ), "Number of output channels per group must divide
{}".format(oc_block_factor)
+ ), f"Number of output channels per group must divide {oc_block_factor}"
packed_data = te.compute(
(batch, channels // ic_block_factor, height, width,
ic_block_factor),
@@ -282,14 +282,10 @@ def group_conv2d_NCHWc_int8(
# Shall we pad the channels to avoid raising assertions?
assert (
groups <= oc_chunk
- ), "Number of groups {} should be less than " "output channel chunk size
{}".format(
- groups, oc_chunk
- )
+ ), f"Number of groups {groups} should be less than output channel chunk
size {oc_chunk}"
assert (
groups <= ic_chunk
- ), "Number of groups {} should be less than " "input channel chunk size
{}".format(
- groups, ic_chunk
- )
+ ), f"Number of groups {groups} should be less than input channel chunk
size {ic_chunk}"
if isinstance(stride, int):
stride_h = stride_w = stride
diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py
index 3be13d7711..f697302961 100644
--- a/python/tvm/topi/cuda/scan.py
+++ b/python/tvm/topi/cuda/scan.py
@@ -31,7 +31,7 @@ from .injective import schedule_injective_from_existing
def _get_thrust_func_name(tvmop):
tvmop_to_thrust_func_name = {tvm.tir.generic.add:
"tvm.contrib.thrust.sum_scan"}
- assert tvmop in tvmop_to_thrust_func_name, "{} not supported by
thrust".format(tvmop)
+ assert tvmop in tvmop_to_thrust_func_name, f"{tvmop} not supported by
thrust"
return tvmop_to_thrust_func_name[tvmop]
diff --git a/python/tvm/topi/cuda/softmax.py b/python/tvm/topi/cuda/softmax.py
index a3c3e431e7..3919eac816 100644
--- a/python/tvm/topi/cuda/softmax.py
+++ b/python/tvm/topi/cuda/softmax.py
@@ -44,10 +44,7 @@ def _schedule_softmax(softmax_op, s, outs, tgt):
expsum = softmax_op.input_tensors[2]
else:
raise ValueError(
- "Tag is expected to be softmax_output or log_softmax_output. \
- Got {0}".format(
- op_tag
- )
+ f"Tag is expected to be softmax_output or log_softmax_output. Got
{op_tag}"
)
# The nvptx and rocm backends only supports 32-bits warp shuffle
diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py
index 6ee4f470f3..921075601e 100644
--- a/python/tvm/topi/cuda/sparse.py
+++ b/python/tvm/topi/cuda/sparse.py
@@ -159,10 +159,9 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr):
bs_m = bs_n
mb = m // bs_m
mi = warp_size
- assert (
- mb >= mi
- ), "Number of block rows in dense matrix must be larger than warp
size: {} vs {}.".format(
- warp_size, mb
+ assert mb >= mi, (
+ f"Number of block rows in dense matrix must be larger than warp
size: "
+ f"{warp_size} vs {mb}."
)
mo = ceil_div(mb, mi)
ni = 1 # TODO(tkonolige): how do I compute the number of warps per
block?
diff --git a/python/tvm/topi/generic/conv2d.py
b/python/tvm/topi/generic/conv2d.py
index a4a37247c8..189bdf9cbd 100644
--- a/python/tvm/topi/generic/conv2d.py
+++ b/python/tvm/topi/generic/conv2d.py
@@ -45,14 +45,12 @@ def fallback_schedule_cpu_common_int8(cfg, wkl,
int32_lanes, num_int8_elements):
dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1
out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1
- assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d,
int32_lanes=%d" % (
- wkl.out_filter,
- int32_lanes,
- )
- assert wkl.in_filter % num_int8_elements == 0, "wkl.in_filter=%d,
num_int8_elements=%d" % (
- wkl.in_filter,
- num_int8_elements,
- )
+ assert (
+ wkl.out_filter % int32_lanes == 0
+ ), f"wkl.out_filter={wkl.out_filter}, int32_lanes={int32_lanes}"
+ assert (
+ wkl.in_filter % num_int8_elements == 0
+ ), f"wkl.in_filter={wkl.in_filter}, num_int8_elements={num_int8_elements}"
oc_bn = int32_lanes if int32_lanes >= num_int8_elements else
num_int8_elements
ic_bn = 1
@@ -93,14 +91,12 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes,
num_int8_elements):
out_height = (wkl.height + pt + pb - wkl.kernel_h) // HSTR + 1
out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1
- assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d,
int32_lanes=%d" % (
- wkl.out_filter,
- int32_lanes,
- )
- assert wkl.in_filter % num_int8_elements == 0, "wkl.in_filter=%d,
num_int8_elements=%d" % (
- wkl.in_filter,
- num_int8_elements,
- )
+ assert (
+ wkl.out_filter % int32_lanes == 0
+ ), f"wkl.out_filter={wkl.out_filter}, int32_lanes={int32_lanes}"
+ assert (
+ wkl.in_filter % num_int8_elements == 0
+ ), f"wkl.in_filter={wkl.in_filter}, num_int8_elements={num_int8_elements}"
oc_bn = int32_lanes if int32_lanes >= num_int8_elements else
num_int8_elements
ic_bn = 1
@@ -118,7 +114,7 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes,
num_int8_elements):
cfg["tile_oh"] = OtherOptionEntity(oh_factor)
cfg["tile_ow"] = SplitEntity([out_width // ow_factor,
ow_factor])
return
- raise ValueError("cannot decide default schedule for workload:
{}".format(wkl))
+ raise ValueError(f"cannot decide default schedule for workload: {wkl}")
def schedule_conv_NCHWc_cpu_common_int8(
@@ -257,7 +253,7 @@ def schedule_conv_NCHWc_cpu_common_int8(
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
else:
- raise ValueError("Unsupported output ndim: %s" % out_ndim)
+ raise ValueError(f"Unsupported output ndim: {out_ndim}")
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
if inline_fused:
s[C].compute_at(s[O], ow_block)
@@ -382,7 +378,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
else:
- raise ValueError("Unsupported output ndim: %s" % out_ndim)
+ raise ValueError(f"Unsupported output ndim: {out_ndim}")
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner,
oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
diff --git a/python/tvm/topi/hexagon/conv2d_alter_op.py
b/python/tvm/topi/hexagon/conv2d_alter_op.py
index 201b6f8043..a4affb8a82 100644
--- a/python/tvm/topi/hexagon/conv2d_alter_op.py
+++ b/python/tvm/topi/hexagon/conv2d_alter_op.py
@@ -51,9 +51,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
new_attrs = {k: attrs[k] for k in attrs.keys()}
new_attrs["channels"] = out_channel
- new_attrs["data_layout"] = "NCHW%dc" % ic_bn
- new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn //
n_elems, oc_bn, n_elems)
- new_attrs["out_layout"] = "NCHW%dc" % oc_bn
+ new_attrs["data_layout"] = f"NCHW{ic_bn}c"
+ new_attrs["kernel_layout"] = f"OIHW{ic_bn //
n_elems:n}i{oc_bn:n}o{n_elems:n}i"
+ new_attrs["out_layout"] = f"NCHW{oc_bn}c"
return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)
diff --git a/python/tvm/topi/hexagon/qnn/conv2d_alter_op.py
b/python/tvm/topi/hexagon/qnn/conv2d_alter_op.py
index 867a477956..b8240dccaf 100644
--- a/python/tvm/topi/hexagon/qnn/conv2d_alter_op.py
+++ b/python/tvm/topi/hexagon/qnn/conv2d_alter_op.py
@@ -44,9 +44,9 @@ def _alter_qnn_conv2d_layout(attrs, inputs, tinfos,
_out_type):
new_attrs = dict(attrs)
new_attrs["channels"] = out_channel
- new_attrs["data_layout"] = "NCHW%dc" % ic_bn
- new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn //
n_elems, oc_bn, n_elems)
- new_attrs["out_layout"] = "NCHW%dc" % oc_bn
+ new_attrs["data_layout"] = f"NCHW{ic_bn}c"
+ new_attrs["kernel_layout"] = f"OIHW{ic_bn //
n_elems:n}i{oc_bn:n}o{n_elems:n}i"
+ new_attrs["out_layout"] = f"NCHW{oc_bn}c"
return relay.qnn.op.conv2d(*inputs, **new_attrs)
diff --git a/python/tvm/topi/hls/nn.py b/python/tvm/topi/hls/nn.py
index b9053fe530..4d0f1f66d7 100644
--- a/python/tvm/topi/hls/nn.py
+++ b/python/tvm/topi/hls/nn.py
@@ -43,7 +43,7 @@ def _schedule_conv2d(outs):
Out = outs[0].op.output(0)
s[Conv2d].compute_at(s[Out], s[Out].op.axis[1])
else:
- raise RuntimeError("Unsupported operator: %s" % OP.tag)
+ raise RuntimeError(f"Unsupported operator: {OP.tag}")
traverse(outs[0].op)
@@ -223,7 +223,7 @@ def schedule_reduce(outs):
Out = outs[0].op.output(0)
s[Reduce].compute_at(s[Out], s[Out].op.axis[0])
else:
- raise RuntimeError("Unsupported operator: %s" % OP.tag)
+ raise RuntimeError(f"Unsupported operator: {OP.tag}")
traverse(outs[0].op)
@@ -264,10 +264,7 @@ def schedule_softmax(outs):
expsum = softmax.op.input_tensors[2]
else:
raise ValueError(
- "Tag is expected to be softmax_output or log_softmax_output. \
- Got {0}".format(
- op_tag
- )
+ f"Tag is expected to be softmax_output or log_softmax_output. Got
{op_tag}"
)
if exp is not None:
@@ -315,7 +312,7 @@ def schedule_dense(outs):
Out = outs[0].op.output(0)
s[Dense].compute_at(s[Out], s[Out].op.axis[1])
else:
- raise RuntimeError("Unsupported operator: %s" % OP.tag)
+ raise RuntimeError(f"Unsupported operator: {OP.tag}")
traverse(outs[0].op)
@@ -358,7 +355,7 @@ def schedule_pool(outs, layout):
Out = outs[0].op.output(0)
s[Pool].compute_at(s[Out], s[Out].op.axis[1])
else:
- raise RuntimeError("Unsupported operator: %s" % OP.tag)
+ raise RuntimeError(f"Unsupported operator: {OP.tag}")
traverse(outs[0].op)
@@ -401,7 +398,7 @@ def schedule_adaptive_pool(outs):
Out = outs[0].op.output(0)
s[Pool].compute_at(s[Out], s[Out].op.axis[1])
else:
- raise RuntimeError("Unsupported operator: %s" % OP.tag)
+ raise RuntimeError(f"Unsupported operator: {OP.tag}")
traverse(outs[0].op)
diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py
index 1973e0543a..29ed03f62e 100644
--- a/python/tvm/topi/image/resize.py
+++ b/python/tvm/topi/image/resize.py
@@ -168,7 +168,7 @@ def get_inx(
)
else:
raise ValueError(
- "Unsupported coordinate_transformation_mode:
{}".format(coordinate_transformation_mode)
+ f"Unsupported coordinate_transformation_mode:
{coordinate_transformation_mode}"
)
return in_x
@@ -194,7 +194,7 @@ def get_closest_index(in_x, rounding_method, boxes,
use_int_div=False):
epsilon = 1e-5
closest_x_index = te.ceil(in_x - epsilon).astype("int32")
else:
- raise ValueError("Uknown rounding method: {}".format(rounding_method))
+ raise ValueError(f"Unknown rounding method: {rounding_method}")
return closest_x_index
@@ -314,14 +314,7 @@ def _resize_1d(
if boxes is not None:
# TODO(mbrookhart): Find an example of this
raise NotImplementedError("resize1d with image boxes not yet
implemented")
- in_x = get_inx(
- x,
- image_width,
- target_width,
- coordinate_transformation_mode,
- roi[0],
- roi[1],
- )
+ in_x = get_inx(x, image_width, target_width,
coordinate_transformation_mode, roi[0], roi[1])
if method == "nearest_neighbor":
if rounding_method == "":
@@ -332,17 +325,7 @@ def _resize_1d(
closest_x_index = get_closest_index(in_x, rounding_method, boxes)
- value = get_1d_pixel(
- data,
- layout,
- image_width,
- box_idx,
- c,
- closest_x_index,
- cc,
- inum,
- ic,
- )
+ value = get_1d_pixel(data, layout, image_width, box_idx, c,
closest_x_index, cc, inum, ic)
elif method == "linear":
x_int = te.floor(in_x).astype("int32")
@@ -350,17 +333,7 @@ def _resize_1d(
p = [0 for i in range(2)]
for i in range(2):
- p[i] = get_1d_pixel(
- data,
- layout,
- image_width,
- box_idx,
- c,
- x_int + i,
- cc,
- inum,
- ic,
- )
+ p[i] = get_1d_pixel(data, layout, image_width, box_idx, c, x_int +
i, cc, inum, ic)
value = _lerp(*p, x_lerp)
@@ -371,17 +344,7 @@ def _resize_1d(
# Get the surrounding values
p = [0 for i in range(4)]
for i in range(4):
- p[i] = get_1d_pixel(
- data,
- layout,
- image_width,
- box_idx,
- c,
- xint + i - 1,
- cc,
- inum,
- ic,
- )
+ p[i] = get_1d_pixel(data, layout, image_width, box_idx, c, xint +
i - 1, cc, inum, ic)
wx = _cubic_spline_weights(xfract, alpha)
if exclude_outside:
@@ -499,7 +462,7 @@ def resize1d(
if output_shape is None:
output_shape = [in_n, in_c, size[0], in_cc]
else:
- raise ValueError("%s layout is not supported." % layout)
+ raise ValueError(f"{layout} layout is not supported.")
if isinstance(size, tuple):
size = list(size)
@@ -866,7 +829,7 @@ def resize2d(
if output_shape is None:
output_shape = [in_n, in_c, size[0], size[1], in_cc]
else:
- raise ValueError("%s layout is not supported." % layout)
+ raise ValueError(f"{layout} layout is not supported.")
if isinstance(size, tuple):
size = list(size)
@@ -967,7 +930,7 @@ def crop_and_resize(
image_h = data.shape[2].astype("int32")
image_w = data.shape[3].astype("int32")
else:
- raise ValueError("%s layout is not supported." % layout)
+ raise ValueError(f"{layout} layout is not supported.")
if method == "bilinear":
method = "linear"
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index a3afc25902..f70d749e0f 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -172,7 +172,7 @@ def _get_workload(data, kernel, stride, padding, dilation,
out_dtype, data_layou
elif data_layout == "HWCN":
IH, IW, CI, _ = get_const_tuple(data.shape)
else:
- raise ValueError("not support this layout {} yet".format(data_layout))
+ raise ValueError(f"not support this layout {data_layout} yet")
if data_layout == "NCHW":
CO, CIG, KH, KW = get_const_tuple(kernel.shape)
@@ -193,10 +193,7 @@ def _get_workload(data, kernel, stride, padding, dilation,
out_dtype, data_layou
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (
data.dtype == "uint8" and kernel.dtype == "int8"
- ), "Do not support inputs with different data types now. ' \
- '{} vs. {}".format(
- data.dtype, kernel.dtype
- )
+ ), f"Do not support inputs with different data types now. {data.dtype} vs.
{kernel.dtype}"
return Workload(
data.dtype,
out_dtype,
@@ -1227,8 +1224,7 @@ def _conv2d_winograd_nhwc_impl(
kernel_pack = te.compute(
(alpha, alpha, CO, CI),
lambda eps, nu, co, ci: te.sum(
- weight[r_kh, r_kw, ci, co] * G[eps, r_kh] * G[nu, r_kw],
- axis=[r_kh, r_kw],
+ weight[r_kh, r_kw, ci, co] * G[eps, r_kh] * G[nu, r_kw],
axis=[r_kh, r_kw]
),
name="kernel_pack",
)
@@ -1246,10 +1242,7 @@ def _conv2d_winograd_nhwc_impl(
input_tile = te.compute(
(alpha, alpha, P, CI),
lambda eps, nu, p, ci: data_pad[
- p // (nH * nW),
- ((p // nW) % nH) * m + eps,
- (p % nW) * m + nu,
- ci,
+ p // (nH * nW), ((p // nW) % nH) * m + eps, (p % nW) * m + nu, ci
],
name="input_tile",
attrs={"schedule_rule": "None"},
@@ -1261,8 +1254,7 @@ def _conv2d_winograd_nhwc_impl(
data_pack = te.compute(
(alpha, alpha, P, CI),
lambda eps, nu, p, ci: te.sum(
- input_tile[r_a, r_b, p, ci] * B[r_a, eps] * B[r_b, nu],
- axis=[r_a, r_b],
+ input_tile[r_a, r_b, p, ci] * B[r_a, eps] * B[r_b, nu], axis=[r_a,
r_b]
),
name="data_pack",
attrs={
@@ -1276,8 +1268,7 @@ def _conv2d_winograd_nhwc_impl(
bgemm = te.compute(
(alpha, alpha, P, CO),
lambda eps, nu, p, co: te.sum(
- data_pack[eps, nu, p, ci] * kernel_pack[eps, nu, co, ci],
- axis=[ci],
+ data_pack[eps, nu, p, ci] * kernel_pack[eps, nu, co, ci], axis=[ci]
),
name="bgemm",
attrs=bgemm_attrs,
@@ -1293,8 +1284,7 @@ def _conv2d_winograd_nhwc_impl(
inverse = te.compute(
(m, m, P, CO),
lambda vh, vw, p, co: te.sum(
- bgemm[r_a, r_b, p, co] * A[r_a, vh] * A[r_b, vw],
- axis=[r_a, r_b],
+ bgemm[r_a, r_b, p, co] * A[r_a, vh] * A[r_b, vw], axis=[r_a, r_b]
),
name="inverse",
attrs={
@@ -1306,12 +1296,7 @@ def _conv2d_winograd_nhwc_impl(
# output
output = te.compute(
(N, H, W, CO),
- lambda n, h, w, co: inverse[
- h % m,
- w % m,
- n * nH * nW + (h // m) * nW + (w // m),
- co,
- ],
+ lambda n, h, w, co: inverse[h % m, w % m, n * nH * nW + (h // m) * nW
+ (w // m), co],
name="conv2d_winograd",
)
@@ -1361,12 +1346,7 @@ def _conv2d_winograd_nchw_impl(
assert HSTR == 1 and WSTR == 1 and KH == 3 and KW == 3
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
- data_pad = pad(
- data,
- (0, 0, pt, pl),
- (0, 0, pb, pr),
- name="data_pad",
- )
+ data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
r = KW
m = tile_size
@@ -1385,8 +1365,7 @@ def _conv2d_winograd_nchw_impl(
kernel_pack = te.compute(
(alpha, alpha, CI, CO),
lambda eps, nu, ci, co: te.sum(
- weight[co, ci, r_kh, r_kw] * G[eps, r_kh] * G[nu, r_kw],
- axis=[r_kh, r_kw],
+ weight[co, ci, r_kh, r_kw] * G[eps, r_kh] * G[nu, r_kw],
axis=[r_kh, r_kw]
),
name="kernel_pack",
)
@@ -1404,10 +1383,7 @@ def _conv2d_winograd_nchw_impl(
input_tile = te.compute(
(CI, P, alpha, alpha),
lambda ci, p, eps, nu: data_pad[
- p // (nH * nW),
- ci,
- ((p // nW) % nH) * m + eps,
- (p % nW) * m + nu,
+ p // (nH * nW), ci, ((p // nW) % nH) * m + eps, (p % nW) * m + nu
],
name="input_tile",
attrs={"schedule_rule": "None"},
@@ -1419,13 +1395,10 @@ def _conv2d_winograd_nchw_impl(
data_pack = te.compute(
(alpha, alpha, CI, P),
lambda eps, nu, ci, p: te.sum(
- input_tile[ci, p, r_a, r_b] * B[r_a, eps] * B[r_b, nu],
- axis=[r_a, r_b],
+ input_tile[ci, p, r_a, r_b] * B[r_a, eps] * B[r_b, nu], axis=[r_a,
r_b]
),
name="data_pack",
- attrs={
- "schedule_rule": "conv2d_nchw_winograd_data_pack",
- },
+ attrs={"schedule_rule": "conv2d_nchw_winograd_data_pack"},
)
# do batch gemm
@@ -1433,8 +1406,7 @@ def _conv2d_winograd_nchw_impl(
bgemm = te.compute(
(alpha, alpha, CO, P),
lambda eps, nu, co, p: te.sum(
- data_pack[eps, nu, ci, p] * kernel_pack[eps, nu, ci, co],
- axis=[ci],
+ data_pack[eps, nu, ci, p] * kernel_pack[eps, nu, ci, co], axis=[ci]
),
name="bgemm",
attrs=bgemm_attrs,
@@ -1446,24 +1418,16 @@ def _conv2d_winograd_nchw_impl(
inverse = te.compute(
(CO, P, m, m),
lambda co, p, vh, vw: te.sum(
- bgemm[r_a, r_b, co, p] * A[r_a, vh] * A[r_b, vw],
- axis=[r_a, r_b],
+ bgemm[r_a, r_b, co, p] * A[r_a, vh] * A[r_b, vw], axis=[r_a, r_b]
),
name="inverse",
- attrs={
- "schedule_rule": "conv2d_nchw_winograd_inverse",
- },
+ attrs={"schedule_rule": "conv2d_nchw_winograd_inverse"},
)
# output
output = te.compute(
(N, CO, H, W),
- lambda n, co, h, w: inverse[
- co,
- n * nH * nW + (h // m) * nW + (w // m),
- h % m,
- w % m,
- ],
+ lambda n, co, h, w: inverse[co, n * nH * nW + (h // m) * nW + (w //
m), h % m, w % m],
name="conv2d_winograd",
)
diff --git a/python/tvm/topi/nn/depthwise_conv2d.py
b/python/tvm/topi/nn/depthwise_conv2d.py
index c33cf365b5..ad1e4a5517 100644
--- a/python/tvm/topi/nn/depthwise_conv2d.py
+++ b/python/tvm/topi/nn/depthwise_conv2d.py
@@ -65,30 +65,23 @@ def _get_workload(data, kernel, stride, padding, dilation,
out_dtype, data_layou
elif data_layout == "NCHWc":
_, in_channel_chunk, height, width, in_channel_block =
get_const_tuple(data.shape)
in_channel = in_channel_chunk * in_channel_block
- (
- filter_channel_chunk,
- cm_chunk,
- kh,
- kw,
- cm_block,
- filter_channel_block,
- ) = get_const_tuple(kernel.shape)
+ (filter_channel_chunk, cm_chunk, kh, kw, cm_block,
filter_channel_block) = get_const_tuple(
+ kernel.shape
+ )
filter_channel = filter_channel_chunk * filter_channel_block
channel_multiplier = cm_chunk * cm_block
- assert (
- in_channel_block == filter_channel_block
- ), "Incorrect dimensions, data has block size {}, but filter has block
size {}".format(
- in_channel_block, filter_channel_block
+ assert in_channel_block == filter_channel_block, (
+ f"Incorrect dimensions, data has block size {in_channel_block},
but filter has "
+ f"block size {filter_channel_block}"
)
else:
- raise ValueError("Data layout {} not supported".format(data_layout))
+ raise ValueError(f"Data layout {data_layout} not supported")
- assert (
- in_channel == filter_channel
- ), "Incorrect dimensions, data has {} channels but filter expects {}
channels".format(
- in_channel, filter_channel
+ assert in_channel == filter_channel, (
+ f"Incorrect dimensions, data has {in_channel} channels but filter
expects "
+ f"{filter_channel} channels"
)
out_channel = filter_channel * channel_multiplier
@@ -101,10 +94,7 @@ def _get_workload(data, kernel, stride, padding, dilation,
out_dtype, data_layou
HSTR, WSTR = stride, stride
assert (data.dtype == kernel.dtype) or (
data.dtype == "uint8" and kernel.dtype == "int8"
- ), "Do not support inputs with different data types now. ' \
- '{} vs. {}".format(
- data.dtype, kernel.dtype
- )
+ ), f"Do not support inputs with different data types now. {data.dtype} vs.
{kernel.dtype}"
dilated_kernel_h = (kh - 1) * dilation_h + 1
dilated_kernel_w = (kw - 1) * dilation_w + 1
pt, pl, pb, pr = get_pad_tuple(padding, (dilated_kernel_h,
dilated_kernel_w))
diff --git a/python/tvm/topi/nn/fifo_buffer.py
b/python/tvm/topi/nn/fifo_buffer.py
index 0f12d9faf1..301662fd03 100644
--- a/python/tvm/topi/nn/fifo_buffer.py
+++ b/python/tvm/topi/nn/fifo_buffer.py
@@ -57,8 +57,8 @@ def fifo_buffer(data, buffer, axis):
Updated value for the buffer
"""
assert len(data.shape) == len(buffer.shape), (
- "buffer and data must have same number of dimensions, "
- + "buffer.shape = {}, data.shape = {}".format(buffer.shape, data.shape)
+ f"buffer and data must have same number of dimensions, "
+ f"buffer.shape = {buffer.shape}, data.shape = {data.shape}"
)
assert len(buffer.shape) >= 1, "Zero-dimension tensor not supported"
assert 0 <= axis < len(buffer.shape), "buffer axis out of range"
@@ -101,7 +101,7 @@ def fifo_buffer(data, buffer, axis):
),
name="new_buffer",
)
- assert False, "Invalid value for axis; it should be at most
{}".format(len(buffer.shape))
+ assert False, f"Invalid value for axis; it should be at most
{len(buffer.shape)}"
elif len(buffer.shape) == 3:
if axis == 0:
return te.compute(
@@ -133,7 +133,7 @@ def fifo_buffer(data, buffer, axis):
),
name="new_buffer",
)
- assert False, "Invalid value for axis; it should be at most
{}".format(len(buffer.shape))
+ assert False, f"Invalid value for axis; it should be at most
{len(buffer.shape)}"
elif len(buffer.shape) == 4:
if axis == 0:
return te.compute(
@@ -175,7 +175,7 @@ def fifo_buffer(data, buffer, axis):
),
name="new_buffer",
)
- assert False, "Invalid value for axis; it should be at most
{}".format(len(buffer.shape))
+ assert False, f"Invalid value for axis; it should be at most
{len(buffer.shape)}"
else:
# Implement FIFO buffer as combination of concat and slice
begin = [0] * len(buffer.shape)
diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py
index e9c810cd5e..e1661bd9be 100644
--- a/python/tvm/topi/nn/upsampling.py
+++ b/python/tvm/topi/nn/upsampling.py
@@ -90,7 +90,7 @@ def upsampling(
)
else:
- raise ValueError("not support this layout {} yet".format(layout))
+ raise ValueError(f"not support this layout {layout} yet")
coord_trans = "align_corners" if align_corners else "asymmetric"
if method[0:2] == "bi":
method = method[2:]
@@ -190,7 +190,7 @@ def upsampling3d(
simplify(topi.cast(te.round(output_shape[3]),
data.shape[3].dtype)),
)
else:
- raise ValueError("not support this layout {} yet".format(layout))
+ raise ValueError(f"not support this layout {layout} yet")
if method[0:3] == "tri":
method = method[3:]
return topi.image.resize3d(
diff --git a/python/tvm/topi/nn/utils.py b/python/tvm/topi/nn/utils.py
index 01e1c1ab54..ce4038ccb6 100644
--- a/python/tvm/topi/nn/utils.py
+++ b/python/tvm/topi/nn/utils.py
@@ -79,7 +79,7 @@ def infer_pad3d(data, data_pad, layout):
_, _, ID, IH, IW = data.shape
_, _, TD, TH, TW = data_pad.shape
else:
- raise ValueError("Layout {} is not supported".format(layout))
+ raise ValueError(f"Layout {layout} is not supported")
dpad = TD - ID
hpad = TH - IH
wpad = TW - IW
@@ -158,7 +158,7 @@ def get_pad_tuple(padding, kernel):
pad_h = kernel[0] - 1
pad_w = kernel[1] - 1
else:
- raise ValueError("Unknown padding option %s" % padding)
+ raise ValueError(f"Unknown padding option {padding}")
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
@@ -194,9 +194,10 @@ def get_pad_tuple_generic(padding, kernel):
if len(padding) == len(kernel):
pad_dimensions = [p * 2 for p in padding]
elif len(padding) == len(kernel) * 2:
- return [padding[i] for i in range(len(kernel))], [
- padding[len(kernel) + i] for i in range(len(kernel))
- ]
+ return (
+ [padding[i] for i in range(len(kernel))],
+ [padding[len(kernel) + i] for i in range(len(kernel))],
+ )
else:
raise ValueError("Size of padding can only be len(kernel) or
len(kernel) * 2")
elif isinstance(padding, int):
@@ -206,7 +207,7 @@ def get_pad_tuple_generic(padding, kernel):
elif padding == "SAME":
pad_dimensions = [k - 1 for k in kernel]
else:
- raise ValueError("Unknown padding option %s" % padding)
+ raise ValueError(f"Unknown padding option {padding}")
pad_begin = [(p + 1) // 2 for p in pad_dimensions]
return [pad_begin, [pd - pb for pb, pd in zip(pad_begin, pad_dimensions)]]
@@ -263,7 +264,7 @@ def get_pad_tuple3d(padding, kernel):
pad_h = kernel[1] - 1
pad_w = kernel[2] - 1
else:
- raise ValueError("Unknown padding option %s" % padding)
+ raise ValueError(f"Unknown padding option {padding}")
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
pad_front = (pad_d + 1) // 2
@@ -304,6 +305,6 @@ def get_pad_tuple1d(padding, kernel):
elif padding == "SAME":
pad_w = kernel[0] - 1
else:
- raise ValueError("Unknown padding option %s" % padding)
+ raise ValueError(f"Unknown padding option {padding}")
pad_left = (pad_w + 1) // 2
return pad_left, pad_w - pad_left
diff --git a/python/tvm/topi/nn/winograd_util.py
b/python/tvm/topi/nn/winograd_util.py
index 4bee06fcfa..8c2f50d7f8 100644
--- a/python/tvm/topi/nn/winograd_util.py
+++ b/python/tvm/topi/nn/winograd_util.py
@@ -160,9 +160,9 @@ def _interpolation_points(degree):
def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
"""Compute the A, B, and G transform matrices for `tile_size` as a
`tvm.Expr`."""
if not 1 < tile_size < 9:
- raise ValueError("Unsupported tile size for Winograd:
{}".format(tile_size))
+ raise ValueError(f"Unsupported tile size for Winograd: {tile_size}")
if not 2 < kernel_size < 8:
- raise ValueError("Unsupported kernel size for Winograd:
{}".format(kernel_size))
+ raise ValueError(f"Unsupported kernel size for Winograd:
{kernel_size}")
degree = tile_size + kernel_size - 2
diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py
index 5045cb8174..cbd75c75ef 100644
--- a/python/tvm/topi/reduction.py
+++ b/python/tvm/topi/reduction.py
@@ -34,7 +34,7 @@ def _get_real_axis(ndim, axis):
ele += ndim
if ele >= ndim:
raise ValueError(
- "{} exceeds the maximum dimension {}. Received
axis={}".format(ele, ndim, axis)
+ f"{ele} exceeds the maximum dimension {ndim}. Received
axis={axis}"
)
real_axis.append(ele)
real_axis.sort()
diff --git a/python/tvm/topi/testing/poolnd_python.py
b/python/tvm/topi/testing/poolnd_python.py
index 29d34c36a4..486c265a02 100644
--- a/python/tvm/topi/testing/poolnd_python.py
+++ b/python/tvm/topi/testing/poolnd_python.py
@@ -38,10 +38,7 @@ def _get_supported_layout(dims: int):
return "NCDHW"
-def _convert_to_layout(
- input_tensor: np.ndarray,
- layout: str,
-) -> np.ndarray:
+def _convert_to_layout(input_tensor: np.ndarray, layout: str) -> np.ndarray:
"""
Converts back to original layout after the algorithm is finished
"""
@@ -55,10 +52,7 @@ def _convert_to_layout(
return input_tensor
-def _convert_from_layout(
- input_tensor: np.ndarray,
- layout: str,
-) -> np.ndarray:
+def _convert_from_layout(input_tensor: np.ndarray, layout: str) -> np.ndarray:
"""
Converts tensor to one of suppored layouts
"""
@@ -208,6 +202,6 @@ def poolnd_python(
# All padded values, default to 0
ret_np[output_slice] = np.max(pad_data[np_index],
axis=reduction_axis)
else:
- raise ValueError("Pool type {} is not supported".format(pool_type))
+ raise ValueError(f"Pool type {pool_type} is not supported")
return _convert_to_layout(ret_np, layout)
diff --git a/python/tvm/topi/testing/resize_python.py
b/python/tvm/topi/testing/resize_python.py
index 13b460f07e..7d2cce8bd2 100644
--- a/python/tvm/topi/testing/resize_python.py
+++ b/python/tvm/topi/testing/resize_python.py
@@ -32,7 +32,7 @@ def get_inx(x, image_width, target_width,
coordinate_transformation_mode):
in_x = scale * x
else:
raise ValueError(
- "Unsupported coordinate_transformation_mode:
{}".format(coordinate_transformation_mode)
+ f"Unsupported coordinate_transformation_mode:
{coordinate_transformation_mode}"
)
return in_x
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index e4fe3c5839..934470fe23 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -85,9 +85,8 @@ def expand_like(a, shape_like, axis):
# A special case: `a` is a scalar represented as a 1-dim tensor
return te.compute(shape_like.shape, lambda *idxs: a(0))
raise ValueError(
- "shape inconsistent when expand_like ({}, {}, {})".format(
- len(axis), len(a.shape), len(shape_like.shape)
- )
+ f"shape inconsistent when expand_like ({len(axis)}, "
+ f"{len(a.shape)}, {len(shape_like.shape)})"
)
real_axis = topi.reduction._get_real_axis(len(shape_like.shape), axis)
@@ -710,10 +709,8 @@ def sequence_mask(data, valid_length, mask_value=0,
axis=0):
depending on the value of `axis`.
"""
- assert len(data.shape) >= 2, "only support data.ndim >= 2, received
data.shape = {}".format(
- data.shape
- )
- assert axis in (0, 1), "only support axis = 0, 1, received axis =
{}".format(axis)
+ assert len(data.shape) >= 2, f"only support data.ndim >= 2, received
data.shape = {data.shape}"
+ assert axis in (0, 1), f"only support axis = 0, 1, received axis = {axis}"
return cpp.sequence_mask(data, valid_length, mask_value, axis)
diff --git a/python/tvm/topi/x86/conv2d_alter_op.py
b/python/tvm/topi/x86/conv2d_alter_op.py
index 032c0e2e23..3772aaec04 100644
--- a/python/tvm/topi/x86/conv2d_alter_op.py
+++ b/python/tvm/topi/x86/conv2d_alter_op.py
@@ -113,10 +113,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
# update new attrs
new_attrs["channels"] = out_channel
- new_attrs["data_layout"] = "NCHW%dc" % ic_bn
+ new_attrs["data_layout"] = f"NCHW{ic_bn}c"
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
- new_attrs["kernel_layout"] = "OIHW%di%do" % (ic_bn, oc_bn)
- new_attrs["out_layout"] = "NCHW%dc" % oc_bn
+ new_attrs["kernel_layout"] = f"OIHW{ic_bn}i{oc_bn}o"
+ new_attrs["out_layout"] = f"NCHW{oc_bn}c"
# Store altered operator's config
new_data = te.placeholder(
@@ -169,9 +169,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
# update new attrs
n_elems = 4
new_attrs["channels"] = out_channel
- new_attrs["data_layout"] = "NCHW%dc" % ic_bn
- new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn //
n_elems, oc_bn, n_elems)
- new_attrs["out_layout"] = "NCHW%dc" % oc_bn
+ new_attrs["data_layout"] = f"NCHW{ic_bn}c"
+ new_attrs["kernel_layout"] = f"OIHW{ic_bn //
n_elems:n}i{oc_bn:n}o{n_elems:n}i"
+ new_attrs["out_layout"] = f"NCHW{oc_bn}c"
# Store altered operator's config.
new_data = te.placeholder(
@@ -220,9 +220,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
# update new attrs
new_attrs["channels"] = out_channel
- new_attrs["data_layout"] = "NCHW%dc" % ic_bn
- new_attrs["kernel_layout"] = "OIHW1i%do" % oc_bn
- new_attrs["out_layout"] = "NCHW%dc" % oc_bn
+ new_attrs["data_layout"] = f"NCHW{ic_bn}c"
+ new_attrs["kernel_layout"] = f"OIHW1i{oc_bn}o"
+ new_attrs["out_layout"] = f"NCHW{oc_bn}c"
# Store altered operator's config.
new_data = te.placeholder(
diff --git a/python/tvm/topi/x86/conv2d_avx_1x1.py
b/python/tvm/topi/x86/conv2d_avx_1x1.py
index 47a6016c52..047377f83e 100644
--- a/python/tvm/topi/x86/conv2d_avx_1x1.py
+++ b/python/tvm/topi/x86/conv2d_avx_1x1.py
@@ -61,7 +61,7 @@ def _fallback_schedule(cfg, wkl):
cfg["tile_oh"] = OtherOptionEntity(oh_factor)
cfg["tile_ow"] = SplitEntity([out_width // ow_factor,
ow_factor])
return
- raise ValueError("cannot decide default schedule for workload:
{}".format(wkl))
+ raise ValueError(f"cannot decide default schedule for workload: {wkl}")
def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec, conv_out, last):
@@ -145,7 +145,7 @@ def _schedule_conv_NCHWc(s, cfg, data_vec, kernel_vec,
conv_out, last):
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
else:
- raise ValueError("Unsupported output ndim: %s" % out_ndim)
+ raise ValueError(f"Unsupported output ndim: {out_ndim}")
return s
diff --git a/python/tvm/topi/x86/conv2d_int8.py
b/python/tvm/topi/x86/conv2d_int8.py
index 005d374b8a..9d32534352 100644
--- a/python/tvm/topi/x86/conv2d_int8.py
+++ b/python/tvm/topi/x86/conv2d_int8.py
@@ -250,11 +250,11 @@ def schedule_conv2d_nhwc_pack_int8(cfg, outs):
if kh == 1 and kw == 1:
conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args)
else:
- raise ValueError("Only support 1x1 kernel with "
"schedule_conv2d_nhwc_pack.")
+ raise ValueError("Only support 1x1 kernel with
schedule_conv2d_nhwc_pack.")
else:
raise ValueError(
- "Not support this data type {} with "
- "schedule_conv2d_nhwc_pack. Only support
int8".format(data.dtype)
+ f"Not support this data type {data.dtype} with "
+ f"schedule_conv2d_nhwc_pack. Only support int8"
)
scheduled_ops.append(op)
diff --git a/python/tvm/topi/x86/conv3d.py b/python/tvm/topi/x86/conv3d.py
index 5fc6963fc8..20f2c4ac12 100644
--- a/python/tvm/topi/x86/conv3d.py
+++ b/python/tvm/topi/x86/conv3d.py
@@ -277,15 +277,7 @@ def _conv3d_ndhwc(cfg, data, kernel, strides, padding,
dilation, groups, out_dty
ci_tile += 1
# pack kernel
- shape = (
- num_filter // oc_bn,
- ci_tile,
- kernel_depth,
- kernel_height,
- kernel_width,
- ic_bn,
- oc_bn,
- )
+ shape = (num_filter // oc_bn, ci_tile, kernel_depth, kernel_height,
kernel_width, ic_bn, oc_bn)
kernel_vec = te.compute(
shape,
lambda CO, CI, d, h, w, ci, co: kernel[d, h, w, CI * ic_bn + ci, CO *
oc_bn + co],
@@ -398,15 +390,7 @@ def _conv3d_ncdhw(cfg, data, kernel, strides, padding,
dilation, layout, groups,
ci_tile += 1
# pack kernel
- shape = (
- num_filter // oc_bn,
- ci_tile,
- kernel_depth,
- kernel_height,
- kernel_width,
- ic_bn,
- oc_bn,
- )
+ shape = (num_filter // oc_bn, ci_tile, kernel_depth, kernel_height,
kernel_width, ic_bn, oc_bn)
kernel_vec = te.compute(
shape,
lambda CO, CI, d, h, w, ci, co: kernel[CO * oc_bn + co, CI * ic_bn +
ci, d, h, w],
@@ -472,7 +456,7 @@ def _create_tuning_space(cfg, data, kernel, strides,
padding, dilation, groups,
n, ic, d, h, w = dshape
oc, _, kd, kh, kw = kshape
else:
- raise ValueError("Not support this layout {} with " "schedule
template.".format(layout))
+ raise ValueError(f"Not support this layout {layout} with schedule
template.")
# pad_front, pad_top, pad_left, pad_back, pad_down(bottom), pad_right
pf, pt, pl, pb, pd, pr = get_pad_tuple3d(padding, (kd, kh, kw))
@@ -493,7 +477,7 @@ def _get_default_config(cfg, data, kernel, strides,
padding, groups, out_dtype,
Get default schedule config for the workload
"""
if layout not in ["NDHWC", "NCDHW"]:
- raise ValueError("Layout {} is not supported".format(layout))
+ raise ValueError(f"Layout {layout} is not supported")
static_data_shape = []
for dim in get_const_tuple(data.shape):
@@ -515,7 +499,7 @@ def _get_conv3d_workload(data, kernel, stride, padding,
groups, out_dtype, data_
_, ID, IH, IW, CI = get_const_tuple(data.shape)
KD, KH, KW, CIG, CO = get_const_tuple(kernel.shape)
else:
- raise ValueError("not support this layout {} yet".format(data_layout))
+ raise ValueError(f"not support this layout {data_layout} yet")
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right =
get_pad_tuple3d(
padding, (get_const_int(KD), get_const_int(KH), get_const_int(KW))
@@ -529,10 +513,7 @@ def _get_conv3d_workload(data, kernel, stride, padding,
groups, out_dtype, data_
DSTR, HSTR, WSTR = stride, stride, stride
assert (data.dtype == kernel.dtype) or (
data.dtype == "uint8" and kernel.dtype == "int8"
- ), "Do not support inputs with different data types now. ' \
- '{} vs. {}".format(
- data.dtype, kernel.dtype
- )
+ ), f"Do not support inputs with different data types now. {data.dtype} vs.
{kernel.dtype}"
return Workload3D(
data.dtype,
out_dtype,
diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py
index 7c3942fabc..734c9f6e70 100644
--- a/python/tvm/topi/x86/nn.py
+++ b/python/tvm/topi/x86/nn.py
@@ -43,10 +43,7 @@ def _schedule_softmax(softmax_op, s, outs):
axis = int(softmax_op.attrs["axis"])
else:
raise ValueError(
- "Tag is expected to be softmax_output or log_softmax_output. \
- Got {0}".format(
- op_tag
- )
+ f"Tag is expected to be softmax_output or log_softmax_output. Got
{op_tag}"
)
output = outs[0]