masahi commented on a change in pull request #7023:
URL: https://github.com/apache/tvm/pull/7023#discussion_r535916137



##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -2377,18 +2023,428 @@ def _impl(inputs, input_types):
         counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
         return _op.scatter_add(counts, data, updates, axis=0)
 
-    return _impl
-
-
-def _scatter_add():
-    def _impl(inputs, input_types):
+    @staticmethod
+    def scatter_add(inputs, input_types):
         data = inputs[0]
         axis = inputs[1]
         index = inputs[2]
         src = inputs[3]
         return _op.scatter_add(data, index, src, axis=axis)
 
-    return _impl
+    # Operator mappings
+    def create_convert_map(self):
+        self.convert_map = {
+            "aten::pixel_shuffle": self.pixel_shuffle,
+            "aten::device": self.none,
+            "prim::device": self.none,
+            "aten::sub": self.make_elemwise("subtract"),
+            "aten::sub_": self.make_elemwise("subtract"),
+            "aten::max": self.max,
+            "aten::min": self.min,
+            "aten::mul": self.make_elemwise("multiply"),
+            "aten::mul_": self.make_elemwise("multiply"),
+            "aten::pow": self.make_elemwise("power"),
+            "aten::arange": self.arange,
+            "aten::meshgrid": self.meshgrid,
+            "aten::div": self.make_elemwise("divide"),
+            "aten::div_": self.make_elemwise("divide"),
+            "aten::floor_divide": self.make_elemwise("floor_divide"),
+            "aten::true_divide": self.make_elemwise("divide"),
+            "aten::addcdiv": self.addcdiv,
+            "aten::addcmul": self.addcmul,
+            "aten::ones": self.ones,
+            "aten::ones_like": self.ones_like,
+            "aten::zeros": self.zeros,
+            "aten::zeros_like": self.zeros_like,
+            "aten::full": self.full,
+            "aten::full_like": self.full_like,
+            "aten::linspace": self.linspace,
+            "aten::reciprocal": self.reciprocal,
+            "aten::repeat": self.repeat,
+            "aten::repeat_interleave": self.repeat_interleave,
+            "aten::to": self.to,
+            "aten::squeeze": self.squeeze,
+            "aten::unsqueeze": self.unsqueeze,
+            "aten::cat": self.concatenate,
+            "aten::slice": self.slice,
+            "aten::split": self.split,
+            "aten::split_with_sizes": self.split_with_sizes,
+            "aten::select": self.select,
+            "aten::take": self.take,
+            "aten::where": self.where,
+            "aten::topk": self.topk,
+            "aten::relu": self.relu,
+            "aten::relu_": self.relu,
+            "aten::prelu": self.prelu,
+            "aten::leaky_relu": self.leaky_relu,
+            "aten::leaky_relu_": self.leaky_relu,
+            "aten::elu": self.elu,
+            "aten::elu_": self.elu,
+            "aten::celu": self.celu,
+            "aten::gelu": self.gelu,
+            "aten::selu": self.selu,
+            "aten::log_sigmoid": self.log_sigmoid,
+            "aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d,
+            "aten::adaptive_max_pool2d": self.adaptive_max_pool_2d,
+            "aten::max_pool2d": self.maxpool_2d,
+            "aten::max_pool2d_with_indices": self.maxpool_2d_with_indices,
+            "aten::max_pool1d": self.maxpool_1d,
+            "aten::max_pool3d": self.maxpool_3d,
+            "aten::hardtanh": self.hardtanh,
+            "aten::hardtanh_": self.hardtanh,
+            "aten::_convolution": self.convolution,
+            "aten::softmax": self.softmax,
+            "aten::threshold": self.threshold,
+            "aten::threshold_": self.threshold,
+            "aten::contiguous": self.contiguous,
+            "aten::batch_norm": self.batch_norm,
+            "aten::instance_norm": self.instance_norm,
+            "aten::layer_norm": self.layer_norm,
+            "aten::group_norm": self.group_norm,
+            "aten::transpose": self.transpose,
+            "aten::transpose_": self.transpose,
+            "aten::t": self.transpose,
+            "aten::flatten": self.flatten,
+            "aten::addmm": self.addmm,
+            "aten::size": self.size,
+            "aten::view": self.view,
+            "aten::reshape": self.reshape,
+            "aten::clone": self.clone,
+            "aten::log_softmax": self.log_softmax,
+            "aten::sigmoid": self.sigmoid,
+            "aten::softplus": self.softplus,
+            "aten::avg_pool2d": self.avg_pool2d,
+            "aten::avg_pool3d": self.avg_pool3d,
+            "aten::dropout": self.dropout,
+            "aten::dropout_": self.dropout,
+            "aten::feature_dropout": self.dropout,
+            "aten::alpha_dropout": self.dropout,
+            "aten::mean": self.mean,
+            "aten::chunk": self.chunk,
+            "aten::matmul": self.matmul,
+            "aten::bmm": self.matmul,
+            "aten::expand": self.expand,
+            "aten::Int": self.int,
+            "prim::NumToTensor": self.numtotensor,
+            "prim::ImplicitTensorToNum": self.tensortonum,
+            "aten::ScalarImplicit": self.tensortonum,
+            "aten::constant_pad_nd": self.make_pad("constant"),
+            "aten::reflection_pad1d": self.make_pad("reflect"),
+            "aten::reflection_pad2d": self.make_pad("reflect"),
+            "aten::replication_pad1d": self.make_pad("edge"),
+            "aten::replication_pad2d": self.make_pad("edge"),
+            "aten::replication_pad3d": self.make_pad("edge"),
+            "aten::permute": self.transpose,
+            "aten::sum": self.make_reduce("sum"),
+            "aten::prod": self.make_reduce("prod"),
+            "aten::argmin": self.make_reduce("argmin"),
+            "aten::argmax": self.make_reduce("argmax"),
+            "aten::norm": self.norm,
+            "aten::frobenius_norm": self.frobenius_norm,
+            "aten::std": self.std,
+            "aten::var": self.variance,
+            "aten::abs": self.make_unary("abs"),
+            "aten::neg": self.make_unary("negative"),
+            "aten::cos": self.make_unary("cos"),
+            "aten::cosh": self.make_unary("cosh"),
+            "aten::sin": self.make_unary("sin"),
+            "aten::sinh": self.make_unary("sinh"),
+            "aten::tan": self.make_unary("tan"),
+            "aten::tanh": self.make_unary("tanh"),
+            "aten::acos": self.make_unary("acos"),
+            "aten::asin": self.make_unary("asin"),
+            "aten::atan": self.make_unary("atan"),
+            "aten::log": self.make_unary("log"),
+            "aten::log2": self.make_unary("log2"),
+            "aten::log10": self.make_unary("log10"),
+            "aten::log1p": self.log1p,
+            "aten::exp": self.make_unary("exp"),
+            "aten::erf": self.make_unary("erf"),
+            "aten::trunc": self.make_unary("trunc"),
+            "aten::sign": self.make_unary("sign"),
+            "aten::sqrt": self.make_unary("sqrt"),
+            "aten::rsqrt": self.make_unary("rsqrt"),
+            "aten::ceil": self.make_unary("ceil"),
+            "aten::floor": self.make_unary("floor"),
+            "aten::round": self.make_unary("round"),
+            "aten::isfinite": self.make_unary("isfinite"),
+            "aten::isinf": self.make_unary("isinf"),
+            "aten::isnan": self.make_unary("isnan"),
+            "aten::clamp": self.clamp,
+            "aten::clamp_": self.clamp,
+            "aten::detach": self.identity,
+            "aten::upsample_bilinear2d": self.make_upsample("bilinear"),
+            "aten::upsample_nearest2d": self.make_upsample("nearest_neighbor"),
+            "aten::upsample_trilinear3d": self.make_upsample3d("trilinear"),
+            "aten::upsample_nearest3d": 
self.make_upsample3d("nearest_neighbor"),
+            "aten::expand_as": self.expand_as,
+            "aten::lt": self.make_elemwise("less"),
+            "aten::gt": self.make_elemwise("greater"),
+            "aten::le": self.make_elemwise("less_equal"),
+            "aten::ge": self.make_elemwise("greater_equal"),
+            "aten::ne": self.make_elemwise("not_equal"),
+            "aten::eq": self.make_elemwise("equal"),
+            "aten::logical_not": self.logical_not,
+            "aten::logical_xor": self.logical_xor,
+            "aten::bitwise_not": self.bitwise_not,
+            "aten::bitwise_xor": self.bitwise_xor,
+            "aten::Bool": self.Bool,
+            "aten::Float": self.Float,
+            "aten::adaptive_avg_pool3d": self.adaptive_avg_pool_3d,
+            "aten::adaptive_max_pool3d": self.adaptive_max_pool_3d,
+            "aten::rsub": self.rsub,
+            "aten::embedding": self.embedding,
+            "aten::one_hot": self.one_hot,
+            "aten::mm": self.matmul,
+            "aten::add": self.add,
+            "aten::add_": self.add,
+            "aten::stack": self.stack,
+            "aten::__getitem__": self.list_getitem,
+            "aten::len": self.list_len,
+            "aten::type_as": self.type_as,
+            "aten::gather": self.gather,
+            "aten::index_select": self.select,
+            "aten::index": self.index,
+            "torchvision::nms": self.nms,
+            "aten::logsumexp": self.logsumexp,
+            "torchvision::roi_align": self.roi_align,
+            "aten::unbind": self.unbind,
+            "aten::__and__": self.logical_and,
+            "aten::_shape_as_tensor": self.shape_as_tensor,
+            "aten::nonzero": self.nonzero,
+            "aten::nonzero_numpy": self.nonzero_numpy,
+            "aten::scatter": self.scatter,
+            "aten::scalar_tensor": self.scalar_tensor,
+            "aten::__interpolate": self.interpolate,
+            "aten::IntImplicit": self.identity,
+            "aten::tensor": self.identity,  # used for example in tensor(1.0)
+            "aten::numel": self.numel,
+            "aten::empty": self.empty,
+            "aten::bincount": self.bincount,
+            "aten::scatter_add": self.scatter_add,
+            "aten::__not__": self.logical_not,
+        }
+
+    def update_convert_map(self, custom_map):
+        self.convert_map.update(custom_map)
+
+    def report_missing_conversion(self, op_names):
+        """ Check if all ops in an input graph are supported by TVM """
+        known_ops = [
+            "prim::Constant",
+            "prim::GetAttr",
+            "prim::ListConstruct",
+            "prim::ListUnpack",
+            "prim::TupleConstruct",
+            "prim::TupleUnpack",
+            "prim::RaiseException",
+            "prim::If",
+            "prim::Loop",
+        ]
+        known_ops += list(self.convert_map.keys())
+        known_ops += list(qnn_torch.convert_map.keys())
+
+        missing = [op_name for op_name in op_names if op_name not in known_ops]
+
+        if missing:
+            msg = "The following operators are not implemented: 
{}".format(missing)
+            raise NotImplementedError(msg)
+
+    def convert_block(self, block, outputs):

Review comment:
       sorry that would make recursive conversion difficult, so this is good. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to