masahi commented on a change in pull request #4964: [Torch] Add initial control 
flow support 
URL: https://github.com/apache/incubator-tvm/pull/4964#discussion_r386070089
 
 

 ##########
 File path: python/tvm/relay/frontend/pytorch.py
 ##########
 @@ -955,7 +1025,100 @@ def parse_params(graph, state_dict):
     return params, param_tensors
 
 
-def parse_operators(operators, outputs, output_index_map, ret_name):
+def convert_block(block, outputs, output_index_map):
+    """ Translate Torch "Block", used for prim::If and prim::Loop """
+    ops = _get_operator_nodes(block.nodes())
+    ret_names = _get_input_names(block.returnNode())
+    return convert_operators(ops, outputs, output_index_map, ret_names)
+
+
+def convert_if(if_node, outputs, output_index_map):
+    """ Translate Torch prim::If to Relay If """
+    cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
+    blocks = list(if_node.blocks())
+    true_branch = convert_block(blocks[0], outputs, output_index_map)
+    false_branch = convert_block(blocks[1], outputs, output_index_map)
+    assert len(true_branch) == 1 and len(false_branch) == 1
+    return _expr.If(cond, true_branch[0], false_branch[0])
+
+
+def convert_loop(loop_node, outputs, output_index_map):
+    """ Translate Torch prim::Loop to Relay while_loop """
+    def get_input(index):
+        ivalue = loop_node.inputsAt(index)
+        inode = ivalue.node()
+        if inode.kind() == "prim::Constant":
+            return _expr.const(_get_constant(inode))
+        var_name = ivalue.debugName()
+        assert var_name in output_index_map
+        return _wrap_const(outputs[output_index_map[var_name]])
+
+    # Refer to the spec for prim::Loop below
+    # 
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
+    # The first input: %max_trip_count
+    # The second input: %initial_condition
+    # The rest of input: loop variables
+    max_loop_count = get_input(0)
+    init_cond = get_input(1)
+    num_loop_var = len(list(loop_node.inputs())) - 2
+    init_vals = [get_input(i + 2) for i in range(num_loop_var)]
+
+    # For loop (not while loop) has always %initial_condition being 1
+    is_for_loop = isinstance(init_cond, _expr.Constant)
 
 Review comment:
   I think the fact that `%i` is not used in while case is an implementation 
detail of torch interpreter. Since iter var increment is not explicit in torch, 
even for `for` loop case `%i` has no user if we only look at the graph. See 
below. 
   
   ```Python
   class SimpleLoop(torch.nn.Module):
       def forward(self, inp):
           a = inp
           for i in range(10):
               a += float(1)
           return a
   
   ```
   
   ```
   graph(%self : __torch__.SimpleLoop,
         %inp.1 : Tensor):
     %2 : bool = prim::Constant[value=1]() # test.py:9:8
     %3 : int = prim::Constant[value=10]() # test.py:9:23
     %4 : int = prim::Constant[value=1]() # test.py:10:23
     %a : Tensor = prim::Loop(%3, %2, %inp.1) # test.py:9:8
       block0(%i : int, %a.5 : Tensor):
         %8 : float = prim::Constant[value=1.]()
         %a.2 : Tensor = aten::add_(%a.5, %8, %4) # test.py:10:12
         -> (%2, %a.2)
     return (%a)
   
   ``` 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to