AnandInguva commented on code in PR #28161:
URL: https://github.com/apache/beam/pull/28161#discussion_r1308781253
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -504,6 +605,74 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
for mh in self._id_to_mh_map.values():
mh.validate_inference_args(inference_args)
+ def update_model_paths(
+ self,
+ model: Union[ModelT, _ModelManager],
+ model_paths: List[KeyModelPathMapping[KeyT]] = None):
+ # When there are many models, the model handler is responsible for
+ # reorganizing the model handlers into cohorts and telling the model
+ # manager to update every cohort's associated model handler. The model
+ # manager is responsible for performing the updates and tracking which
+ # updates have already been applied.
+ if model_paths is None or len(model_paths) == 0 or model is None:
+ return
+ if self._single_model:
+ raise RuntimeError(
+ 'Invalid model update: sent many model paths to '
+ 'update, but KeyedModelHandler is wrapping a single '
+ 'model.')
+ # Map cohort ids to a dictionary mapping new model paths to the keys that
+ # were originally in that cohort. We will use this to construct our new
+ # cohorts.
+ # cohort_path_mapping will be structured as follows:
+ # {
+ # original_cohort_id: {
+ # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
+ # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
+ # }
+ # }
+ cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
+ seen_keys = set()
+ for mp in model_paths:
+ keys = mp.keys
+ update_path = mp.update_path
+ if len(update_path) == 0:
+ raise ValueError(f'Invalid model update, path for {keys} is empty')
+ for key in keys:
+ if key in seen_keys:
+ raise ValueError(
+ f'Invalid model update: {key} appears in multiple '
+ 'update lists.')
Review Comment:
Can we give a more clear error message?
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -934,6 +942,204 @@ def process(self, element):
assert_that(result_pcoll, equal_to(expected_result))
+ def test_run_inference_side_input_in_batch_per_key_models(self):
+ first_ts = math.floor(time.time()) - 30
+ interval = 7
+
+ sample_main_input_elements = ([
+ (first_ts - 2, 'key1'),
+ (first_ts + 1, 'key2'),
+ (first_ts + 8, 'key2'),
+ (first_ts + 15, 'key1'),
+ (first_ts + 22, 'key2'),
+ (first_ts + 29, 'key1'),
+ ])
+
+ sample_side_input_elements = [
+ (
+ first_ts + 1,
+ [
+ base.KeyModelPathMapping(
+ keys=['key1'], update_path='fake_model_id_default'),
+ base.KeyModelPathMapping(
+ keys=['key2'], update_path='fake_model_id_default')
+ ]),
+ # if model_id is empty string, we use the default model
+ # handler model URI.
+ (
+ first_ts + 8,
+ [
+ base.KeyModelPathMapping(
+ keys=['key1'], update_path='fake_model_id_1'),
+ base.KeyModelPathMapping(
+ keys=['key2'], update_path='fake_model_id_default')
+ ]),
+ (
+ first_ts + 15,
+ [
+ base.KeyModelPathMapping(
+ keys=['key1'], update_path='fake_model_id_1'),
+ base.KeyModelPathMapping(
+ keys=['key2'], update_path='fake_model_id_2')
+ ]),
+ ]
+
+ model_handler = base.KeyedModelHandler([
+ base.KeyMhMapping(['key1'],
+ FakeModelHandlerReturnsPredictionResult(
+ multi_process_shared=True, state=True)),
+ base.KeyMhMapping(['key2'],
+ FakeModelHandlerReturnsPredictionResult(
+ multi_process_shared=True, state=True))
+ ])
+
+ class _EmitElement(beam.DoFn):
+ def process(self, element):
+ for e in element:
+ yield e
+
+ with TestPipeline() as pipeline:
+ side_input = (
+ pipeline
+ |
+ "CreateSideInputElements" >> beam.Create(sample_side_input_elements)
+ | beam.Map(lambda x: TimestampedValue(x[1], x[0]))
+ | beam.WindowInto(
+ window.FixedWindows(interval),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING)
+ | beam.Map(lambda x: ('key', x))
+ | beam.GroupByKey()
+ | beam.Map(lambda x: x[1])
+ | "EmitSideInput" >> beam.ParDo(_EmitElement()))
+
+ result_pcoll = (
+ pipeline
+ | beam.Create(sample_main_input_elements)
+ | "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x[0]))
+ | "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
+ | beam.Map(lambda x: (x[1], x[0]))
Review Comment:
We can remove this here and swap the positions of elements in the main input
tuple.
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -504,6 +605,74 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
for mh in self._id_to_mh_map.values():
mh.validate_inference_args(inference_args)
+ def update_model_paths(
+ self,
+ model: Union[ModelT, _ModelManager],
+ model_paths: List[KeyModelPathMapping[KeyT]] = None):
+ # When there are many models, the model handler is responsible for
+ # reorganizing the model handlers into cohorts and telling the model
+ # manager to update every cohort's associated model handler. The model
+ # manager is responsible for performing the updates and tracking which
+ # updates have already been applied.
+ if model_paths is None or len(model_paths) == 0 or model is None:
+ return
+ if self._single_model:
+ raise RuntimeError(
+ 'Invalid model update: sent many model paths to '
+ 'update, but KeyedModelHandler is wrapping a single '
+ 'model.')
+ # Map cohort ids to a dictionary mapping new model paths to the keys that
+ # were originally in that cohort. We will use this to construct our new
+ # cohorts.
+ # cohort_path_mapping will be structured as follows:
+ # {
+ # original_cohort_id: {
+ # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
+ # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
+ # }
+ # }
+ cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
+ seen_keys = set()
+ for mp in model_paths:
+ keys = mp.keys
+ update_path = mp.update_path
+ if len(update_path) == 0:
+ raise ValueError(f'Invalid model update, path for {keys} is empty')
+ for key in keys:
+ if key in seen_keys:
+ raise ValueError(
+ f'Invalid model update: {key} appears in multiple '
+ 'update lists.')
+ seen_keys.add(key)
+ if key not in self._key_to_id_map:
+ raise ValueError(
+ f'Invalid model update: {key} appears in '
+ 'update, but not in the original configuration.')
+ cohort_id = self._key_to_id_map[key]
+ if cohort_id not in cohort_path_mapping:
Review Comment:
to understand, `cohart_id` is the 0th index in the list of keys that are
mapped to a single ModelHandler in the self._key_to_id_map, right?
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -504,6 +605,74 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
for mh in self._id_to_mh_map.values():
mh.validate_inference_args(inference_args)
+ def update_model_paths(
+ self,
+ model: Union[ModelT, _ModelManager],
+ model_paths: List[KeyModelPathMapping[KeyT]] = None):
+ # When there are many models, the model handler is responsible for
+ # reorganizing the model handlers into cohorts and telling the model
+ # manager to update every cohort's associated model handler. The model
+ # manager is responsible for performing the updates and tracking which
+ # updates have already been applied.
+ if model_paths is None or len(model_paths) == 0 or model is None:
+ return
+ if self._single_model:
+ raise RuntimeError(
+ 'Invalid model update: sent many model paths to '
+ 'update, but KeyedModelHandler is wrapping a single '
+ 'model.')
+ # Map cohort ids to a dictionary mapping new model paths to the keys that
+ # were originally in that cohort. We will use this to construct our new
+ # cohorts.
+ # cohort_path_mapping will be structured as follows:
+ # {
+ # original_cohort_id: {
+ # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
+ # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
+ # }
+ # }
+ cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
+ seen_keys = set()
+ for mp in model_paths:
+ keys = mp.keys
+ update_path = mp.update_path
+ if len(update_path) == 0:
+ raise ValueError(f'Invalid model update, path for {keys} is empty')
+ for key in keys:
+ if key in seen_keys:
+ raise ValueError(
+ f'Invalid model update: {key} appears in multiple '
+ 'update lists.')
+ seen_keys.add(key)
+ if key not in self._key_to_id_map:
+ raise ValueError(
+ f'Invalid model update: {key} appears in '
+ 'update, but not in the original configuration.')
+ cohort_id = self._key_to_id_map[key]
+ if cohort_id not in cohort_path_mapping:
Review Comment:
Since a single key will be present in a unqiue `cohort`, you initially
mapped `keys[0]` to `key`s and now using that `keys[0]`, we construct a new
cohort
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -316,6 +366,32 @@ def increment_max_models(self, increment: int):
" models mode).")
self._max_models += increment
+ def update_model_handler(self, key, model_path, previous_key):
Review Comment:
add type annotations
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -264,10 +306,17 @@ def __init__(
allow unlimited models.
"""
self._max_models = max_models
+ # Map keys to model handlers
self._mh_map: Dict[str, ModelHandler] = mh_map
- self._proxy_map: Dict[str, str] = {}
- self._tag_map: Dict[
- str, multi_process_shared.MultiProcessShared] = OrderedDict()
+ # Map keys to the last model update for that key
Review Comment:
```suggestion
# Map keys to the last updated model path for that key
```
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -504,6 +605,74 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
for mh in self._id_to_mh_map.values():
mh.validate_inference_args(inference_args)
+ def update_model_paths(
+ self,
+ model: Union[ModelT, _ModelManager],
+ model_paths: List[KeyModelPathMapping[KeyT]] = None):
+ # When there are many models, the model handler is responsible for
Review Comment:
```suggestion
# When there are many models, the keyed model handler is responsible for
```
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -504,6 +605,74 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
for mh in self._id_to_mh_map.values():
mh.validate_inference_args(inference_args)
+ def update_model_paths(
+ self,
+ model: Union[ModelT, _ModelManager],
+ model_paths: List[KeyModelPathMapping[KeyT]] = None):
+ # When there are many models, the model handler is responsible for
+ # reorganizing the model handlers into cohorts and telling the model
+ # manager to update every cohort's associated model handler. The model
+ # manager is responsible for performing the updates and tracking which
+ # updates have already been applied.
+ if model_paths is None or len(model_paths) == 0 or model is None:
+ return
+ if self._single_model:
+ raise RuntimeError(
+ 'Invalid model update: sent many model paths to '
+ 'update, but KeyedModelHandler is wrapping a single '
+ 'model.')
+ # Map cohort ids to a dictionary mapping new model paths to the keys that
+ # were originally in that cohort. We will use this to construct our new
+ # cohorts.
+ # cohort_path_mapping will be structured as follows:
+ # {
+ # original_cohort_id: {
+ # 'update/path/1': ['key1FromOriginalCohort', key2FromOriginalCohort'],
+ # 'update/path/2': ['key3FromOriginalCohort', key4FromOriginalCohort'],
+ # }
+ # }
+ cohort_path_mapping: Dict[KeyT, Dict[str, List[KeyT]]] = {}
+ seen_keys = set()
+ for mp in model_paths:
+ keys = mp.keys
+ update_path = mp.update_path
+ if len(update_path) == 0:
+ raise ValueError(f'Invalid model update, path for {keys} is empty')
+ for key in keys:
+ if key in seen_keys:
+ raise ValueError(
+ f'Invalid model update: {key} appears in multiple '
+ 'update lists.')
+ seen_keys.add(key)
+ if key not in self._key_to_id_map:
+ raise ValueError(
+ f'Invalid model update: {key} appears in '
+ 'update, but not in the original configuration.')
+ cohort_id = self._key_to_id_map[key]
+ if cohort_id not in cohort_path_mapping:
+ cohort_path_mapping[cohort_id] = defaultdict(list)
+ cohort_path_mapping[cohort_id][update_path].append(key)
+ for key in self._key_to_id_map:
+ if key not in seen_keys:
+ raise ValueError(
+ f'Invalid model update: {key} appears in the '
+ 'original configuration, but not the update.')
+
+ # We now have our new set of cohorts. For each one, update our local model
+ # handler configuration and send the results to the ModelManager
+ for old_cohort_id, path_key_mapping in cohort_path_mapping.items():
+ for updated_path, keys in path_key_mapping.items():
+ cohort_id = old_cohort_id
+ if old_cohort_id not in keys:
+ # Create new cohort
+ cohort_id = keys[0]
+ for key in keys:
+ self._key_to_id_map[key] = cohort_id
Review Comment:
This will be same as the one during the `__init__` call right? why do we
need to update? I may be wrong
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]