damccorm commented on code in PR #28161:
URL: https://github.com/apache/beam/pull/28161#discussion_r1309027057
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -504,6 +605,74 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
for mh in self._id_to_mh_map.values():
mh.validate_inference_args(inference_args)
+ def update_model_paths(
+ self,
+ model: Union[ModelT, _ModelManager],
+ model_paths: List[KeyModelPathMapping[KeyT]] = None):
+ # When there are many models, the model handler is responsible for
+ # reorganizing the model handlers into cohorts and telling the model
+ # manager to update every cohort's associated model handler. The model
+ # manager is responsible for performing the updates and tracking which
+ # updates have already been applied.
+ if model_paths is None or len(model_paths) == 0 or model is None:
+ return
+ if self._single_model:
+ raise RuntimeError(
+ 'Invalid model update: sent many model paths to '
+ 'update, but KeyedModelHandler is wrapping a single '
+ 'model.')
+ # Map cohort ids to a dictionary mapping new model paths to the keys that
+ # were originally in that cohort. We will use this to construct our new
+ # cohorts.
+ # cohort_path_mapping will be structured as follows:
+ # {
+ # original_cohort_id: {
+ # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
+ # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
+ # }
+ # }
+ cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
+ seen_keys = set()
+ for mp in model_paths:
+ keys = mp.keys
+ update_path = mp.update_path
+ if len(update_path) == 0:
+ raise ValueError(f'Invalid model update, path for {keys} is empty')
+ for key in keys:
+ if key in seen_keys:
+ raise ValueError(
+ f'Invalid model update: {key} appears in multiple '
+ 'update lists.')
+ seen_keys.add(key)
+ if key not in self._key_to_id_map:
+ raise ValueError(
+ f'Invalid model update: {key} appears in '
+ 'update, but not in the original configuration.')
+ cohort_id = self._key_to_id_map[key]
+ if cohort_id not in cohort_path_mapping:
Review Comment:
Yes, that's correct
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -504,6 +605,74 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
for mh in self._id_to_mh_map.values():
mh.validate_inference_args(inference_args)
+ def update_model_paths(
+ self,
+ model: Union[ModelT, _ModelManager],
+ model_paths: List[KeyModelPathMapping[KeyT]] = None):
+ # When there are many models, the model handler is responsible for
+ # reorganizing the model handlers into cohorts and telling the model
+ # manager to update every cohort's associated model handler. The model
+ # manager is responsible for performing the updates and tracking which
+ # updates have already been applied.
+ if model_paths is None or len(model_paths) == 0 or model is None:
+ return
+ if self._single_model:
+ raise RuntimeError(
+ 'Invalid model update: sent many model paths to '
+ 'update, but KeyedModelHandler is wrapping a single '
+ 'model.')
+ # Map cohort ids to a dictionary mapping new model paths to the keys that
+ # were originally in that cohort. We will use this to construct our new
+ # cohorts.
+ # cohort_path_mapping will be structured as follows:
+ # {
+ # original_cohort_id: {
+ # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
+ # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
+ # }
+ # }
+ cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
+ seen_keys = set()
+ for mp in model_paths:
+ keys = mp.keys
+ update_path = mp.update_path
+ if len(update_path) == 0:
+ raise ValueError(f'Invalid model update, path for {keys} is empty')
+ for key in keys:
+ if key in seen_keys:
+ raise ValueError(
+ f'Invalid model update: {key} appears in multiple '
+ 'update lists.')
Review Comment:
Updated, let me know if it is clearer (or please suggest wording that would
help)
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -934,6 +942,204 @@ def process(self, element):
assert_that(result_pcoll, equal_to(expected_result))
+ def test_run_inference_side_input_in_batch_per_key_models(self):
+ first_ts = math.floor(time.time()) - 30
+ interval = 7
+
+ sample_main_input_elements = ([
+ (first_ts - 2, 'key1'),
+ (first_ts + 1, 'key2'),
+ (first_ts + 8, 'key2'),
+ (first_ts + 15, 'key1'),
+ (first_ts + 22, 'key2'),
+ (first_ts + 29, 'key1'),
+ ])
+
+ sample_side_input_elements = [
+ (
+ first_ts + 1,
+ [
+ base.KeyModelPathMapping(
+ keys=['key1'], update_path='fake_model_id_default'),
+ base.KeyModelPathMapping(
+ keys=['key2'], update_path='fake_model_id_default')
+ ]),
+ # if model_id is empty string, we use the default model
+ # handler model URI.
+ (
+ first_ts + 8,
+ [
+ base.KeyModelPathMapping(
+ keys=['key1'], update_path='fake_model_id_1'),
+ base.KeyModelPathMapping(
+ keys=['key2'], update_path='fake_model_id_default')
+ ]),
+ (
+ first_ts + 15,
+ [
+ base.KeyModelPathMapping(
+ keys=['key1'], update_path='fake_model_id_1'),
+ base.KeyModelPathMapping(
+ keys=['key2'], update_path='fake_model_id_2')
+ ]),
+ ]
+
+ model_handler = base.KeyedModelHandler([
+ base.KeyMhMapping(['key1'],
+ FakeModelHandlerReturnsPredictionResult(
+ multi_process_shared=True, state=True)),
+ base.KeyMhMapping(['key2'],
+ FakeModelHandlerReturnsPredictionResult(
+ multi_process_shared=True, state=True))
+ ])
+
+ class _EmitElement(beam.DoFn):
+ def process(self, element):
+ for e in element:
+ yield e
+
+ with TestPipeline() as pipeline:
+ side_input = (
+ pipeline
+ |
+ "CreateSideInputElements" >> beam.Create(sample_side_input_elements)
+ | beam.Map(lambda x: TimestampedValue(x[1], x[0]))
+ | beam.WindowInto(
+ window.FixedWindows(interval),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING)
+ | beam.Map(lambda x: ('key', x))
+ | beam.GroupByKey()
+ | beam.Map(lambda x: x[1])
+ | "EmitSideInput" >> beam.ParDo(_EmitElement()))
+
+ result_pcoll = (
+ pipeline
+ | beam.Create(sample_main_input_elements)
+ | "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x[0]))
+ | "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
+ | beam.Map(lambda x: (x[1], x[0]))
Review Comment:
Yeah, I initially felt like it read clearer this way when looking at the
examples, but I think this is more confusion than its worth.
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -504,6 +605,74 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
for mh in self._id_to_mh_map.values():
mh.validate_inference_args(inference_args)
+ def update_model_paths(
+ self,
+ model: Union[ModelT, _ModelManager],
+ model_paths: List[KeyModelPathMapping[KeyT]] = None):
+ # When there are many models, the model handler is responsible for
+ # reorganizing the model handlers into cohorts and telling the model
+ # manager to update every cohort's associated model handler. The model
+ # manager is responsible for performing the updates and tracking which
+ # updates have already been applied.
+ if model_paths is None or len(model_paths) == 0 or model is None:
+ return
+ if self._single_model:
+ raise RuntimeError(
+ 'Invalid model update: sent many model paths to '
+ 'update, but KeyedModelHandler is wrapping a single '
+ 'model.')
+ # Map cohort ids to a dictionary mapping new model paths to the keys that
+ # were originally in that cohort. We will use this to construct our new
+ # cohorts.
+ # cohort_path_mapping will be structured as follows:
+ # {
+ # original_cohort_id: {
+ # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
+ # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
+ # }
+ # }
+ cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
+ seen_keys = set()
+ for mp in model_paths:
+ keys = mp.keys
+ update_path = mp.update_path
+ if len(update_path) == 0:
+ raise ValueError(f'Invalid model update, path for {keys} is empty')
+ for key in keys:
+ if key in seen_keys:
+ raise ValueError(
+ f'Invalid model update: {key} appears in multiple '
+ 'update lists.')
+ seen_keys.add(key)
+ if key not in self._key_to_id_map:
+ raise ValueError(
+ f'Invalid model update: {key} appears in '
+ 'update, but not in the original configuration.')
+ cohort_id = self._key_to_id_map[key]
+ if cohort_id not in cohort_path_mapping:
+ cohort_path_mapping[cohort_id] = defaultdict(list)
+ cohort_path_mapping[cohort_id][update_path].append(key)
+ for key in self._key_to_id_map:
+ if key not in seen_keys:
+ raise ValueError(
+ f'Invalid model update: {key} appears in the '
+ 'original configuration, but not the update.')
+
+ # We now have our new set of cohorts. For each one, update our local model
+ # handler configuration and send the results to the ModelManager
+ for old_cohort_id, path_key_mapping in cohort_path_mapping.items():
+ for updated_path, keys in path_key_mapping.items():
+ cohort_id = old_cohort_id
+ if old_cohort_id not in keys:
+ # Create new cohort
+ cohort_id = keys[0]
+ for key in keys:
+ self._key_to_id_map[key] = cohort_id
Review Comment:
It could be different. For example, if we have:
`['key1', 'key2', 'key3'] => MH1` as our initial configuration, `key1` will
be the cohort leader for all 3 keys.
If we then get the following update:
```
['key1'] => 'path/1'
['key2', 'key3'] => 'path/2'
```
the cohort has split and we need to elect a new cohort leader for `key2` and
`key3`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]