guberti opened a new issue, #13330: URL: https://github.com/apache/tvm/issues/13330
When using `te.extern_primfunc`, we need to ensure the tensors we pass in are the same size as the buffers we are loading them into. This is done by `python/tvm/te/operation.py`, which has [the following code](https://github.com/apache/tvm/blob/0e395c389ccd173cf6c1f254b47a81e715762626/python/tvm/te/operation.py#L364): ```python def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimFunc, **kwargs): access_map = { k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items() } in_buffers = [buf for buf, access in access_map.items() if len(access[0])] input_buffers = in_buffers # [some lines omitted for brevity] for obuf in out_buffers: assert len(input_buffers) == len(input_tensors), ( "The number of provided input input_tensors does not match the number of ", "input buffers in the primfunc", ) for tensor, buffer in zip(input_tensors, input_buffers): assert len(tensor.shape) == len(buffer.shape) for d1, d2 in zip(tensor.shape, buffer.shape): assert d1 == d2, ( "The input input_tensors provided do not match the input buffers in the ", "primfunc. Please check that the order of input te.Input_Tensors and the ", "order of the primfunc variables in the params list agree.", ) ``` Specifically, by using `zip` this code requires that the order of tensors in the list `input_tensors` match the order of buffers in the **dictionary** `input_buffers`. It's easy to see how this might pose a problem. However, the issue is slightly more nuanced. For Python 3.6 and above, [dictionaries preserve order](https://www.blog.pythonlibrary.org/2018/02/27/python-3-7-dictionaries-now-ordered). Furthermore, [there are unittests](https://github.com/apache/tvm/blob/main/tests/python/unittest/test_tir_te_extern_primfunc.py) for the code in question - why doesn't it always break? While `extern_primfunc` probably shouldn't use dictionaries to store buffers since order is important, the issue isn't here. Instead, the problem is that `tvm.arith._ffi_api.DomainTouchedAccessMap` returns a `tvm.ir.container.Map`. When used with a small number of elements (e.g. 3), `tvm.ir.container.Map` happens to preserve order, but it does not guarantee this (which is why this bug is not detected by the unit tests). However, when a larger number (e.g. 5) elements are present in the map, an error is produced: ``` E tvm._ffi.base.TVMError: Traceback (most recent call last): E 22: TVMFuncCall [some lines omitted for brevity] E File "/home/guberti/tvm/python/tvm/te/operation.py", line 440, in extern_primfunc E assert len(tensor.shape) == len(buffer.shape) E TVMError: AssertionError ``` ### Suggested fix As far as I can tell, `extern_primfunc` needs to be given the buffers in the same order as tensors. Thus, `DomainTouchedAccessMap` should be changed to not return a map, and instead return an list of tuples. I'm blocked by this bug, so I'm just gonna fix it. * tir:arith CC @areusch, @Lunderberg -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
