Lunderberg commented on a change in pull request #8161:
URL: https://github.com/apache/tvm/pull/8161#discussion_r641832303



##########
File path: python/tvm/relay/backend/vm.py
##########
@@ -198,20 +198,29 @@ def _update_target(self, target):
         target = target if target else tvm.target.Target.current()
         if target is None:
             raise ValueError("Target is not set in env or passed as argument.")
-        tgts = {}
-        if isinstance(target, (str, tvm.target.Target)):
-            dev_type = tvm.tir.IntImm("int32", 
tvm.nd.device(str(target)).device_type)
-            tgts[dev_type] = tvm.target.Target(target)
-        elif isinstance(target, dict):
-            for dev, tgt in target.items():
-                dev_type = tvm.tir.IntImm("int32", 
tvm.nd.device(dev).device_type)
-                tgts[dev_type] = tvm.target.Target(tgt)
-        else:
+
+        elif isinstance(target, str):
+            target = {target: target}
+

Review comment:
       I think this line should be present, in order to maintain previous 
behavior.  The `target` variable comes from either `VMCompiler.lower` or 
`VMCompiler.optimize`, both of which state that the `target` parameter can be a 
string, a `tvm.target.Target`, or a dictionary.  If this line were removed, 
then there would be a type error later on when `target` is accessed as a 
dictionary, if a `str` argument is passed in.
   
   The modified version moves all the input-handling into a single location, 
making a dictionary regardless of the input value, so that the remainder of the 
function doesn't need to repeat the logic for how to set the output `tgts` 
dictionary.

##########
File path: python/tvm/relay/backend/vm.py
##########
@@ -198,20 +198,29 @@ def _update_target(self, target):
         target = target if target else tvm.target.Target.current()
         if target is None:
             raise ValueError("Target is not set in env or passed as argument.")
-        tgts = {}
-        if isinstance(target, (str, tvm.target.Target)):
-            dev_type = tvm.tir.IntImm("int32", 
tvm.nd.device(str(target)).device_type)
-            tgts[dev_type] = tvm.target.Target(target)
-        elif isinstance(target, dict):
-            for dev, tgt in target.items():
-                dev_type = tvm.tir.IntImm("int32", 
tvm.nd.device(dev).device_type)
-                tgts[dev_type] = tvm.target.Target(tgt)
-        else:
+
+        elif isinstance(target, str):
+            target = {target: target}
+
+        elif isinstance(target, tvm.target.Target):
+            target = {target.kind.name: target}
+

Review comment:
       Same comment as for the previous, but if a `tvm.target.Target` is passed 
to `VMCompiler.lower`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to