trevor-m commented on a change in pull request #8808:
URL: https://github.com/apache/tvm/pull/8808#discussion_r695107832
##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -238,14 +253,15 @@ void TensorRTBuilder::CleanUp() {
#endif
builder_->destroy();
for (auto weight : trt_weights_) {
- if (weight.type == nvinfer1::DataType::kFLOAT) {
+ if (weight.type == nvinfer1::DataType::kFLOAT)
Review comment:
Please use clang-format to fix formatting
##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -40,30 +40,30 @@ namespace contrib {
TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
const std::vector<const DLTensor*>&
data_entry,
size_t max_workspace_size, bool
use_implicit_batch, bool use_fp16,
- int batch_size)
+ int batch_size, nvinfer1::IInt8Calibrator*
calibrator)
: data_entry_(data_entry),
max_workspace_size_(max_workspace_size),
use_implicit_batch_(use_implicit_batch),
use_fp16_(use_fp16),
batch_size_(batch_size) {
// Create TRT builder and network.
builder_ = nvinfer1::createInferBuilder(*logger);
-#if TRT_VERSION_GE(6, 0, 1)
- // Use INetworkV2.
- auto flags =
- 1U <<
static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
- if (use_implicit_batch_) {
- flags = 0U;
- builder_->setMaxBatchSize(batch_size_);
- }
- network_ = builder_->createNetworkV2(flags);
-#else
+ LOG(INFO) << "create a builder_ ";
+ use_int8_ = false;
// Use INetwork with implicit batch.
builder_->setMaxBatchSize(batch_size_);
builder_->setMaxWorkspaceSize(max_workspace_size_);
builder_->setFp16Mode(use_fp16_);
+
+ this->calibrator_ = calibrator;
+ if (calibrator != nullptr) {
+ LOG(INFO) << "calibrator is not null, and setting up int8 mode ... ";
+ set_use_int8();
Review comment:
I don't think set_use_int8 needs to be a method, we can just set int8 to
true here.
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -125,17 +146,24 @@ class TensorRTRuntime : public JSONRuntimeBase {
/*! \brief Run inference using built engine. */
void Run() override {
+
auto& engine_and_context = GetOrBuildEngine();
int batch_size = GetBatchSize();
if (batch_size == 0) return;
auto engine = engine_and_context.engine;
auto context = engine_and_context.context;
- std::vector<void*> bindings(engine->getNbBindings(), nullptr);
+ const int num_bindings = engine->getNbBindings();
+ std::vector<void*> bindings(num_bindings, nullptr);
+ std::vector<size_t> binding_sizes(num_bindings, 0);
// Setup input bindings.
+ const size_t num_inputs = input_nodes_.size();
+ int count_inputs = 0;
Review comment:
We don't need count_inputs, we can just use num_bindings
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -267,13 +320,68 @@ class TensorRTRuntime : public JSONRuntimeBase {
}
// Build engine.
- trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] =
builder.BuildEngine();
- DLOG(INFO) << "Finished building TensorRT engine for subgraph " <<
symbol_name_
+ // trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] =
builder.BuildEngine();
+ const bool use_int8 = (dmlc::GetEnv("TVM_TENSORRT_USE_INT8", 0) != 0);
+ TensorRTEngineAndContext engine_and_context = builder.BuildEngine();
+ trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] =
engine_and_context;
+ if(use_int8 == true){
+ if(calibrator_ == nullptr){
+ this->CreateCalibratorIfUsingInt8(engine_and_context);
+ }
+
+ if(num_calibration_batches_remaining_ == 0){
+ engine_and_context.context->destroy();
+ engine_and_context.engine->destroy();
+
+ LOG(INFO)<<"rebuild builder using int8 mode";
+ TensorRTBuilder builder2(&logger_, data_entry_, max_workspace_size_,
use_implicit_batch_,
Review comment:
If we are building the final int8 engine, we didnt need to build the
first engine in this same function call.
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -125,17 +147,26 @@ class TensorRTRuntime : public JSONRuntimeBase {
/*! \brief Run inference using built engine. */
void Run() override {
+
Review comment:
We have to build the engine first to know the input binding order that
tensorrt assigns to the inputs. It might not match TVM input signature directly.
This is the same strategy used by TF-TRT
##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -40,30 +40,30 @@ namespace contrib {
TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
const std::vector<const DLTensor*>&
data_entry,
size_t max_workspace_size, bool
use_implicit_batch, bool use_fp16,
- int batch_size)
+ int batch_size, nvinfer1::IInt8Calibrator*
calibrator)
: data_entry_(data_entry),
max_workspace_size_(max_workspace_size),
use_implicit_batch_(use_implicit_batch),
use_fp16_(use_fp16),
batch_size_(batch_size) {
// Create TRT builder and network.
builder_ = nvinfer1::createInferBuilder(*logger);
-#if TRT_VERSION_GE(6, 0, 1)
- // Use INetworkV2.
- auto flags =
- 1U <<
static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
- if (use_implicit_batch_) {
- flags = 0U;
- builder_->setMaxBatchSize(batch_size_);
- }
- network_ = builder_->createNetworkV2(flags);
-#else
+ LOG(INFO) << "create a builder_ ";
+ use_int8_ = false;
Review comment:
Please keep the `#if TRT_VERSION_GE(6, 0, 1)` macros. You will need to
use a different API to enable int8 depending on the TRT version
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -153,9 +181,31 @@ class TensorRTRuntime : public JSONRuntimeBase {
device_buffer.CopyFrom(data_entry_[eid]);
bindings[binding_index] = device_buffer->data;
}
+
+ auto dims = engine->getBindingDimensions(binding_index);
+ int num_elements = 1;
+ for (int i = 0; i < dims.nbDims; ++i) num_elements *= dims.d[i];
+ binding_sizes[binding_index] = num_elements;
+
}
}
}
+
+ // add batch data to calibrator
+ if(num_calibration_batches_remaining_ != 0){
+ if(calibrator_ != nullptr){
Review comment:
Can calibrator be null and num_calibration_batches be > 0? I think this
if can be replace by
`ICHECK(calibrator_ != nullptr)`
##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.h
##########
@@ -153,6 +159,8 @@ class TensorRTBuilder {
/*! \brief Whether to automatically convert model to 16-bit floating point
precision. */
bool use_fp16_;
+ bool use_int8_;
Review comment:
Add docstring comments to new class variables
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -153,9 +181,31 @@ class TensorRTRuntime : public JSONRuntimeBase {
device_buffer.CopyFrom(data_entry_[eid]);
bindings[binding_index] = device_buffer->data;
}
+
+ auto dims = engine->getBindingDimensions(binding_index);
+ int num_elements = 1;
+ for (int i = 0; i < dims.nbDims; ++i) num_elements *= dims.d[i];
+ binding_sizes[binding_index] = num_elements;
+
}
}
}
+
+ // add batch data to calibrator
+ if(num_calibration_batches_remaining_ != 0){
Review comment:
I prefer checking`num_calibration_batches_remaining_ > 0`
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -153,9 +181,31 @@ class TensorRTRuntime : public JSONRuntimeBase {
device_buffer.CopyFrom(data_entry_[eid]);
bindings[binding_index] = device_buffer->data;
}
+
+ auto dims = engine->getBindingDimensions(binding_index);
+ int num_elements = 1;
+ for (int i = 0; i < dims.nbDims; ++i) num_elements *= dims.d[i];
+ binding_sizes[binding_index] = num_elements;
+
}
}
}
+
+ // add batch data to calibrator
+ if(num_calibration_batches_remaining_ != 0){
+ if(calibrator_ != nullptr){
+ LOG(INFO) << "starting adding last " <<
num_calibration_batches_remaining_ << "-th batch data to calibrator";
+ std::vector<void*> input_bindings(bindings.begin(),
+ bindings.begin() + count_inputs);
+ std::vector<size_t> input_sizes(binding_sizes.begin(),
+ binding_sizes.begin() + count_inputs);
+ calibrator_->AddBatchData(input_bindings, input_sizes);
Review comment:
Just pass `bindings, binding_sizes` directly to AddBatchData
##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -40,30 +40,30 @@ namespace contrib {
TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
const std::vector<const DLTensor*>&
data_entry,
size_t max_workspace_size, bool
use_implicit_batch, bool use_fp16,
- int batch_size)
+ int batch_size, nvinfer1::IInt8Calibrator*
calibrator)
: data_entry_(data_entry),
max_workspace_size_(max_workspace_size),
use_implicit_batch_(use_implicit_batch),
use_fp16_(use_fp16),
batch_size_(batch_size) {
// Create TRT builder and network.
builder_ = nvinfer1::createInferBuilder(*logger);
-#if TRT_VERSION_GE(6, 0, 1)
Review comment:
Yes, we need to keep these version macros. TRT6+ has a different API
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -267,13 +320,68 @@ class TensorRTRuntime : public JSONRuntimeBase {
}
// Build engine.
- trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] =
builder.BuildEngine();
- DLOG(INFO) << "Finished building TensorRT engine for subgraph " <<
symbol_name_
+ // trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] =
builder.BuildEngine();
+ const bool use_int8 = (dmlc::GetEnv("TVM_TENSORRT_USE_INT8", 0) != 0);
+ TensorRTEngineAndContext engine_and_context = builder.BuildEngine();
+ trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] =
engine_and_context;
+ if(use_int8 == true){
+ if(calibrator_ == nullptr){
+ this->CreateCalibratorIfUsingInt8(engine_and_context);
+ }
+
+ if(num_calibration_batches_remaining_ == 0){
+ engine_and_context.context->destroy();
+ engine_and_context.engine->destroy();
+
+ LOG(INFO)<<"rebuild builder using int8 mode";
+ TensorRTBuilder builder2(&logger_, data_entry_, max_workspace_size_,
use_implicit_batch_,
+ use_fp16, batch_size, calibrator_.get());
+ set_up_input_output(builder2);
+ TensorRTEngineAndContext new_engine_and_context =
builder2.BuildEngine();
+ trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] =
new_engine_and_context;
+ calibrator_.reset(nullptr);
+ LOG(INFO) <<"finished rebuilding using int8 mode ... ";
+ }
+
+ }
+
+ LOG(INFO) << "Finished building TensorRT engine for subgraph " <<
symbol_name_
<< " with batch size " << batch_size;
CacheEngineToDisk();
return trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size));
}
+
+ void set_up_input_output(TensorRTBuilder& builder){
Review comment:
Lets move all of the build engine functionality to here and rename to
BuildEngineFromJSON(). Currently this code is duplicating what is already in
`GetOrBuildEngine`. We can pass a flag if we want `BuildEngineFromJson()` to
consume the calibrator and build the int8 engine.
Then we can just call `BuildEngineFromJson()` from `GetOrBuildEngine` when
we need to build a new engine.
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -66,7 +78,16 @@ class TensorRTRuntime : public JSONRuntimeBase {
use_implicit_batch_(true),
max_workspace_size_(size_t(1) << 30),
max_batch_size_(-1),
- multi_engine_mode_(false) {}
+ multi_engine_mode_(false) {
+ const bool use_int8 = dmlc::GetEnv("TVM_TENSORRT_USE_INT8", false);
Review comment:
This will be too restricting because the precision and number of
calibration images would not be able to be configured at runtime.
##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.cc
##########
@@ -40,30 +40,30 @@ namespace contrib {
TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger,
const std::vector<const DLTensor*>&
data_entry,
size_t max_workspace_size, bool
use_implicit_batch, bool use_fp16,
- int batch_size)
+ int batch_size, nvinfer1::IInt8Calibrator*
calibrator)
: data_entry_(data_entry),
max_workspace_size_(max_workspace_size),
use_implicit_batch_(use_implicit_batch),
use_fp16_(use_fp16),
batch_size_(batch_size) {
// Create TRT builder and network.
builder_ = nvinfer1::createInferBuilder(*logger);
-#if TRT_VERSION_GE(6, 0, 1)
- // Use INetworkV2.
- auto flags =
- 1U <<
static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
- if (use_implicit_batch_) {
- flags = 0U;
- builder_->setMaxBatchSize(batch_size_);
- }
- network_ = builder_->createNetworkV2(flags);
-#else
+ LOG(INFO) << "create a builder_ ";
Review comment:
Could you please remove all of the debug logs?
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -48,6 +54,12 @@ struct PairHash {
}
};
+std::string getEnvVar( std::string const & key )
Review comment:
Don't think this is needed since we have `dmlc::GetEnv`
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -66,7 +78,16 @@ class TensorRTRuntime : public JSONRuntimeBase {
use_implicit_batch_(true),
max_workspace_size_(size_t(1) << 30),
max_batch_size_(-1),
- multi_engine_mode_(false) {}
+ multi_engine_mode_(false) {
+ const bool use_int8 = dmlc::GetEnv("TVM_TENSORRT_USE_INT8", false);
Review comment:
Hi @FrozenGene
We already have environment variables for runtime options such as
`TVM_TENSORRT_USE_FP16`. This calibration is purely done at runtime, so it is
too restricting to require user to specify this during compilation. I think
environment variable is the only clean way to add new runtime functionality.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]