[
https://issues.apache.org/jira/browse/BEAM-14337?focusedWorklogId=775543&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-775543
]
ASF GitHub Bot logged work on BEAM-14337:
-----------------------------------------
Author: ASF GitHub Bot
Created on: 27/May/22 17:53
Start Date: 27/May/22 17:53
Worklog Time Spent: 10m
Work Description: yeandy commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r883863798
##########
sdks/python/apache_beam/ml/inference/pytorch_test.py:
##########
@@ -43,10 +43,14 @@
raise unittest.SkipTest('PyTorch dependencies are not installed')
-def _compare_prediction_result(a, b):
- return (
- torch.equal(a.inference, b.inference) and
- torch.equal(a.example, b.example))
+def _compare_prediction_result(x, y):
Review Comment:
As opposed to the other pattern
```
for actual, expected in zip(predictions, expected_predictions):
self.assertTrue(_compare_prediction_result(actual, expected))
```
which will won't return anything meaningful besides saying that `False` is
not `True` (which I've already fixed), the call
```
assert_that(
predictions,
equal_to(expected_predictions, equals_fn=_compare_prediction_result))
```
will output the details of how `a` and `b` are not equal. So I will keep
this function name the same.
Issue Time Tracking
-------------------
Worklog Id: (was: 775543)
Time Spent: 6h 20m (was: 6h 10m)
> Support **kwargs for PyTorch models.
> ------------------------------------
>
> Key: BEAM-14337
> URL: https://issues.apache.org/jira/browse/BEAM-14337
> Project: Beam
> Issue Type: Sub-task
> Components: sdk-py-core
> Reporter: Anand Inguva
> Assignee: Andy Ye
> Priority: P2
> Time Spent: 6h 20m
> Remaining Estimate: 0h
>
> Some models in Pytorch instantiating from torch.nn.Module, has extra
> parameters in the forward function call. These extra parameters can be passed
> as Dict or as positional arguments.
> Example of PyTorch models supported by Hugging Face ->
> [https://huggingface.co/bert-base-uncased]
> [Some torch models on Hugging
> face|https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py]
> Eg:
> [https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel]
> {code:java}
> inputs = {
> input_ids: Tensor1,
> attention_mask: Tensor2,
> token_type_ids: Tensor3,
> }
> model = BertModel.from_pretrained("bert-base-uncased") # which is a
> # subclass of torch.nn.Module
> outputs = model(**inputs) # model forward method should be expecting the keys
> in the inputs as the positional arguments.{code}
>
> [Transformers|https://pytorch.org/hub/huggingface_pytorch-transformers/]
> integrated in Pytorch is supported by Hugging Face as well.
>
--
This message was sent by Atlassian Jira
(v8.20.7#820007)