gemini-code-assist[bot] commented on code in PR #18408:
URL: https://github.com/apache/tvm/pull/18408#discussion_r2477518058
##########
position_id_fix.py:
##########
@@ -0,0 +1,78 @@
+# sol-script-fixed.py
+import torch
+import torch.nn as nn
+from transformers import AutoModel
+from torch.export import export as torch_export
+from tvm.relax.frontend.torch import from_exported_program
+
+class StateDictWrapper(dict):
+ """Wrap exported state_dict and inject extra keys (non-persistent
buffers)."""
+ def __init__(self, base_dict, extra):
+ super().__init__(base_dict)
+ self.extra = extra
+
+ def __getitem__(self, key):
+ if key in self.extra:
+ return self.extra[key]
+ return super().__getitem__(key)
+
+ def get(self, key, default=None):
+ if key in self.extra:
+ return self.extra[key]
+ return super().get(key, default)
+
+class M(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
+ self.cls = nn.Linear(self.bert.config.hidden_size, 2)
+
+ def forward(self, x, mask=None):
+ out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
+ return self.cls(out)
+
+def main():
+ torch.manual_seed(0)
+ m = M().eval()
+
+ x = torch.randint(0, 30522, (2, 16))
+ mask = torch.ones_like(x)
+
+ ep = torch_export(m, (x, mask))
+ print("\n torch.export completed successfully\n")
+
+ # --- Build extra buffers dict ---
+ extra = {}
+ for buf_name in m.bert.embeddings._non_persistent_buffers_set:
+ tensor = m.bert.embeddings._buffers.get(buf_name)
+ if tensor is not None:
+ extra[f"bert.embeddings.{buf_name}"] = tensor
+ print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape
{tensor.shape}")
+
+ # Wrap exported state_dict
+ sd_wrapped = StateDictWrapper(ep.state_dict, extra)
+
+ # EP wrapper to override state_dict access
+ class EPWrapper:
+ def __init__(self, ep, sd_wrapped):
+ self.__dict__["_ep"] = ep
+ self.__dict__["_sd"] = sd_wrapped
+
+ def __getattr__(self, name):
+ if name == "state_dict":
+ return self._sd
+ return getattr(self._ep, name)
+
+ ep_wrapped = EPWrapper(ep, sd_wrapped)
+
+ # Import to TVM
+ try:
+ mod = from_exported_program(ep_wrapped)
+ print("\n TVM import succeeded — all non-persistent buffers
injected!\n")
+ except Exception as e:
+ print("\n TVM import failed with exception:")
+ import traceback
+ traceback.print_exc()
Review Comment:

It is a standard Python convention (PEP 8) to place all imports at the top
of the file. This improves readability and makes it easier to see the script's
dependencies at a glance. Please move `import traceback` to the top of the
script with the other imports.
```suggestion
traceback.print_exc()
```
##########
position_id_fix.py:
##########
@@ -0,0 +1,78 @@
+# sol-script-fixed.py
+import torch
+import torch.nn as nn
+from transformers import AutoModel
+from torch.export import export as torch_export
+from tvm.relax.frontend.torch import from_exported_program
+
+class StateDictWrapper(dict):
+ """Wrap exported state_dict and inject extra keys (non-persistent
buffers)."""
+ def __init__(self, base_dict, extra):
+ super().__init__(base_dict)
+ self.extra = extra
+
+ def __getitem__(self, key):
+ if key in self.extra:
+ return self.extra[key]
+ return super().__getitem__(key)
+
+ def get(self, key, default=None):
+ if key in self.extra:
+ return self.extra[key]
+ return super().get(key, default)
+
+class M(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
+ self.cls = nn.Linear(self.bert.config.hidden_size, 2)
+
+ def forward(self, x, mask=None):
+ out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
+ return self.cls(out)
+
+def main():
+ torch.manual_seed(0)
+ m = M().eval()
+
+ x = torch.randint(0, 30522, (2, 16))
Review Comment:

The value `30522` is a magic number, which appears to be the vocabulary size
for 'bert-base-multilingual-uncased'. It's better to fetch this value from the
model's configuration to improve readability and maintainability. This makes
the code more robust if the model changes.
```suggestion
x = torch.randint(0, m.bert.config.vocab_size, (2, 16))
```
##########
position_id_fix.py:
##########
@@ -0,0 +1,78 @@
+# sol-script-fixed.py
+import torch
+import torch.nn as nn
+from transformers import AutoModel
+from torch.export import export as torch_export
+from tvm.relax.frontend.torch import from_exported_program
+
+class StateDictWrapper(dict):
+ """Wrap exported state_dict and inject extra keys (non-persistent
buffers)."""
+ def __init__(self, base_dict, extra):
+ super().__init__(base_dict)
+ self.extra = extra
+
+ def __getitem__(self, key):
+ if key in self.extra:
+ return self.extra[key]
+ return super().__getitem__(key)
+
+ def get(self, key, default=None):
+ if key in self.extra:
+ return self.extra[key]
+ return super().get(key, default)
+
+class M(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
+ self.cls = nn.Linear(self.bert.config.hidden_size, 2)
+
+ def forward(self, x, mask=None):
+ out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
+ return self.cls(out)
+
+def main():
+ torch.manual_seed(0)
+ m = M().eval()
+
+ x = torch.randint(0, 30522, (2, 16))
+ mask = torch.ones_like(x)
+
+ ep = torch_export(m, (x, mask))
+ print("\n torch.export completed successfully\n")
+
+ # --- Build extra buffers dict ---
+ extra = {}
+ for buf_name in m.bert.embeddings._non_persistent_buffers_set:
+ tensor = m.bert.embeddings._buffers.get(buf_name)
+ if tensor is not None:
+ extra[f"bert.embeddings.{buf_name}"] = tensor
+ print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape
{tensor.shape}")
+
+ # Wrap exported state_dict
+ sd_wrapped = StateDictWrapper(ep.state_dict, extra)
+
+ # EP wrapper to override state_dict access
+ class EPWrapper:
+ def __init__(self, ep, sd_wrapped):
+ self.__dict__["_ep"] = ep
+ self.__dict__["_sd"] = sd_wrapped
+
+ def __getattr__(self, name):
+ if name == "state_dict":
+ return self._sd
+ return getattr(self._ep, name)
+
Review Comment:

This proxy implementation for `ExportedProgram` is minimal. A more robust
and idiomatic way to create this wrapper is by using a property for
`state_dict`. This avoids overriding `__getattr__` in a way that could be
brittle and makes the intent clearer. The suggested implementation is cleaner
and less prone to subtle bugs if the `from_exported_program` API has more
complex interactions with the object.
```suggestion
class EPWrapper:
def __init__(self, ep, sd_wrapped):
self._ep = ep
self._sd = sd_wrapped
@property
def state_dict(self):
return self._sd
def __getattr__(self, name):
return getattr(self._ep, name)
```
##########
position_id_fix.py:
##########
@@ -0,0 +1,78 @@
+# sol-script-fixed.py
+import torch
+import torch.nn as nn
+from transformers import AutoModel
+from torch.export import export as torch_export
+from tvm.relax.frontend.torch import from_exported_program
+
+class StateDictWrapper(dict):
+ """Wrap exported state_dict and inject extra keys (non-persistent
buffers)."""
+ def __init__(self, base_dict, extra):
+ super().__init__(base_dict)
+ self.extra = extra
+
+ def __getitem__(self, key):
+ if key in self.extra:
+ return self.extra[key]
+ return super().__getitem__(key)
+
+ def get(self, key, default=None):
+ if key in self.extra:
+ return self.extra[key]
+ return super().get(key, default)
+
+class M(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.bert = AutoModel.from_pretrained("bert-base-multilingual-uncased")
+ self.cls = nn.Linear(self.bert.config.hidden_size, 2)
+
+ def forward(self, x, mask=None):
+ out = self.bert(x, attention_mask=mask).last_hidden_state[:, 0, :]
+ return self.cls(out)
+
+def main():
+ torch.manual_seed(0)
+ m = M().eval()
+
+ x = torch.randint(0, 30522, (2, 16))
+ mask = torch.ones_like(x)
+
+ ep = torch_export(m, (x, mask))
+ print("\n torch.export completed successfully\n")
+
+ # --- Build extra buffers dict ---
+ extra = {}
+ for buf_name in m.bert.embeddings._non_persistent_buffers_set:
+ tensor = m.bert.embeddings._buffers.get(buf_name)
+ if tensor is not None:
+ extra[f"bert.embeddings.{buf_name}"] = tensor
+ print(f"Injecting buffer: bert.embeddings.{buf_name} -> shape
{tensor.shape}")
+
+ # Wrap exported state_dict
+ sd_wrapped = StateDictWrapper(ep.state_dict, extra)
Review Comment:

The custom `StateDictWrapper` class can be replaced by
`collections.ChainMap` for a more concise and idiomatic implementation.
`ChainMap` is designed for linking multiple dictionaries.
After this change, you can remove the `StateDictWrapper` class definition
(lines 8-22) and add `import collections` to the top of the file.
```suggestion
sd_wrapped = collections.ChainMap(extra, ep.state_dict)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]