Issue 151584
Summary MLIR Enum Python bindings infinite recursion
Labels mlir
Assignees
Reporter nsmithtt
    This bug specifically arises when using `I32BitEnumAttrCaseGroup` and python bindings.

## Example repro

Example tablegen:
```td
def TTCore_ChipCapabilityPCIE : I32BitEnumAttrCaseBit<"PCIE", 0, "pcie">;
def TTCore_ChipCapabilityHostMMIO : I32BitEnumAttrCaseBit<"HostMMIO", 1, "host_mmio">;
def TTCore_ChipCapabilityAll : I32BitEnumAttrCaseGroup<"All",
 [TTCore_ChipCapabilityPCIE, TTCore_ChipCapabilityHostMMIO], "all">;

def TTCore_ChipCapability : I32BitEnumAttr<"ChipCapability", "TT Chip Capabilities",
                           [
 TTCore_ChipCapabilityPCIE,
 TTCore_ChipCapabilityHostMMIO,
 TTCore_ChipCapabilityAll,
                           ]> {
  let genSpecializedAttr = 1;
  let cppNamespace = "::mlir::tt::ttcore";
}
```

Generates the following python binding:
```python
class ChipCapability(IntFlag):
    """TT Chip Capabilities"""

    PCIE = 1
    HostMMIO = 2
    All = 3

    def __iter__(self):
        return iter([case for case in type(self) if (self & case) is case])
    def __len__(self):
        return bin(self).count("1")

    def __str__(self):
        if len(self) > 1:
 return "|".join(map(str, self))
        if self is ChipCapability.PCIE:
            return "pcie"
        if self is ChipCapability.HostMMIO:
            return "host_mmio"
        if self is ChipCapability.All:
            return "all"
        raise ValueError("Unknown ChipCapability enum entry.")
```

The following sequence results in infinite recursion:
1. Call to `__str__`, we take the first branch and map over each element of the `IntFlag` enum class.
2. We go into `__iter__`.  It's especially useful to `print(list(case for case in type(self)))` right here, we can see:
```
[<ChipCapability.PCIE: 1>, <ChipCapability.HostMMIO: 2>, <ChipCapability.All: 3>]
```
3. The existence of `<ChipCapability.All: 3>` causes this to be a valid case to be returned from `__iter__` which generates infinite recursion.

## Proposed Fix

Proposed fix is to filter the iteration when the case is not equal to self. 
```python
    def __iter__(self):
        return iter([case for case in type(self) if (self & case) is case and self is not case])
```

_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to