corehalt edited a comment on issue #4919: [Relay][Pass] Don't consider 
constants as free vars in MergeComposite
URL: https://github.com/apache/incubator-tvm/pull/4919#issuecomment-591286696
 
 
   > I have a feeling you may be able to get around this problem by using 
bind_params_by_name. I have to do this when using merge composite because 
weight parameters are still treating as variables until this pass it called, at 
which time they are replaced with constants.
   > 
   > Could you try running:
   > 
   > ```
   > f = relay.build_module.bind_params_by_name(mod["main"], params)
   > mod = tvm.IRModule()
   > mod["main"] = f
   > ```
   > 
   > before the merge composite?
   
   @mbaret @soiferj  I tried to call `bind_params_by_name()` before merge 
composite. 
   If I don't call `bind_params_by_name()` I get something like:
   ```
     %155 = nn.relu(%154) /* ty=Tensor[(1, 1024, 7, 7), float32] */;
     %157 = fn (%scompiler_input52: Tensor[(1, 1024, 7, 7), float32], 
%scompiler_input53: Tensor[(1024, 1024, 1, 1), float32], Compiler="scompiler", 
ExternalSymbol="scompiler_0", Primitive=1) -> Tensor[(1, 1024, 7, 7), float32] {
       %156 = fn (%x26: Tensor[(1, 1024, 7, 7), float32], %y26: Tensor[(1024, 
1024, 1, 1), float32], Primitive=1, Composite="conv") -> Tensor[(1, 1024, 7, 
7), float32] {
         nn.conv2d(%x26, %y26, padding=[0, 0, 0, 0], channels=1024, 
kernel_size=[1, 1]) /* ty=Tensor[(1, 1024, 7, 7), float32] */
       };
       %156(%scompiler_input52, %scompiler_input53) /* ty=Tensor[(1, 1024, 7, 
7), float32] */
     };
     %158 = %157(%155, %separable_conv_block_13_conv2_weight) /* ty=Tensor[(1, 
1024, 7, 7), float32] */;
   ```
   But if I call `bind_params_by_name()` I can see something different:
   ```
     %129 = nn.relu(%128) /* ty=Tensor[(1, 1024, 7, 7), float32] */;
     %131 = fn (%scompiler_input52: Tensor[(1, 1024, 7, 7), float32], 
%scompiler_input53: Tensor[(1024, 1024, 1, 1), float32], Compiler="scompiler", 
ExternalSymbol="scompiler_0", Primitive=1) -> Tensor[(1, 1024, 7, 7), float32] {
       %130 = fn (%x26: Tensor[(1, 1024, 7, 7), float32], %y26: Tensor[(1024, 
1024, 1, 1), float32], Primitive=1, Composite="conv") -> Tensor[(1, 1024, 7, 
7), float32] {
         nn.conv2d(%x26, %y26, padding=[0, 0, 0, 0], channels=1024, 
kernel_size=[1, 1]) /* ty=Tensor[(1, 1024, 7, 7), float32] */
       };
       %130(%scompiler_input52, %scompiler_input53) /* ty=Tensor[(1, 1024, 7, 
7), float32] */
     };
     %132 = %131(%129, meta[relay.Constant][52] /* ty=Tensor[(1024, 1024, 1, 
1), float32] */ /* ty=Tensor[(1024, 1024, 1, 1), float32] */) /* ty=Tensor[(1, 
1024, 7, 7), float32] */;
   ```
   The thing is that later `meta[relay.Constant][52]` is exposed as a 
Relay.VarNode on the codegen anyway. Is there any way of accessing these 
constants from a custom codegen?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to