[
https://issues.apache.org/jira/browse/BEAM-14044?focusedWorklogId=772099&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-772099
]
ASF GitHub Bot logged work on BEAM-14044:
-----------------------------------------
Author: ASF GitHub Bot
Created on: 18/May/22 19:28
Start Date: 18/May/22 19:28
Worklog Time Spent: 10m
Work Description: yeandy commented on code in PR #17527:
URL: https://github.com/apache/beam/pull/17527#discussion_r876268059
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -72,6 +72,21 @@ def process(self, prediction_result):
yield prediction_result.inference
+class FakeInferenceRunnerNeedsBigBatch(FakeInferenceRunner):
+ def run_inference(self, batch, unused_model):
+ if len(batch) < 100:
+ raise ValueError('Unexpectedly small batch')
+ return batch
+
+
+class FakeLoaderWithBatchArgForwarding(FakeModelLoader):
+ def get_inference_runner(self):
+ return FakeInferenceRunnerNeedsBigBatch()
+
+ def batch_elements_kwargs(self):
+ return {'min_batch_size': 9999}
Review Comment:
If the goal is to be able to set a max batch size of 1, I think it might
also be useful to explicitly test the case in which `batch_elements_kwargs`
returns `{max_batch_size=1}`.
Issue Time Tracking
-------------------
Worklog Id: (was: 772099)
Time Spent: 1.5h (was: 1h 20m)
> Hook In Batching DoFn Apis to RunInference
> ------------------------------------------
>
> Key: BEAM-14044
> URL: https://issues.apache.org/jira/browse/BEAM-14044
> Project: Beam
> Issue Type: Sub-task
> Components: sdk-py-core
> Reporter: Ryan Thompson
> Assignee: Brian Hulette
> Priority: P2
> Time Spent: 1.5h
> Remaining Estimate: 0h
>
> Hook into the batching DoFn APIs to the base RunInference interface.
> We should also investigate what defaults we should set for batching, and
> perhaps make that part of the API.
> See
> [s.apache.org/batched-dofns|https://www.google.com/url?q=http://s.apache.org/batched-dofns&sa=D&source=docs&ust=1646063987404027&usg=AOvVaw1VO9QgWlbAhx0Rh2Bzl1nw]
> for more details.
--
This message was sent by Atlassian Jira
(v8.20.7#820007)