masahi commented on a change in pull request #7023:
URL: https://github.com/apache/tvm/pull/7023#discussion_r535916137
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -2377,18 +2023,428 @@ def _impl(inputs, input_types):
counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
return _op.scatter_add(counts, data, updates, axis=0)
- return _impl
-
-
-def _scatter_add():
- def _impl(inputs, input_types):
+ @staticmethod
+ def scatter_add(inputs, input_types):
data = inputs[0]
axis = inputs[1]
index = inputs[2]
src = inputs[3]
return _op.scatter_add(data, index, src, axis=axis)
- return _impl
+ # Operator mappings
+ def create_convert_map(self):
+ self.convert_map = {
+ "aten::pixel_shuffle": self.pixel_shuffle,
+ "aten::device": self.none,
+ "prim::device": self.none,
+ "aten::sub": self.make_elemwise("subtract"),
+ "aten::sub_": self.make_elemwise("subtract"),
+ "aten::max": self.max,
+ "aten::min": self.min,
+ "aten::mul": self.make_elemwise("multiply"),
+ "aten::mul_": self.make_elemwise("multiply"),
+ "aten::pow": self.make_elemwise("power"),
+ "aten::arange": self.arange,
+ "aten::meshgrid": self.meshgrid,
+ "aten::div": self.make_elemwise("divide"),
+ "aten::div_": self.make_elemwise("divide"),
+ "aten::floor_divide": self.make_elemwise("floor_divide"),
+ "aten::true_divide": self.make_elemwise("divide"),
+ "aten::addcdiv": self.addcdiv,
+ "aten::addcmul": self.addcmul,
+ "aten::ones": self.ones,
+ "aten::ones_like": self.ones_like,
+ "aten::zeros": self.zeros,
+ "aten::zeros_like": self.zeros_like,
+ "aten::full": self.full,
+ "aten::full_like": self.full_like,
+ "aten::linspace": self.linspace,
+ "aten::reciprocal": self.reciprocal,
+ "aten::repeat": self.repeat,
+ "aten::repeat_interleave": self.repeat_interleave,
+ "aten::to": self.to,
+ "aten::squeeze": self.squeeze,
+ "aten::unsqueeze": self.unsqueeze,
+ "aten::cat": self.concatenate,
+ "aten::slice": self.slice,
+ "aten::split": self.split,
+ "aten::split_with_sizes": self.split_with_sizes,
+ "aten::select": self.select,
+ "aten::take": self.take,
+ "aten::where": self.where,
+ "aten::topk": self.topk,
+ "aten::relu": self.relu,
+ "aten::relu_": self.relu,
+ "aten::prelu": self.prelu,
+ "aten::leaky_relu": self.leaky_relu,
+ "aten::leaky_relu_": self.leaky_relu,
+ "aten::elu": self.elu,
+ "aten::elu_": self.elu,
+ "aten::celu": self.celu,
+ "aten::gelu": self.gelu,
+ "aten::selu": self.selu,
+ "aten::log_sigmoid": self.log_sigmoid,
+ "aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d,
+ "aten::adaptive_max_pool2d": self.adaptive_max_pool_2d,
+ "aten::max_pool2d": self.maxpool_2d,
+ "aten::max_pool2d_with_indices": self.maxpool_2d_with_indices,
+ "aten::max_pool1d": self.maxpool_1d,
+ "aten::max_pool3d": self.maxpool_3d,
+ "aten::hardtanh": self.hardtanh,
+ "aten::hardtanh_": self.hardtanh,
+ "aten::_convolution": self.convolution,
+ "aten::softmax": self.softmax,
+ "aten::threshold": self.threshold,
+ "aten::threshold_": self.threshold,
+ "aten::contiguous": self.contiguous,
+ "aten::batch_norm": self.batch_norm,
+ "aten::instance_norm": self.instance_norm,
+ "aten::layer_norm": self.layer_norm,
+ "aten::group_norm": self.group_norm,
+ "aten::transpose": self.transpose,
+ "aten::transpose_": self.transpose,
+ "aten::t": self.transpose,
+ "aten::flatten": self.flatten,
+ "aten::addmm": self.addmm,
+ "aten::size": self.size,
+ "aten::view": self.view,
+ "aten::reshape": self.reshape,
+ "aten::clone": self.clone,
+ "aten::log_softmax": self.log_softmax,
+ "aten::sigmoid": self.sigmoid,
+ "aten::softplus": self.softplus,
+ "aten::avg_pool2d": self.avg_pool2d,
+ "aten::avg_pool3d": self.avg_pool3d,
+ "aten::dropout": self.dropout,
+ "aten::dropout_": self.dropout,
+ "aten::feature_dropout": self.dropout,
+ "aten::alpha_dropout": self.dropout,
+ "aten::mean": self.mean,
+ "aten::chunk": self.chunk,
+ "aten::matmul": self.matmul,
+ "aten::bmm": self.matmul,
+ "aten::expand": self.expand,
+ "aten::Int": self.int,
+ "prim::NumToTensor": self.numtotensor,
+ "prim::ImplicitTensorToNum": self.tensortonum,
+ "aten::ScalarImplicit": self.tensortonum,
+ "aten::constant_pad_nd": self.make_pad("constant"),
+ "aten::reflection_pad1d": self.make_pad("reflect"),
+ "aten::reflection_pad2d": self.make_pad("reflect"),
+ "aten::replication_pad1d": self.make_pad("edge"),
+ "aten::replication_pad2d": self.make_pad("edge"),
+ "aten::replication_pad3d": self.make_pad("edge"),
+ "aten::permute": self.transpose,
+ "aten::sum": self.make_reduce("sum"),
+ "aten::prod": self.make_reduce("prod"),
+ "aten::argmin": self.make_reduce("argmin"),
+ "aten::argmax": self.make_reduce("argmax"),
+ "aten::norm": self.norm,
+ "aten::frobenius_norm": self.frobenius_norm,
+ "aten::std": self.std,
+ "aten::var": self.variance,
+ "aten::abs": self.make_unary("abs"),
+ "aten::neg": self.make_unary("negative"),
+ "aten::cos": self.make_unary("cos"),
+ "aten::cosh": self.make_unary("cosh"),
+ "aten::sin": self.make_unary("sin"),
+ "aten::sinh": self.make_unary("sinh"),
+ "aten::tan": self.make_unary("tan"),
+ "aten::tanh": self.make_unary("tanh"),
+ "aten::acos": self.make_unary("acos"),
+ "aten::asin": self.make_unary("asin"),
+ "aten::atan": self.make_unary("atan"),
+ "aten::log": self.make_unary("log"),
+ "aten::log2": self.make_unary("log2"),
+ "aten::log10": self.make_unary("log10"),
+ "aten::log1p": self.log1p,
+ "aten::exp": self.make_unary("exp"),
+ "aten::erf": self.make_unary("erf"),
+ "aten::trunc": self.make_unary("trunc"),
+ "aten::sign": self.make_unary("sign"),
+ "aten::sqrt": self.make_unary("sqrt"),
+ "aten::rsqrt": self.make_unary("rsqrt"),
+ "aten::ceil": self.make_unary("ceil"),
+ "aten::floor": self.make_unary("floor"),
+ "aten::round": self.make_unary("round"),
+ "aten::isfinite": self.make_unary("isfinite"),
+ "aten::isinf": self.make_unary("isinf"),
+ "aten::isnan": self.make_unary("isnan"),
+ "aten::clamp": self.clamp,
+ "aten::clamp_": self.clamp,
+ "aten::detach": self.identity,
+ "aten::upsample_bilinear2d": self.make_upsample("bilinear"),
+ "aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"),
+ "aten::upsample_trilinear3d": self.make_upsample3d("trilinear"),
+ "aten::upsample_nearest3d":
self.make_upsample3d("nearest_neighbor"),
+ "aten::expand_as": self.expand_as,
+ "aten::lt": self.make_elemwise("less"),
+ "aten::gt": self.make_elemwise("greater"),
+ "aten::le": self.make_elemwise("less_equal"),
+ "aten::ge": self.make_elemwise("greater_equal"),
+ "aten::ne": self.make_elemwise("not_equal"),
+ "aten::eq": self.make_elemwise("equal"),
+ "aten::logical_not": self.logical_not,
+ "aten::logical_xor": self.logical_xor,
+ "aten::bitwise_not": self.bitwise_not,
+ "aten::bitwise_xor": self.bitwise_xor,
+ "aten::Bool": self.Bool,
+ "aten::Float": self.Float,
+ "aten::adaptive_avg_pool3d": self.adaptive_avg_pool_3d,
+ "aten::adaptive_max_pool3d": self.adaptive_max_pool_3d,
+ "aten::rsub": self.rsub,
+ "aten::embedding": self.embedding,
+ "aten::one_hot": self.one_hot,
+ "aten::mm": self.matmul,
+ "aten::add": self.add,
+ "aten::add_": self.add,
+ "aten::stack": self.stack,
+ "aten::__getitem__": self.list_getitem,
+ "aten::len": self.list_len,
+ "aten::type_as": self.type_as,
+ "aten::gather": self.gather,
+ "aten::index_select": self.select,
+ "aten::index": self.index,
+ "torchvision::nms": self.nms,
+ "aten::logsumexp": self.logsumexp,
+ "torchvision::roi_align": self.roi_align,
+ "aten::unbind": self.unbind,
+ "aten::__and__": self.logical_and,
+ "aten::_shape_as_tensor": self.shape_as_tensor,
+ "aten::nonzero": self.nonzero,
+ "aten::nonzero_numpy": self.nonzero_numpy,
+ "aten::scatter": self.scatter,
+ "aten::scalar_tensor": self.scalar_tensor,
+ "aten::__interpolate": self.interpolate,
+ "aten::IntImplicit": self.identity,
+ "aten::tensor": self.identity, # used for example in tensor(1.0)
+ "aten::numel": self.numel,
+ "aten::empty": self.empty,
+ "aten::bincount": self.bincount,
+ "aten::scatter_add": self.scatter_add,
+ "aten::__not__": self.logical_not,
+ }
+
+ def update_convert_map(self, custom_map):
+ self.convert_map.update(custom_map)
+
+ def report_missing_conversion(self, op_names):
+ """ Check if all ops in an input graph are supported by TVM """
+ known_ops = [
+ "prim::Constant",
+ "prim::GetAttr",
+ "prim::ListConstruct",
+ "prim::ListUnpack",
+ "prim::TupleConstruct",
+ "prim::TupleUnpack",
+ "prim::RaiseException",
+ "prim::If",
+ "prim::Loop",
+ ]
+ known_ops += list(self.convert_map.keys())
+ known_ops += list(qnn_torch.convert_map.keys())
+
+ missing = [op_name for op_name in op_names if op_name not in known_ops]
+
+ if missing:
+ msg = "The following operators are not implemented:
{}".format(missing)
+ raise NotImplementedError(msg)
+
+ def convert_block(self, block, outputs):
Review comment:
sorry that would make recursive conversion difficult, so this is good.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]