mbs-octoml commented on a change in pull request #10388:
URL: https://github.com/apache/tvm/pull/10388#discussion_r821022656



##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -85,8 +85,10 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, 
const JSONGraphNode&
       shape.erase(shape.begin());
     }
     nvinfer1::Dims dims = VectorToTrtDims(shape);
-    ICHECK(TypeMatch(dtypes[i], kDLFloat, 32)) << "Only FP32 inputs are 
supported.";
-    auto input_tensor = network_->addInput(name.c_str(), 
nvinfer1::DataType::kFLOAT, dims);
+    auto tensor_dtype =
+        (dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : 
nvinfer1::DataType::kFLOAT;

Review comment:
       I'd suggest ICHECK failing if unsupported type.

##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -205,18 +210,16 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
 nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,
                                                         DLDeviceType 
src_device) {
   ICHECK_EQ(dptr->device.device_type, src_device);
-  ICHECK(static_cast<int>(dptr->dtype.code) == kDLFloat ||
-         static_cast<int>(dptr->dtype.code) == kDLInt);
-  const auto trt_dtype = static_cast<int>(dptr->dtype.code) == kDLFloat
-                             ? nvinfer1::DataType::kFLOAT
-                             : nvinfer1::DataType::kINT32;
+
+  const auto trt_dtype = (static_cast<int>(dptr->dtype.bits) == 16) ? 
nvinfer1::DataType::kHALF

Review comment:
       Another ICHECK would be in order to make sure we're not silently 
generating bad code.

##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -250,7 +253,7 @@ void TensorRTBuilder::CleanUp() {
 #endif
   builder_->destroy();
   for (auto weight : trt_weights_) {
-    if (weight.type == nvinfer1::DataType::kFLOAT) {
+    if (static_cast<int>(weight.type) <= 1) {

Review comment:
       Can we avoid hard coding the enum constants?

##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -141,15 +143,18 @@ void TensorRTBuilder::AddLayer(int nid, const 
JSONGraphNode& node) {
     }
     params.inputs.push_back(input);
   }
-  ICHECK(converter->variable_input_count || converter->input_types.size() == 
params.inputs.size())
-      << "Op expected a different number of inputs.";
 
   // Convert op to TRT.
   converter->Convert(&params);
 
   // Get outputs.
   node_output_map_[nid] = {};
   for (auto out : params.outputs) {
+    auto out_type = params.inputs.at(1).weight.type == 
params.inputs.at(0).tensor->getType()

Review comment:
       Can you explain this? It seems very specific yet AddLayer is used for 
all  of the supported ops. 

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -202,9 +211,6 @@ def _func_wrapper(expr):
         # ops with dynamic shapes are offloaded to VM
         if check_dynamism(args, op_name):
             return False
-        if any([x.checked_type.dtype != "float32" for x in args]):

Review comment:
       I'm not seeing where the type check (which must now be generalized to 
float32/float16) has gone too. If we remove it altogether then I think we'll 
either generate bad code or fail at trt build time, which from the tvm users 
point of view is runtime and too late. We also need to check in the predicate 
to prevent collage from exploring invalid candidate kernels.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to