vikingMei commented on issue #9180: fix parameter overwrite in 
_generate_symbol_function_code
URL: https://github.com/apache/incubator-mxnet/pull/9180#issuecomment-353770914
 
 
   @szha 
   
   below is how to reproduce this bug:
   ```python
   import mxnet as mx
   
   def network():
       data = mx.sym.Variable(name="data")
       topk = mx.sym.topk(data, ret_typ='both', k=3, bug_cause='abc')
   
       return topk
   
   ctx = mx.Context(mx.gpu(0)) 
   net = network()
   
   ex = net.bind(ctx, args={"data":mx.nd.array([1,2,3,4,5], ctx=ctx)})
   res = ex.forward()
   print(res)
   ```
   
   here is rutime error:
   ```
   Traceback (most recent call last):
     File "./test.py", line 16, in <module>
       net = network()
     File "./test.py", line 11, in network
       topk = mx.sym.topk(data, ret_typ='both', k=3, bug_cause='abc')
     File "<string>", line 85, in topk
     File 
"./mxnet/debug/lib/python3.6/site-packages/mxnet-1.0.1-py3.6.egg/mxnet/_ctypes/symbol.py",
 line 125, in _symbol_creator
       ctypes.byref(sym_handle)))
     File 
"./mxnet/debug/lib/python3.6/site-packages/mxnet-1.0.1-py3.6.egg/mxnet/base.py",
 line 146, in check_call
       raise MXNetError(py_str(_LIB.MXGetLastError()))
   mxnet.base.MXNetError: Invalid Parameter format for k expect int but 
value='bug_cause', in operator topk(name="", ret_typ="both", k="bug_cause", 
bug_cause="abc")
   ```
   
   **just add any kwargs to mx.sym.topk, this bug will trigged**
   
   blow are the details: 
   
   if print the code generate by _generate_symbol_function_code, we can find 
sth like this:
   
   ```python
   def topk(data=None, axis=_Null, k=_Null, ret_typ=_Null, is_ascend=_Null, 
name=None, attr=None, out=None, **kwargs):
       r"""Returns the top *k* elements in an input array along the given axis.
   
       ....    # many documents line
   
           Returns
       -------
       Symbol
           The result symbol.
       """   
       kwargs.update(AttrScope.current.get(attr))
       sym_kwargs = dict()
   
       keys = []
       vals = []
   
       for k, v in kwargs.items():
           if isinstance(v, SymbolBase):
               sym_kwargs[k] = v
           else: 
               keys.append(k)
               vals.append(v)
   
       if data is not None: 
           assert isinstance(data, SymbolBase), \
               "Argument data must be Symbol instances, but got %s"%str(data)
           sym_kwargs['data'] = data
   
       if axis is not _Null:
           keys.append('axis')
           vals.append(axis)
   
       if k is not _Null:
           keys.append('k')
           vals.append(k)
   
       if ret_typ is not _Null:
           keys.append('ret_typ')
           vals.append(ret_typ)
   
       if is_ascend is not _Null:
           keys.append('is_ascend')
           vals.append(is_ascend)
   
       name = NameManager.current.get(name, 'topk')
       return _symbol_creator(17208384, None, sym_kwargs, keys, vals, name) 
   ```
   
   **k is a argument of function call, and is also appear in the for loop. so, 
is the for loop activated, it will generate a new value for k, which will 
overwrite the value passed from caller, and cause the problem above.**
   
   if _keys_, _vals_ appear in argument, it will also overwrited by local 
version
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to