jrmccluskey commented on code in PR #32237:
URL: https://github.com/apache/beam/pull/32237#discussion_r1725075931
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1283,13 +1304,33 @@ def expand(
**resource_hints)
if self._with_exception_handling:
+ # On timeouts, report back to the central model metadata
+ # that the model is invalid
+ model_tag = self._model_tag
+ share_across_processes =
self._model_handler.share_model_across_processes(
+ )
+ timeout = self._timeout
+
+ def failure_callback(exception: Exception, element: Any):
+ if type(exception) is not TimeoutError:
+ return
+ model_metadata = load_model_status(model_tag, share_across_processes)
+ model_metadata.try_mark_current_model_invalid(timeout)
+ logging.warning("Operation timed out, etc…….")
Review Comment:
Can we improve this warning message?
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1445,6 +1502,110 @@ def next_model_index(self, num_models):
return self._cur_index
+class _ModelStatus():
+ """A class holding any metadata about a model required by RunInference.
+
+ Currently, this only includes whether or not the model is valid. Uses the
+ model tag to map models to metadata.
+ """
+ def __init__(self, share_model_across_processes: bool):
+ self._active_tags = set()
+ self._invalid_tags = set()
+ self._tag_mapping = {}
+ self._model_first_seen = {}
+ self._pending_hard_delete = []
+ self._share_model_across_process = share_model_across_processes
+
+ def try_mark_current_model_invalid(self, min_model_life_seconds):
+ """Mark the current model invalid.
+
+ Since we don't have sufficient information to say which model is being
+ marked invalid, but there may be multiple active models, we will mark all
+ models currently in use as inactive so that they all get reloaded. To
+ avoid thrashing, however, we will only mark models as invalid if they've
+ been active at least min_model_life_seconds seconds.
+ """
+ cutoff_time = datetime.now() - timedelta(seconds=min_model_life_seconds)
+ for tag in list(self._active_tags):
+ if cutoff_time >= self._model_first_seen[tag]:
+ self._invalid_tags.add(tag)
+ # Delete old models after a grace period of 2 * the model life.
+ # This already happens automatically for shared.Shared models, so
+ # cleanup is only necessary for multi_process_shared models.
+ if self._share_model_across_process:
+ self._pending_hard_delete.append((
+ tag,
+ datetime.now() + 2 * timedelta(seconds=min_model_life_seconds)))
+ self._active_tags.remove(tag)
+
+ def get_valid_tag(self, tag: str) -> str:
+ """Takes in a proposed valid tag and returns a valid one.
+
+ Will always return a valid tag. If the passed in tag is valid, this
+ function will simply return it, otherwise it will deterministically
+ generate a new tag to use instead. The new tag will be the original tag
+ with an incrementing suffix (e.g. `my_tag_1`, `my_tag_2`) for each reload
Review Comment:
the incrementing suffix seems to only apply to the `{tag}_reload_{i}` tags
here, not `{tag}_{i}`
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1477,26 +1638,31 @@ def __init__(
model_handler: ModelHandler[ExampleT, PredictionT, Any],
clock,
metrics_namespace,
- enable_side_input_loading: bool = False,
+ load_model_at_runtime: bool = False,
Review Comment:
This is pretty convenient functionality for large models too, good extra get
here
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1305,7 +1346,12 @@ def expand(
return results
def with_exception_handling(
- self, *, exc_class=Exception, use_subprocess=False, threshold=1):
+ self,
+ *,
+ exc_class=Exception,
+ use_subprocess=False,
+ threshold=1,
+ timeout=None):
Review Comment:
Could type this as `Optional[int]=None`, also needs a unit either in the var
name or in the docstring
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1477,26 +1638,31 @@ def __init__(
model_handler: ModelHandler[ExampleT, PredictionT, Any],
clock,
metrics_namespace,
- enable_side_input_loading: bool = False,
+ load_model_at_runtime: bool = False,
model_tag: str = "RunInference"):
"""A DoFn implementation generic to frameworks.
Args:
model_handler: An implementation of ModelHandler.
clock: A clock implementing time_ns. *Used for unit testing.*
metrics_namespace: Namespace of the transform to collect metrics.
- enable_side_input_loading: Bool to indicate if model updates
- with side inputs.
+ enable_side_input_loading: Bool to indicate if model loading should be
Review Comment:
```suggestion
load_model_at_runtime: Bool to indicate if model loading should be
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]