huajsj commented on a change in pull request #10543:
URL: https://github.com/apache/tvm/pull/10543#discussion_r825381354
##########
File path: src/runtime/pipeline/pipeline_struct.h
##########
@@ -567,50 +686,82 @@ class BackendRuntime {
// Breaking the loop when the notification is in the exit state.
if ((exit_notify = notify->second->GetExitState())) break;
// Getting the source which sends this notification.
- auto notify_source = notify->second->GetNotifySource();
+ auto target_input_interface_index = notify->first;
+ auto source_interface_id = notify->second->GetNotifySource();
// Loading the binding data.
- while (!this->LoadBindingData(notify->first, notify_source.runtime_idx,
- notify_source.runtime_output_idx)) {
+ while (!this->LoadBindingData(target_input_interface_index)) {
// Waiting for the notification.
if (!notify->second->Wait()) {
VLOG(1) << "runtime index:" << runtime_idx_ << " receive exit
notify.";
exit_notify = true;
break;
}
- // TODO(huajsj): removing this 'break' after finishing the
'LoadBindingData'.
- break;
}
- VLOG(1) << "runtime_index.input_index:" << runtime_idx_ << "." <<
notify->first
- << "from runtime_index.output_index:" <<
notify_source.runtime_idx << "."
- << notify_source.runtime_output_idx;
+ VLOG(1) << "Data forwardin from runtime(" <<
source_interface_id.runtime_idx << ").output("
+ << source_interface_id.runtime_interface_idx << ") to runtime("
<< runtime_idx_
+ << ").input(" << target_input_interface_index << ")";
notifys.erase(notify);
}
return exit_notify;
}
/*!
* \brief Loading the binding data.
- * \param parent_idx The index of runtime which forwards data to current
runtime.
- * \param parent_output_idx The index of output where the forwarding data is
coming from.
- * \param input_idx The index of input where the data will be forwarding to.
+ * \param input_index The index of the interface which will receive the
forwarding data.
* \return Returning 'true' when data is loaded successfully, otherwise
returning 'false'.
*/
- bool LoadBindingData(int parent_idx, int parent_output_idx, int input_idx) {
- // TODO(huajsj): Loading data.
- return false;
+ bool LoadBindingData(int input_index) {
+ if (input_queue_.find(input_index) == input_queue_.end()) {
+ LOG(FATAL) << "Not finding the associated input queue of the input " <<
input_index << " !";
+ return false;
+ }
+ auto queue = input_queue_[input_index];
+ QueueData data;
+ // TODO(huajsj): Doing the 'SetInput' inside the poll function to avoid
one time data copy.
+ if (!queue->Poll<QueueData>(&data)) {
+ return false;
+ }
+ SetInput(input_index, data.GetDLData());
+ return true;
}
/*!
* \brief Forwarding the output data into the child runtimes.
+ * \return bool Return false when the "PipelineIsStop" function returns true
or this function
+ * reaches some errors. Otherwise, return true.
*/
- void ForwardingOutputDataToChildren(void) {
+ bool ForwardingOutputDataToChildren(void) {
for (auto child : children_) {
- // TODO(huajsj): Getting the output data from the current runtime in
order to forward
- // data to the child.
-
+ auto output_idx = child.first;
+ if (output_queue_.find(output_idx) == output_queue_.end()) {
+ LOG(FATAL) << "Not find the forwarding queue map for output(" <<
output_idx << ")!";
+ return false;
+ }
+ NDArray output = GetOutput(output_idx);
+ auto forward_queue_map = output_queue_[output_idx];
// Notifying the 'children runtime' that the forwarding data are ready.
for (auto module_pair : child.second) {
- module_pair.first->ParentNotify(module_pair.second);
+ auto child_runtime = module_pair.first;
+ auto child_runtime_index = child_runtime->GetModuleIndex();
+ auto child_input_index = module_pair.second;
+ auto queue_id = GenerateQueueID(child_runtime_index,
child_input_index, INPUT);
+ if (forward_queue_map.find(queue_id) == forward_queue_map.end()) {
+ LOG(FATAL) << "Not find the associated queue of the runtime(" <<
child_runtime_index
+ << ").input(" << child_input_index << ") which is
connected with runtime("
+ << runtime_idx_ << ").output(" << output_idx << ")";
+ }
+ auto forward_queue = forward_queue_map[queue_id];
+ // If the queue is full, keep try until the push get success or the
pipeline run into
+ // a STOP state.
+ while (!forward_queue->Push<NDArray>(output)) {
+ if (PipelineIsStop()) {
+ LOG(INFO) << "The forwarding queue pushing is stopped due to the
pipeline state "
+ << "is changed into stop.";
Review comment:
fixed.
"The forwarding process is stopped after the pipeline status is changed into
stop."
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]