This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 38e726aab1 [Relax][PyTorch] Cleanup unary op converters (#17356)
38e726aab1 is described below
commit 38e726aab191d5c16a7d98b2191a5f97f7fef410
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Thu Sep 12 04:18:07 2024 +0900
[Relax][PyTorch] Cleanup unary op converters (#17356)
* classify into 9 types of ops
* introduce `_unary_op()`
* cleanup `_clamp()`
* cleanup `_gelu()`
* cleanup `_hardsigmoid()` and `_hardswish()`
* cleanup `_leakyrelu()`
* cleanup `_log_softmax()`
* cleanup `_round()`
* cleanup `_softmax()`
* cleanup `_tril_triu()`
* replace `fx.node.Node` with `fx.Node`
---
python/tvm/relax/frontend/torch/fx_translator.py | 566 ++++++++++++-----------
1 file changed, 288 insertions(+), 278 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index aed38d7c49..8d66343254 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -35,7 +35,7 @@ class TorchFXImporter:
import torch # type: ignore
from torch import fx
- self.env: Dict[fx.node.Node, relax.Expr] = {}
+ self.env: Dict[fx.Node, relax.Expr] = {}
self.params: Dict[torch.Tensor, relax.Expr] = {}
self.named_modules: Dict[str, torch.Module] = None
self.block_builder: relax.BlockBuilder = None
@@ -108,7 +108,7 @@ class TorchFXImporter:
def _retrieve_args(self, node):
from torch import fx
- if isinstance(node, fx.node.Node):
+ if isinstance(node, fx.Node):
return self.env[node]
elif isinstance(node, tuple):
return tuple(self._retrieve_args(x) for x in node)
@@ -136,33 +136,113 @@ class TorchFXImporter:
lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs)
return self.block_builder.emit(op(lhs, rhs))
- ########## Arithmetic ##########
+ ########## Unary Ops ##########
- def _exp(self, node: fx.node.Node) -> relax.Var:
- return self.block_builder.emit(relax.op.exp(self.env[node.args[0]]))
+ def _unary_op(self, op: Callable) -> Callable:
+ from torch import fx
- def _sigmoid(self, node: fx.node.Node) -> relax.Var:
- return
self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]]))
+ def convert(node: fx.Node) -> relax.Var:
+ return self.block_builder.emit(op(self.env[node.args[0]]))
- def _sqrt(self, node: fx.node.Node) -> relax.Expr:
- arg = self.env[node.args[0]]
- if isinstance(arg, (int, float)):
- arg = relax.const(arg, "float32")
- return self.block_builder.emit(relax.op.sqrt(arg))
+ return convert
- def _rsqrt(self, node: fx.node.Node) -> relax.Expr:
- arg = self.env[node.args[0]]
- if isinstance(arg, (int, float)):
- arg = relax.const(arg, "float32")
- return self.block_builder.emit(relax.op.rsqrt(arg))
+ def _clamp(self, node: fx.Node) -> relax.Expr:
+ args = self.retrieve_args(node)
+ a_min = args[1] if len(args) > 1 else node.kwargs["min"]
+ a_max = args[2] if len(args) > 2 else node.kwargs["max"]
+ if not isinstance(a_min, (int, float)):
+ raise ValueError(
+ f"TVM only supports constant min value for torch.clamp/clip, "
+ f"but got {a_min} with type {type(a_min)}"
+ )
+ if not isinstance(a_max, (int, float)):
+ raise ValueError(
+ f"TVM only supports constant max value for torch.clamp/clip, "
+ f"but got {a_max} with type {type(a_max)}"
+ )
+ return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
+
+ def _gelu(self, node: fx.Node) -> relax.Expr:
+ approximate = node.kwargs.get("approximate", "none")
+ if approximate == "none":
+ return
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
+ elif approximate == "tanh":
+ return
self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
+ else:
+ raise KeyError("Unregonized approximate algorithm for gelu:
{}.".format(approximate))
+
+ def _hardsigmoid(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dtype = x.struct_info.dtype
+ x0 = relax.op.add(x, relax.const(3, dtype))
+ x1 = relax.op.clip(x0, 0, 6)
+ return self.block_builder.emit(relax.op.divide(x1, relax.const(6,
dtype)))
+
+ def _hardswish(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dtype = x.struct_info.dtype
+ x0 = relax.op.add(x, relax.const(3, dtype))
+ x1 = relax.op.clip(x0, 0, 6)
+ x2 = relax.op.divide(x1, relax.const(6, dtype))
+ return self.block_builder.emit(relax.op.multiply(x, x2))
+
+ def _leakyrelu(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ alpha = node.args[1] if len(node.args) > 1 else
node.kwargs.get("negative_slope", 0.01)
+ return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
+
+ def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ alpha = module.negative_slope
+ return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
+
+ def _log_softmax(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
+ return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
+
+ def _log_softmax_module(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ dim = module.dim
+ assert dim is not None
+ return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
- def _round(self, node: fx.node.Node) -> relax.Expr:
- if "decimals" in node.kwargs and node.kwargs["decimals"] != 0:
+ def _round(self, node: fx.Node) -> relax.Expr:
+ if node.kwargs.get("decimals", 0) != 0:
raise ValueError("specifying decimals for round is not supported
yet")
arg = self.env[node.args[0]]
return self.block_builder.emit(relax.op.round(arg))
- def _add(self, node: fx.node.Node) -> relax.Expr:
+ def _softmax(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
+ return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+
+ def _softmax_module(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ dim = module.dim
+ assert dim is not None
+ return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+
+ def _tril_triu(self, op: Callable) -> Callable:
+ from torch import fx
+
+ def convert(node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ k = node.args[1] if len(node.args) > 1 else
node.kwargs.get("diagonal", 0)
+ assert isinstance(k, int)
+ return self.block_builder.emit(op(x, k))
+
+ return convert
+
+ ########## Arithmetic ##########
+
+ def _add(self, node: fx.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return self._call_binary_op(relax.op.add, lhs, rhs)
@@ -176,103 +256,54 @@ class TorchFXImporter:
)
return lhs + rhs
- def _max(self, node: fx.node.Node) -> relax.Expr:
+ def _max(self, node: fx.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return self._call_binary_op(relax.op.maximum, lhs, rhs)
- def _floordiv(self, node: fx.node.Node) -> relax.Expr:
+ def _floordiv(self, node: fx.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return self._call_binary_op(relax.op.floor_divide, lhs, rhs)
return lhs // rhs
- def _mul(self, node: fx.node.Node) -> relax.Expr:
+ def _mul(self, node: fx.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return self._call_binary_op(relax.op.multiply, lhs, rhs)
return lhs * rhs
- def _pow(self, node: fx.node.Node) -> relax.Expr:
+ def _pow(self, node: fx.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return self._call_binary_op(relax.op.power, lhs, rhs)
return lhs**rhs
- def _neg(self, node: fx.node.Node) -> relax.Expr:
- x = self.env[node.args[0]]
- return self.block_builder.emit(relax.op.negative(x))
-
- def _sub(self, node: fx.node.Node) -> relax.Expr:
+ def _sub(self, node: fx.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return self._call_binary_op(relax.op.subtract, lhs, rhs)
return lhs - rhs
- def _truediv(self, node: fx.node.Node) -> relax.Expr:
+ def _truediv(self, node: fx.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return self._call_binary_op(relax.op.divide, lhs, rhs)
return lhs / rhs
- def _clamp(self, node: fx.node.Node) -> relax.Expr:
- args = self.retrieve_args(node)
- a_min = node.kwargs["min"]
- a_max = node.kwargs["max"]
- if not isinstance(a_min, (int, float)):
- raise ValueError(
- f"TVM only supports constant min value for torch.clamp/clip, "
- f"but got {a_min} with type {type(a_min)}"
- )
- if not isinstance(a_max, (int, float)):
- raise ValueError(
- f"TVM only supports constant max value for torch.clamp/clip, "
- f"but got {a_max} with type {type(a_max)}"
- )
- return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
-
- def _gelu(self, node: fx.node.Node) -> relax.Expr:
- if "approximate" not in node.kwargs:
- approximate = "none"
- else:
- approximate = node.kwargs["approximate"]
- if approximate == "none":
- return
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
- elif approximate == "tanh":
- return
self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
- else:
- raise KeyError("Unregonized approximate algorithm for gelu:
{}.".format(approximate))
-
- def _hardsigmoid(self, node: fx.node.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- dtype = x.struct_info.dtype
- x0 = relax.op.add(x, relax.const(3, dtype))
- x1 = relax.op.clip(x0, 0, 6)
- return self.block_builder.emit(relax.op.divide(x1, relax.const(6,
dtype)))
-
- def _hardswish(self, node: fx.node.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- dtype = x.struct_info.dtype
- x0 = relax.op.add(x, relax.const(3, dtype))
- x1 = relax.op.clip(x0, 0, 6)
- x2 = relax.op.divide(x1, relax.const(6, dtype))
- return self.block_builder.emit(relax.op.multiply(x, x2))
-
########## Compare ##########
- def _lt(self, node: fx.node.Node) -> relax.Expr:
+ def _lt(self, node: fx.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
return self._call_binary_op(relax.op.less, lhs, rhs)
- def _eq(self, node: fx.node.Node) -> relax.Expr:
+ def _eq(self, node: fx.Node) -> relax.Expr:
lhs, rhs = self.retrieve_args(node)
return self._call_binary_op(relax.op.equal, lhs, rhs)
########## Creation ##########
- def _arange(self, node: fx.node.Node) -> relax.Var:
+ def _arange(self, node: fx.Node) -> relax.Var:
import torch
start_end_step = [None, None, None]
@@ -311,15 +342,15 @@ class TorchFXImporter:
else:
dtype = "int64"
start_end_step = [
- self.env[x] if isinstance(x, torch.fx.node.Node) else x for x in
start_end_step
+ self.env[x] if isinstance(x, torch.fx.Node) else x for x in
start_end_step
]
return self.block_builder.emit(relax.op.arange(*start_end_step,
dtype=dtype))
- def _empty(self, node: fx.node.Node) -> relax.Var:
+ def _empty(self, node: fx.Node) -> relax.Var:
dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]),
self.env)
return self.block_builder.emit(relax.op.zeros(node.args, dtype))
- def _inplace_fill(self, node: fx.node.Node) -> relax.Var:
+ def _inplace_fill(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dtype = x.struct_info.dtype
@@ -328,7 +359,7 @@ class TorchFXImporter:
self.env[node.args[0]] = filled
return filled
- def _tensor(self, node: fx.node.Node) -> relax.Var:
+ def _tensor(self, node: fx.Node) -> relax.Var:
dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None
if isinstance(node.args[0], float):
return relax.const(node.args[0], dtype if dtype is not None else
"float32")
@@ -336,21 +367,10 @@ class TorchFXImporter:
return relax.const(node.args[0], dtype if dtype is not None else
"int64")
raise ValueError("torch.tensor with value not a float or int is not
accepted")
- def _tril_triu(self, op: Callable) -> Callable:
- from torch import fx
-
- def convert(node: fx.node.Node) -> relax.Var:
- x = self.env[node.args[0]]
- k = node.args[1] if len(node.args) > 1 else 0
- assert isinstance(k, int)
- return self.block_builder.emit(op(x, k))
-
- return convert
-
def _inplace_tril_triu(self, op: Callable) -> Callable:
from torch import fx
- def convert(node: fx.node.Node) -> relax.Var:
+ def convert(node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
k = node.args[1] if len(node.args) > 1 else 0
assert isinstance(k, int)
@@ -361,7 +381,7 @@ class TorchFXImporter:
return convert
- def _new_ones(self, node: fx.node.Node) -> relax.Var:
+ def _new_ones(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
self_var = args[0]
size = args[1:]
@@ -376,7 +396,7 @@ class TorchFXImporter:
)
)
- def _ones(self, node: fx.node.Node) -> relax.Var:
+ def _ones(self, node: fx.Node) -> relax.Var:
import torch
args = self.retrieve_args(node)
@@ -397,7 +417,7 @@ class TorchFXImporter:
)
)
- def _full(self, node: fx.node.Node) -> relax.Var:
+ def _full(self, node: fx.Node) -> relax.Var:
import torch
args = self.retrieve_args(node)
@@ -421,14 +441,14 @@ class TorchFXImporter:
########## Statistical ##########
- def _sum(self, node: fx.node.Node) -> relax.Var:
+ def _sum(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
return self.block_builder.emit(relax.op.sum(args[0],
keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))
- def _mean(self, node: fx.node.Node) -> relax.Var:
+ def _mean(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
@@ -437,18 +457,18 @@ class TorchFXImporter:
########## DataType ##########
- def _float(self, node: fx.node.Node) -> relax.Var:
+ def _float(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float32"))
- def _half(self, node: fx.node.Node) -> relax.Var:
+ def _half(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float16"))
- def _type(self, node: fx.node.Node) -> relax.Var:
+ def _type(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
return self.block_builder.emit(relax.op.astype(x, dtype))
- def _to(self, node: fx.node.Node) -> relax.Var:
+ def _to(self, node: fx.Node) -> relax.Var:
import torch
x = self.env[node.args[0]]
@@ -466,7 +486,7 @@ class TorchFXImporter:
def _matmul_impl(self, a: relax.Expr, b: relax.Expr):
return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b,
out_dtype="float32"))
- def _matmul(self, node: fx.node.Node) -> relax.Var:
+ def _matmul(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
res = self._matmul_impl(
args[0],
@@ -474,7 +494,7 @@ class TorchFXImporter:
)
return res
- def _addmm(self, node: fx.node.Node) -> relax.Var:
+ def _addmm(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
y = self.env[node.args[1]]
z = self.env[node.args[2]]
@@ -496,7 +516,7 @@ class TorchFXImporter:
res = bias if res is None else
self.block_builder.emit(relax.op.add(bias, res))
return res
- def _baddbmm(self, node: fx.node.Node) -> relax.Var:
+ def _baddbmm(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
a = self.env[node.args[1]]
b = self.env[node.args[2]]
@@ -518,7 +538,7 @@ class TorchFXImporter:
res = bias if res is None else
self.block_builder.emit(relax.op.add(res, bias))
return res
- def _einsum(self, node: fx.node.Node) -> relax.Var:
+ def _einsum(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
args = self.retrieve_args(node)
@@ -526,7 +546,7 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.einsum(tuple(args[1]),
args[0]))
return self.block_builder.emit(relax.op.einsum(args[1:], args[0]))
- def _unbind(self, node: fx.node.Node) -> relax.Var:
+ def _unbind(self, node: fx.Node) -> relax.Var:
if len(node.args) == 2:
assert isinstance(node.args[1], int), "Expected 2nd argument of
unbind as int"
dim = node.args[1]
@@ -544,12 +564,12 @@ class TorchFXImporter:
########## Manipulation ##########
- def _cat(self, node: fx.node.Node) -> relax.Var:
+ def _cat(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
- def _expand(self, node: fx.node.Node) -> relax.Var:
+ def _expand(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
broadcast_shape, in_shape = [], self.shape_of(args[0])
for idx, i in enumerate(args[1:]):
@@ -559,7 +579,7 @@ class TorchFXImporter:
broadcast_shape.append(i)
return self.block_builder.emit(relax.op.broadcast_to(args[0],
broadcast_shape))
- def _flatten(self, node: fx.node.Node) -> relax.Var:
+ def _flatten(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
if node.target in self.named_modules:
module = self.named_modules[node.target]
@@ -579,7 +599,7 @@ class TorchFXImporter:
)
return self.block_builder.emit(relax.op.reshape(x, new_shape))
- def _permute(self, node: fx.node.Node) -> relax.Var:
+ def _permute(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
args = self.retrieve_args(node)
@@ -587,7 +607,7 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.permute_dims(args[0],
tuple(args[1])))
return self.block_builder.emit(relax.op.permute_dims(args[0],
args[1:]))
- def _reshape(self, node: fx.node.Node) -> relax.Var:
+ def _reshape(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
args = self.retrieve_args(node)
@@ -595,7 +615,7 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.reshape(args[0],
tuple(args[1])))
return self.block_builder.emit(relax.op.reshape(args[0], args[1:]))
- def _split(self, node: fx.node.Node) -> relax.Var:
+ def _split(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
split_size = node.args[1]
if "dim" in node.kwargs:
@@ -611,7 +631,7 @@ class TorchFXImporter:
n_section = (self.shape_of(x)[dim].value + split_size - 1) //
split_size
return self.block_builder.emit(relax.op.split(x, n_section, dim))
- def _chunk(self, node: fx.node.Node) -> relax.Var:
+ def _chunk(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
chunks = node.args[1]
@@ -623,13 +643,13 @@ class TorchFXImporter:
dim = 0
return self.block_builder.emit(relax.op.split(x, chunks, dim))
- def _transpose(self, node: fx.node.Node) -> relax.Var:
+ def _transpose(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
full_idx = list(range(len(self.shape_of(args[0]))))
full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]],
full_idx[args[1]]
return self.block_builder.emit(relax.op.permute_dims(args[0],
full_idx))
- def _squeeze(self, node: fx.node.Node) -> relax.Var:
+ def _squeeze(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
if "dim" in node.kwargs:
@@ -640,7 +660,7 @@ class TorchFXImporter:
dim = None
return self.block_builder.emit(relax.op.squeeze(x, dim))
- def _repeat(self, node: fx.node.Node) -> relax.Var:
+ def _repeat(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
args = self.retrieve_args(node)
@@ -648,7 +668,7 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.tile(args[0],
tuple(args[1])))
return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
- def _tile(self, node: fx.node.Node) -> relax.Var:
+ def _tile(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
args = self.retrieve_args(node)
@@ -656,7 +676,7 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.tile(args[0],
tuple(args[1])))
return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
- def _cumsum(self, node: fx.node.Node) -> relax.Var:
+ def _cumsum(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
if "dim" in node.kwargs:
@@ -674,13 +694,13 @@ class TorchFXImporter:
return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
- def _index_select(self, node: fx.node.Node) -> relax.Var:
+ def _index_select(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1]
index = self.env[node.args[2]]
return self.block_builder.emit(relax.op.take(x, index, dim))
- def _masked_fill(self, node: fx.node.Node) -> relax.Var:
+ def _masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
value = node.args[2]
@@ -688,7 +708,7 @@ class TorchFXImporter:
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
return self.block_builder.emit(relax.op.where(mask, values, x))
- def _inplace_masked_fill(self, node: fx.node.Node) -> relax.Var:
+ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
value = node.args[2]
@@ -703,7 +723,7 @@ class TorchFXImporter:
def _argmax_argmin(self, op: Callable) -> Callable:
from torch import fx
- def convert(node: fx.node.Node):
+ def convert(node: fx.Node):
x = self.env[node.args[0]]
dim = None
keepdims = False
@@ -726,14 +746,14 @@ class TorchFXImporter:
########## Neural Network ##########
- def _linear(self, node: fx.node.Node) -> relax.Var:
+ def _linear(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = None if module.bias is None else self.params[module.bias]
return self.block_builder.emit(relax.op.linear(x, weight, bias,
"float32"))
- def _linear_functional(self, node: fx.node.Node) -> relax.Var:
+ def _linear_functional(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
@@ -770,7 +790,7 @@ class TorchFXImporter:
bias = relax.op.reshape(bias, (1, -1, 1))
return self.block_builder.emit(relax.op.add(conv1d, bias))
- def _conv1d(self, node: fx.node.Node) -> relax.Var:
+ def _conv1d(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
@@ -788,7 +808,7 @@ class TorchFXImporter:
groups=module.groups,
)
- def _conv1d_functional(self, node: fx.node.Node) -> relax.Var:
+ def _conv1d_functional(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
@@ -838,7 +858,7 @@ class TorchFXImporter:
bias = relax.op.reshape(bias, (1, -1, 1))
return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
- def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var:
+ def _conv1d_transpose(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
@@ -856,7 +876,7 @@ class TorchFXImporter:
groups=module.groups,
)
- def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var:
+ def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
@@ -905,7 +925,7 @@ class TorchFXImporter:
bias = relax.op.reshape(bias, (1, -1, 1, 1))
return self.block_builder.emit(relax.op.add(conv2d, bias))
- def _conv2d(self, node: fx.node.Node) -> relax.Var:
+ def _conv2d(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
@@ -923,7 +943,7 @@ class TorchFXImporter:
groups=module.groups,
)
- def _conv2d_functional(self, node: fx.node.Node) -> relax.Var:
+ def _conv2d_functional(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
@@ -973,7 +993,7 @@ class TorchFXImporter:
bias = relax.op.reshape(bias, (1, -1, 1, 1))
return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
- def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var:
+ def _conv2d_transpose(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
@@ -991,7 +1011,7 @@ class TorchFXImporter:
groups=module.groups,
)
- def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var:
+ def _conv2d_transpose_functional(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
@@ -1040,7 +1060,7 @@ class TorchFXImporter:
bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
return self.block_builder.emit(relax.op.add(conv3d, bias))
- def _conv3d(self, node: fx.node.Node) -> relax.Var:
+ def _conv3d(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
@@ -1058,7 +1078,7 @@ class TorchFXImporter:
groups=module.groups,
)
- def _conv3d_functional(self, node: fx.node.Node) -> relax.Var:
+ def _conv3d_functional(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
weight = args[1]
@@ -1077,7 +1097,7 @@ class TorchFXImporter:
groups=groups,
)
- def _max_pool2d(self, node: fx.node.Node) -> relax.Var:
+ def _max_pool2d(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
if node.target in self.named_modules:
module = self.named_modules[node.target]
@@ -1108,7 +1128,7 @@ class TorchFXImporter:
)
)
- def _avg_pool2d(self, node: fx.node.Node) -> relax.Var:
+ def _avg_pool2d(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
if node.target in self.named_modules:
module = self.named_modules[node.target]
@@ -1154,7 +1174,7 @@ class TorchFXImporter:
def _adaptive_avg_pool2d(self, is_module: bool) -> Callable:
from torch import fx
- def _impl(node: fx.node.Node) -> relax.Var:
+ def _impl(node: fx.Node) -> relax.Var:
if is_module:
module = self.named_modules[node.target]
x = self.env[node.args[0]]
@@ -1168,7 +1188,7 @@ class TorchFXImporter:
return _impl
- def _softmax(self, node: fx.node.Node) -> relax.Var:
+ def _softmax(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
if node.target in self.named_modules:
module = self.named_modules[node.target]
@@ -1179,29 +1199,7 @@ class TorchFXImporter:
assert dim is not None
return self.block_builder.emit(relax.op.nn.softmax(x, dim))
- def _log_softmax(self, node: fx.node.Node) -> relax.Var:
- x = self.env[node.args[0]]
- if node.target in self.named_modules:
- module = self.named_modules[node.target]
- dim = module.dim
- else:
- nargs = len(node.args)
- dim = node.args[1] if nargs > 1 else node.kwargs["dim"]
- assert dim is not None
- return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))
-
- def _leakyrelu(self, node: fx.node.Node) -> relax.Var:
- x = self.env[node.args[0]]
- if node.target in self.named_modules:
- module = self.named_modules[node.target]
- alpha = module.negative_slope
- else:
- nargs = len(node.args)
- alpha = node.args[1] if nargs > 1 else
node.kwargs["negative_slope"]
- assert alpha is not None
- return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))
-
- def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var:
+ def _batch_norm_2d(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
@@ -1224,7 +1222,7 @@ class TorchFXImporter:
return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
- def _layer_norm(self, node: fx.node.Node) -> relax.Var:
+ def _layer_norm(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
from torch.fx.immutable_collections import immutable_list
import numpy as np # type: ignore
@@ -1291,7 +1289,7 @@ class TorchFXImporter:
)
)
- def _group_norm(self, node: fx.node.Node) -> relax.Var:
+ def _group_norm(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
x = self.env[node.args[0]]
@@ -1317,7 +1315,7 @@ class TorchFXImporter:
)
)
- def _embedding(self, node: fx.node.Node) -> relax.Var:
+ def _embedding(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
@@ -1333,7 +1331,7 @@ class TorchFXImporter:
embedding = self.block_builder.emit(relax.op.take(weight, x,
axis=0))
return self.block_builder.emit(relax.op.reshape(embedding,
[*x_shape, emb_size]))
- def _interpolate(self, node: fx.node.Node) -> relax.Var:
+ def _interpolate(self, node: fx.Node) -> relax.Var:
# torch.nn.functional.interpolate(
# input, size=None, scale_factor=None, mode='nearest',
align_corners=None,
# recompute_scale_factor=None, antialias=False)
@@ -1407,7 +1405,7 @@ class TorchFXImporter:
)
)
- def _cross_entropy(self, node: fx.node.Node) -> relax.Expr:
+ def _cross_entropy(self, node: fx.Node) -> relax.Expr:
preds = self.env[node.args[0]]
targets = self.env[node.args[1]]
@@ -1442,7 +1440,7 @@ class TorchFXImporter:
)
)
- def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var:
+ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
assert (
len(node.args) <= 4
), "Dropout is not supported, and is_causal should be called by
kwargs."
@@ -1464,13 +1462,13 @@ class TorchFXImporter:
########## Others ##########
- def _sym_size_int(self, node: fx.node.Node) -> relax.Expr:
+ def _sym_size_int(self, node: fx.Node) -> relax.Expr:
x = self.env[node.args[0]]
shape = self.shape_of(x)
idx = node.args[1]
return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
- def _size(self, node: fx.node.Node) -> relax.Expr:
+ def _size(self, node: fx.Node) -> relax.Expr:
x = self.env[node.args[0]]
shape = self.shape_of(x)
if len(node.args) == 1:
@@ -1480,7 +1478,7 @@ class TorchFXImporter:
idx = node.args[1]
return self.shape_of(x)[idx].value
- def _getattr(self, node: fx.node.Node) -> relax.Var:
+ def _getattr(self, node: fx.Node) -> relax.Var:
if isinstance(self.env[node.args[0]], relax.Expr):
if node.args[1] == "dtype":
return self.env[node.args[0]].struct_info.dtype
@@ -1488,7 +1486,7 @@ class TorchFXImporter:
return self.shape_of(self.env[node.args[0]])
return getattr(self.env[node.args[0]], node.args[1])
- def _getitem(self, node: fx.node.Node) -> relax.Var:
+ def _getitem(self, node: fx.Node) -> relax.Var:
import torch
x = self.env[node.args[0]]
@@ -1510,7 +1508,7 @@ class TorchFXImporter:
shape = self.shape_of(x)
non_ellipsis_cnt = 0
for index in node.args[1]:
- if isinstance(index, (int, slice, torch.fx.node.Node)):
+ if isinstance(index, (int, slice, torch.fx.Node)):
non_ellipsis_cnt += 1
for index in node.args[1]:
if isinstance(index, int):
@@ -1534,7 +1532,7 @@ class TorchFXImporter:
stride.append(1)
stride_axes.append(i)
i += 1
- elif isinstance(index, torch.fx.node.Node):
+ elif isinstance(index, torch.fx.Node):
node_index = self.env[index]
if not isinstance(node_index, relax.Expr):
raise ValueError(
@@ -1573,142 +1571,154 @@ class TorchFXImporter:
from torch import nn
from torch import fx
- self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node],
relax.Var]] = {
- # call_module
- nn.Linear: self._linear,
+ self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.Node],
relax.Var]] = {
+ ## call_module
+ # unary
+ nn.Dropout: lambda node: self.env[node.args[0]],
+ nn.GELU: self._gelu,
+ nn.Hardsigmoid: self._hardsigmoid,
+ nn.Hardswish: self._hardswish,
+ nn.Identity: lambda node: self.env[node.args[0]],
+ nn.LeakyReLU: self._leakyrelu_module,
+ nn.LogSoftmax: self._log_softmax_module,
+ nn.ReLU: self._unary_op(relax.op.nn.relu),
+ nn.ReLU6: lambda node: self.block_builder.emit(
+ relax.op.clip(self.env[node.args[0]], 0, 6)
+ ),
+ nn.Sigmoid: self._unary_op(relax.op.sigmoid),
+ nn.SiLU: self._unary_op(relax.op.nn.silu),
+ nn.Softmax: self._softmax_module,
+ nn.Tanh: self._unary_op(relax.op.tanh),
+ # neural network
+ nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True),
+ nn.AvgPool2d: self._avg_pool2d,
+ nn.BatchNorm2d: self._batch_norm_2d,
nn.Conv1d: self._conv1d,
nn.Conv2d: self._conv2d,
nn.Conv3d: self._conv3d,
nn.ConvTranspose1d: self._conv1d_transpose,
nn.ConvTranspose2d: self._conv2d_transpose,
- nn.MaxPool2d: self._max_pool2d,
- nn.AvgPool2d: self._avg_pool2d,
- nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True),
- nn.Softmax: self._softmax,
- nn.LogSoftmax: self._log_softmax,
- nn.ReLU: lambda node:
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
- nn.LeakyReLU: self._leakyrelu,
- nn.ReLU6: lambda node: self.block_builder.emit(
- relax.op.clip(self.env[node.args[0]], 0, 6)
- ),
- nn.GELU: self._gelu,
- nn.Sigmoid: self._sigmoid,
- nn.Tanh: lambda node:
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
- nn.SiLU: lambda node:
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
- nn.Hardsigmoid: self._hardsigmoid,
- nn.Hardswish: self._hardswish,
- nn.Flatten: self._flatten,
- nn.BatchNorm2d: self._batch_norm_2d,
- nn.LayerNorm: self._layer_norm,
+ nn.CrossEntropyLoss: self._cross_entropy,
nn.GroupNorm: self._group_norm,
- nn.Dropout: lambda node: self.env[node.args[0]],
- nn.Identity: lambda node: self.env[node.args[0]],
+ nn.LayerNorm: self._layer_norm,
+ nn.Linear: self._linear,
+ nn.MaxPool2d: self._max_pool2d,
nn.modules.sparse.Embedding: self._embedding,
- nn.CrossEntropyLoss: self._cross_entropy,
- # call_function and call_method
- "sin": lambda node:
self.block_builder.emit(relax.op.sin(self.env[node.args[0]])),
- "cos": lambda node:
self.block_builder.emit(relax.op.cos(self.env[node.args[0]])),
- "tan": lambda node:
self.block_builder.emit(relax.op.tan(self.env[node.args[0]])),
- "asin": lambda node:
self.block_builder.emit(relax.op.asin(self.env[node.args[0]])),
- "acos": lambda node:
self.block_builder.emit(relax.op.acos(self.env[node.args[0]])),
- "atan": lambda node:
self.block_builder.emit(relax.op.atan(self.env[node.args[0]])),
- "sinh": lambda node:
self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])),
- "cosh": lambda node:
self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])),
- "tanh": lambda node:
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
- "asinh": lambda node:
self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])),
- "acosh": lambda node:
self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])),
- "atanh": lambda node:
self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])),
- "exp": self._exp,
- "iadd": self._add,
+ # tensor manipulation
+ nn.Flatten: self._flatten,
+ ## call_function and call_method
+ # unary
+ "acos": self._unary_op(relax.op.acos),
+ "acosh": self._unary_op(relax.op.acosh),
+ "asin": self._unary_op(relax.op.asin),
+ "asinh": self._unary_op(relax.op.asinh),
+ "atan": self._unary_op(relax.op.atan),
+ "atanh": self._unary_op(relax.op.atanh),
+ "clamp": self._clamp,
+ "cos": self._unary_op(relax.op.cos),
+ "cosh": self._unary_op(relax.op.cosh),
+ "dropout": lambda node: self.env[node.args[0]],
+ "exp": self._unary_op(relax.op.exp),
+ "gelu": self._gelu,
+ "hardsigmoid": self._hardsigmoid,
+ "hardswish": self._hardswish,
+ "leaky_relu": self._leakyrelu,
+ "log_softmax": self._log_softmax,
+ "neg": self._unary_op(relax.op.negative),
+ "relu": self._unary_op(relax.op.nn.relu),
+ "round": self._round,
+ "rsqrt": self._unary_op(relax.op.rsqrt),
+ "sigmoid": self._unary_op(relax.op.sigmoid),
+ "silu": self._unary_op(relax.op.nn.silu),
+ "sin": self._unary_op(relax.op.sin),
+ "sinh": self._unary_op(relax.op.sinh),
+ "softmax": self._softmax,
+ "sqrt": self._unary_op(relax.op.sqrt),
+ "tan": self._unary_op(relax.op.tan),
+ "tanh": self._unary_op(relax.op.tanh),
+ "tril_": self._inplace_tril_triu(relax.op.tril),
+ "tril": self._tril_triu(relax.op.tril),
+ "triu_": self._inplace_tril_triu(relax.op.triu),
+ "triu": self._tril_triu(relax.op.triu),
+ # binary
"add": self._add,
+ "eq": self._eq,
"floordiv": self._floordiv,
+ "iadd": self._add,
+ "lt": self._lt,
+ "matmul": self._matmul,
+ "max": self._max,
"mul": self._mul,
- "sub": self._sub,
"pow": self._pow,
- "sigmoid": self._sigmoid,
- "sqrt": self._sqrt,
- "round": self._round,
- "lt": self._lt,
- "eq": self._eq,
+ "sub": self._sub,
"truediv": self._truediv,
- "fill_": self._inplace_fill,
- "new_ones": self._new_ones,
- "arange": self._arange,
- "empty": self._empty,
- "tensor": self._tensor,
- "tril": self._tril_triu(relax.op.tril),
- "triu": self._tril_triu(relax.op.triu),
- "tril_": self._inplace_tril_triu(relax.op.tril),
- "triu_": self._inplace_tril_triu(relax.op.triu),
- "sum": self._sum,
- "float": self._float,
- "half": self._half,
- "type": self._type,
- "astype": self._type,
- "matmul": self._matmul,
- "conv1d": self._conv1d_functional,
+ # neural network
+ "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
+ "addmm": self._addmm,
+ "avg_pool2d": self._avg_pool2d,
+ "baddbmm": self._baddbmm,
+ "bmm": self._matmul,
"conv_transpose1d": self._conv1d_transpose_functional,
- "conv2d": self._conv2d_functional,
"conv_transpose2d": self._conv2d_transpose_functional,
+ "conv1d": self._conv1d_functional,
+ "conv2d": self._conv2d_functional,
"conv3d": self._conv3d_functional,
+ "cross_entropy": self._cross_entropy,
+ "einsum": self._einsum,
+ "interpolate": self._interpolate,
+ "layer_norm": self._layer_norm,
"linear": self._linear_functional,
- "addmm": self._addmm,
- "baddbmm": self._baddbmm,
- "bmm": self._matmul,
+ "max_pool2d": self._max_pool2d,
+ "scaled_dot_product_attention": self._scaled_dot_product_attention,
+ "stochastic_depth": lambda node: self.env[node.args[0]],
+ "unbind": self._unbind,
+ # statistical
+ "mean": self._mean,
+ "sum": self._sum,
+ # search
+ "argmax": self._argmax_argmin(relax.op.argmax),
+ "argmin": self._argmax_argmin(relax.op.argmin),
+ # tensor manipulation
"cat": self._cat,
"concat": self._cat,
+ "contiguous": lambda node: self.env[node.args[0]],
+ "cumsum": self._cumsum,
"expand": self._expand,
"flatten": self._flatten,
"permute": self._permute,
"repeat": self._repeat,
"reshape": self._reshape,
+ "size": self._size,
"split": self._split,
+ "squeeze": self._squeeze,
"tile": self._tile,
- "cumsum": self._cumsum,
- "chunk": self._chunk,
"transpose": self._transpose,
- "squeeze": self._squeeze,
"unsqueeze": lambda node: self.block_builder.emit(
relax.op.expand_dims(self.env[node.args[0]], node.args[1])
),
"view": self._reshape,
- "argmax": self._argmax_argmin(relax.op.argmax),
- "argmin": self._argmax_argmin(relax.op.argmin),
- "softmax": self._softmax,
- "log_softmax": self._log_softmax,
- "dropout": lambda node: self.env[node.args[0]],
- "stochastic_depth": lambda node: self.env[node.args[0]],
- "clamp": self._clamp,
- "relu": lambda node:
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
- "leaky_relu": self._leakyrelu,
- "gelu": self._gelu,
- "silu": lambda node:
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
- "hardsigmoid": self._hardsigmoid,
- "hardswish": self._hardswish,
- "interpolate": self._interpolate,
- "sym_size.int": self._sym_size_int,
- "size": self._size,
- "getattr": self._getattr,
- "getitem": self._getitem,
- "contiguous": lambda node: self.env[node.args[0]],
- "to": self._to,
- "max_pool2d": self._max_pool2d,
- "avg_pool2d": self._avg_pool2d,
- "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
- "layer_norm": self._layer_norm,
+ # tensor creation
+ "arange": self._arange,
+ "chunk": self._chunk,
+ "empty": self._empty,
+ "fill_": self._inplace_fill,
+ "full": self._full,
"index_select": self._index_select,
+ "masked_fill_": self._inplace_masked_fill,
"masked_fill": self._masked_fill,
+ "new_ones": self._new_ones,
"ones": self._ones,
- "full": self._full,
- "masked_fill_": self._inplace_masked_fill,
- "mean": self._mean,
- "rsqrt": self._rsqrt,
- "neg": self._neg,
- "max": self._max,
- "cross_entropy": self._cross_entropy,
- "scaled_dot_product_attention": self._scaled_dot_product_attention,
- "einsum": self._einsum,
- "unbind": self._unbind,
+ "tensor": self._tensor,
+ "to": self._to,
+ # datatype
+ "astype": self._type,
+ "float": self._float,
+ "half": self._half,
+ "type": self._type,
+ # other
+ "getattr": self._getattr,
+ "getitem": self._getitem,
+ "sym_size.int": self._sym_size_int,
}
def update_convert_map(self, custom_convert_map: dict):