comaniac commented on a change in pull request #8633:
URL: https://github.com/apache/tvm/pull/8633#discussion_r681925000



##########
File path: python/tvm/contrib/graph_executor.py
##########
@@ -242,6 +243,19 @@ def get_input(self, index, out=None):
 
         return self._get_input(index)
 
+    def get_input_index(self, name):
+        """Get inputs index via input name

Review comment:
       ```suggestion
           """Get inputs index via input name.
   
   ```

##########
File path: src/runtime/graph_executor/graph_executor.cc
##########
@@ -502,6 +502,14 @@ PackedFunc GraphExecutor::GetFunction(const std::string& 
name,
       dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
       this->ShareParams(dynamic_cast<const 
GraphExecutor&>(*module.operator->()), &strm);
     });
+  } else if (name == "get_input_index") {
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      if (String::CanConvertFrom(args[0])) {
+        *rv = this->GetInputIndex(args[0].operator String());
+      } else {
+        *rv = args[0];
+      }

Review comment:
       ```suggestion
         CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
         *rv = this->GetInputIndex(args[0].operator String());
   ```
   
   IIUC, it means we simply return the input when it is not a string? It seems 
a bit weird.
   `get_input` uses the similar logic because it allows you to get the input 
tensor using either name or index, but it is not the case here.

##########
File path: python/tvm/contrib/graph_executor.py
##########
@@ -242,6 +243,19 @@ def get_input(self, index, out=None):
 
         return self._get_input(index)
 
+    def get_input_index(self, name):
+        """Get inputs index via input name
+        Parameters
+        ----------
+        name : str
+           The input key name
+        Returns

Review comment:
       ```suggestion
   
           Returns
   ```

##########
File path: python/tvm/contrib/graph_executor.py
##########
@@ -242,6 +243,19 @@ def get_input(self, index, out=None):
 
         return self._get_input(index)
 
+    def get_input_index(self, name):
+        """Get inputs index via input name
+        Parameters
+        ----------
+        name : str
+           The input key name
+        Returns
+        -------
+        index: int
+            The input index

Review comment:
       ```suggestion
               The input index. -1 will be returned if the given input name is 
not found.
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to