MasterJH5574 opened a new pull request #9358:
URL: https://github.com/apache/tvm/pull/9358
This PR updates TVM script parser/printer in the following aspects:
* support parsing Python lambdas: a lambda can have arbitrary parameters and
an arbitrary body;
* support parsing `tir.CommReducer`:
* for single-group reducer, the syntax is like
```python
tir.comm_reducer(lambda x, y: x + y, [tir.float32(0)])
```
* for multiple-group reducer, the syntax is like
```python
tir.comm_reducer(
lambda x0, x1, y0, y1: (tir.Select((x1 >= y1), x0, y0),
tir.Select((x1 >= y1), x1, y1)),
[tir.int32(-1), tir.min_value("float32")]
)
```
* support printing `tir.CommReducer` according to the syntax above (the old
way and related stuffs to print `tir.CommReducer` is removed because they're
outdated).
Besides, this PR updates the constructor of `CommReducer` by
1. checking whether the inputs `lhs`, `rhs`, `result` and `identity_element`
all have the same length, and
2. converting the data type of variables in `lhs` and `rhs` so that their
data types are consistent with the data types in `identity_element`.
---
cc @Hzfengsy @junrushao1994 @spectrometerHBH
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]