Caenorst commented on a change in pull request #15399: Add unit tests for 
TensorRT integration and fix some bugs
URL: https://github.com/apache/incubator-mxnet/pull/15399#discussion_r311259674
 
 

 ##########
 File path: src/operator/subgraph/tensorrt/tensorrt-inl.h
 ##########
 @@ -109,13 +111,70 @@ class TensorrtSelector : public SubgraphSelector {
 
   bool isTRTCompatible(const nnvm::Node &n) {
     const std::string op_name = n.op()->name;
+    if (op_name == "FullyConnected") {
+      const auto& param = nnvm::get<FullyConnectedParam>(n.attrs.parsed);
+      return !param.no_bias;
+    }
+
     if (op_name == "Pooling") {
-      return (n.attrs.dict.at("pool_type") == "avg" ||
-          n.attrs.dict.at("pool_type") == "max");
+      const auto& param = nnvm::get<PoolingParam>(n.attrs.parsed);
+      if (param.layout.has_value()) {
+        if (param.layout.value() == mshadow::kNHWC) {
+          LOG(INFO) << "Warning: NHWC layout (node: " << n.attrs.name
+                    << ") is not supported by TensorRT";
+          return false;
+        } else if (param.layout.value() == mshadow::kNDHWC) {
+          LOG(INFO) << "Warning: NDHWC layout (node: " << n.attrs.name
+                    << ") is not supported by TensorRT";
+          return false;
+        }
+      }
+      if (param.pooling_convention != pool_enum::kValid && !param.global_pool)
+        return false;
+      if (param.pool_type == pool_enum::kAvgPooling) {
+        if ((!param.global_pool) &&
+            (!param.count_include_pad.has_value() || 
param.count_include_pad.value()))
+          return false;
+        return true;
+      } else if (param.pool_type == pool_enum::kMaxPooling) {
+        return true;
+      } else {
+        return false;
+      }
     }
 
-    if (unconditionalTRTops.count(op_name)) {
-      return true;
+    if (op_name == "Convolution") {
+      const auto& param = nnvm::get<ConvolutionParam>(n.attrs.parsed);
+      if (!param.layout.has_value())
+        return true;
+      switch (param.layout.value()) {
+        case mshadow::kNCHW:
+        case mshadow::kNCW:
+        case mshadow::kNCDHW:
+          return true;
+        case mshadow::kNHWC:
+          LOG(INFO) << "Warning: NHWC layout (node: " << n.attrs.name
+                    << ") is not supported by TensorRT";
+          return false;
+        case mshadow::kNDHWC:
+          LOG(INFO) << "Warning: NDHWC layout (node: " << n.attrs.name
+                    << ") is not supported by TensorRT";
+          return false;
+        default:
+          LOG(INFO) << "Warning: Layout (node: " << n.attrs.name
+                    << ") is unknown (so unsupported by TensorRT)";
+          return false;
+      }
+    }
+
+    if (op_name == "Concat") {
+      const auto& param = nnvm::get<ConcatParam>(n.attrs.parsed);
+      return (param.dim != 0);
+    }
+
+    if (op_name == "Dropout") {
 
 Review comment:
   Dropout have always been seen as identity function in MXNet-TensorRT 
integration so I don't see any changement on this, regarding to whether or not 
identity is actually doing a copy or not I'm not quite sure, here is the 
onnx-tensorrt conversion: 
https://github.com/onnx/onnx-tensorrt/blob/0ab159579551cabfa05fd66f338357f116e96835/trt_utils.hpp#L169-L180

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to