trevor-m commented on a change in pull request #8808:
URL: https://github.com/apache/tvm/pull/8808#discussion_r700349594
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -66,7 +73,17 @@ class TensorRTRuntime : public JSONRuntimeBase {
use_implicit_batch_(true),
max_workspace_size_(size_t(1) << 30),
max_batch_size_(-1),
- multi_engine_mode_(false) {}
+ multi_engine_mode_(false) {
+ const bool use_int8 = dmlc::GetEnv("TVM_TENSORRT_USE_INT8", false);
+ if (use_int8) {
+ const int extract_cali_num =
dmlc::GetEnv("TENSORRT_NUM_CALI_INT8", 0);
+ ICHECK(extract_cali_num != 0);
Review comment:
Let's add a message here:
```
ICHECK(extract_cali_num != 0) << "When using INT8 mode, environment variable
TENSORRT_NUM_CALI_INT8 must also be set to specify the number of inputs which
will be used for calibration.";
```
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -153,9 +172,30 @@ class TensorRTRuntime : public JSONRuntimeBase {
device_buffer.CopyFrom(data_entry_[eid]);
bindings[binding_index] = device_buffer->data;
}
+
+ auto dims = engine->getBindingDimensions(binding_index);
+ int num_elements = 1;
+ for (int i = 0; i < dims.nbDims; ++i) num_elements *= dims.d[i];
+ binding_sizes[binding_index] = num_elements;
}
}
}
+
+ // add batch data to calibrator
+ if (num_calibration_batches_remaining_ > 0) {
+ if (calibrator_ != nullptr) {
+ LOG(INFO) << "Starting adding last " <<
+ num_calibration_batches_remaining_ <<
+ "-th batch data to the calibrator";
+ std::vector<size_t> input_sizes(binding_sizes.begin(),
+ binding_sizes.begin() + num_bindings);
+ calibrator_->AddBatchData(bindings, input_sizes);
Review comment:
Can just use `binding_sizes` directly instead of copying to `input_sizes`
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -225,10 +266,16 @@ class TensorRTRuntime : public JSONRuntimeBase {
TensorRTEngineAndContext& GetOrBuildEngine() {
int batch_size = GetBatchSize();
int compatible_engine_batch_size = -1;
- if (FindCompatibleEngine(batch_size, &compatible_engine_batch_size)) {
+ bool find_engine_flag = FindCompatibleEngine(batch_size,
&compatible_engine_batch_size);
+ const bool use_int8 = (dmlc::GetEnv("TVM_TENSORRT_USE_INT8", 0) != 0);
+ if (find_engine_flag &&
Review comment:
I think we can simplify the condition and also make it clearer this way:
```
const bool int8_calibration_not_used_or_not_complete = calibrator_ !=
nullptr && num_calibration_batches_remaining_ != 0;
if (find_engine_flag && int8_calibration_not_used_or_not_complete) {
return ...
}
```
If there is a compatible engine already built, we can always use it UNLESS
we have finished gathering calibration data and need to build final INT8 engine.
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -240,40 +287,65 @@ class TensorRTRuntime : public JSONRuntimeBase {
TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_,
use_implicit_batch_,
use_fp16, batch_size);
- // Add inputs and constants.
- for (size_t i = 0; i < input_nodes_.size(); ++i) {
- auto nid = input_nodes_[i];
- const auto& node = nodes_[nid];
- std::string name = node.GetOpName();
- if (node.GetOpType() == "input") {
- builder.AddInput(nid, EntryID(nid, 0), node);
- } else {
- ICHECK_EQ(node.GetOpType(), "const");
- uint32_t eid = EntryID(nid, 0);
- builder.AddConstant(nid, data_entry_[eid]);
- }
- }
-
- // Add layers.
- for (size_t nid = 0; nid < nodes_.size(); ++nid) {
- const auto& node = nodes_[nid];
- if (node.GetOpType() != "kernel") continue;
- builder.AddLayer(nid, node);
+ // Build engine.
+ if (trt_engine_cache_.find(std::make_pair(symbol_name_, batch_size)) ==
+ trt_engine_cache_.end()) {
+ BuildEngineFromJson(use_fp16, batch_size);
}
-
- // Add outputs.
- for (size_t i = 0; i < outputs_.size(); ++i) {
- builder.AddOutput(outputs_[i], EntryID(outputs_[i]));
+ if (use_int8) {
+ TensorRTEngineAndContext& engine_and_context =
+ trt_engine_cache_[std::make_pair(symbol_name_,
batch_size)];
+ if (calibrator_ == nullptr) {
+ this->CreateCalibratorIfUsingInt8(engine_and_context);
+ }
+ if (num_calibration_batches_remaining_ == 0) {
+ engine_and_context.context->destroy();
+ engine_and_context.engine->destroy();
+ LOG(INFO) << "rebuild builder using INT8 mode";
+ BuildEngineFromJson(use_fp16, batch_size);
+ calibrator_.reset(nullptr);
+ LOG(INFO) << "finished rebuilding using INT8 mode ... ";
+ }
}
- // Build engine.
- trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] =
builder.BuildEngine();
- DLOG(INFO) << "Finished building TensorRT engine for subgraph " <<
symbol_name_
+ LOG(INFO) << "Finished building TensorRT engine for subgraph " <<
symbol_name_
<< " with batch size " << batch_size;
CacheEngineToDisk();
return trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size));
}
+ void BuildEngineFromJson(bool use_fp16, int batch_size) {
Review comment:
We can move `const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16",
false);` in here so we dont need to pass the flag around.
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -240,40 +287,65 @@ class TensorRTRuntime : public JSONRuntimeBase {
TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_,
use_implicit_batch_,
use_fp16, batch_size);
- // Add inputs and constants.
- for (size_t i = 0; i < input_nodes_.size(); ++i) {
- auto nid = input_nodes_[i];
- const auto& node = nodes_[nid];
- std::string name = node.GetOpName();
- if (node.GetOpType() == "input") {
- builder.AddInput(nid, EntryID(nid, 0), node);
- } else {
- ICHECK_EQ(node.GetOpType(), "const");
- uint32_t eid = EntryID(nid, 0);
- builder.AddConstant(nid, data_entry_[eid]);
- }
- }
-
- // Add layers.
- for (size_t nid = 0; nid < nodes_.size(); ++nid) {
- const auto& node = nodes_[nid];
- if (node.GetOpType() != "kernel") continue;
- builder.AddLayer(nid, node);
+ // Build engine.
+ if (trt_engine_cache_.find(std::make_pair(symbol_name_, batch_size)) ==
Review comment:
From here on, the logic should look like this:
```
if (calibrator != nullptr && num_calibration_batches_remaining_ == 0) {
// Calibration complete. Delete fp32 engine and build int8 engine
TensorRTEngineAndContext& engine_and_context =
trt_engine_cache_[std::make_pair(symbol_name_, compatible_batch_size)];
engine_and_context.context->destroy();
engine_and_context.engine->destroy();
BuildEngineFromJson(use_fp16, batch_size);
calibrator_.reset(nullptr);
} else {
// Build new engine
BuildEngineFromJson(use_fp16, batch_size);
if (use_int8) {
this->CreateCalibratorIfUsingInt8(engine_and_context);
}
}
CacheEngineToDisk();
return trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size));
```
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -240,40 +287,65 @@ class TensorRTRuntime : public JSONRuntimeBase {
TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_,
use_implicit_batch_,
Review comment:
We dont need this builder anymore.
##########
File path: src/runtime/contrib/tensorrt/tensorrt_builder.h
##########
@@ -161,6 +164,9 @@ class TensorRTBuilder {
/*! \brief Output names. */
std::vector<std::string> network_output_names_;
+
+ // calibrator pointer
Review comment:
Can you fix the docstring and add some more description? Maybe mention
the different states when calibrator can be nullptr, etc.
##########
File path: src/runtime/contrib/tensorrt/tensorrt_runtime.cc
##########
@@ -369,10 +442,26 @@ class TensorRTRuntime : public JSONRuntimeBase {
return device_buffers_.at(binding_index);
}
+ void CreateCalibratorIfUsingInt8(const TensorRTEngineAndContext&
engine_and_context) {
Review comment:
Probably can rename to just `CreateInt8Calibrator`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]