damccorm commented on code in PR #27603:
URL: https://github.com/apache/beam/pull/27603#discussion_r1281147278


##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -964,6 +982,117 @@ def test_child_class_without_env_vars(self):
       actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  def test_model_manager_loads_shared_model(self):

Review Comment:
   I added a check for this in the test.



##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -964,6 +982,117 @@ def test_child_class_without_env_vars(self):
       actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  def test_model_manager_loads_shared_model(self):
+    mhs = {
+        'key1': FakeModelHandler(state=1),
+        'key2': FakeModelHandler(state=2),
+        'key3': FakeModelHandler(state=3)
+    }
+    mm = base._ModelManager(mh_map=mhs)
+    tag1 = mm.load('key1')
+    # Use bad_mh's load function to make sure we're actually loading the
+    # version already stored
+    bad_mh = FakeModelHandler(state=100)
+    model1 = multi_process_shared.MultiProcessShared(
+        bad_mh.load_model, tag=tag1).acquire()
+    self.assertEqual(1, model1.predict(10))
+
+    tag2 = mm.load('key2')
+    tag3 = mm.load('key3')
+    model2 = multi_process_shared.MultiProcessShared(
+        bad_mh.load_model, tag=tag2).acquire()
+    model3 = multi_process_shared.MultiProcessShared(
+        bad_mh.load_model, tag=tag3).acquire()
+    self.assertEqual(2, model2.predict(10))
+    self.assertEqual(3, model3.predict(10))
+
+  def test_model_manager_evicts_models(self):
+    mh1 = FakeModelHandler(state=1)
+    mh2 = FakeModelHandler(state=2)
+    mh3 = FakeModelHandler(state=3)
+    mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
+    mm = base._ModelManager(mh_map=mhs, max_models=2)
+    tag1 = mm.load('key1')
+    sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
+    model1 = sh1.acquire()
+    self.assertEqual(1, model1.predict(10))
+    model1.increment_state(5)
+    self.assertEqual(6, model1.predict(10))
+    sh1.release(model1)
+
+    tag2 = mm.load('key2')
+    tag3 = mm.load('key3')
+    sh2 = multi_process_shared.MultiProcessShared(mh2.load_model, tag=tag2)
+    model2 = sh2.acquire()
+    sh3 = multi_process_shared.MultiProcessShared(mh3.load_model, tag=tag3)
+    model3 = sh3.acquire()
+    model2.increment_state(5)
+    model3.increment_state(5)
+    self.assertEqual(7, model2.predict(10))
+    self.assertEqual(8, model3.predict(10))
+    sh2.release(model2)
+    sh3.release(model3)
+
+    # This should get recreated, so it shouldn't have the state updates
+    model1 = multi_process_shared.MultiProcessShared(
+        mh1.load_model, tag=tag1).acquire()
+    self.assertEqual(1, model1.predict(10))
+
+    # These should not get recreated, so they should have the state updates
+    model2 = multi_process_shared.MultiProcessShared(
+        mh2.load_model, tag=tag2).acquire()
+    self.assertEqual(7, model2.predict(10))
+    model3 = multi_process_shared.MultiProcessShared(
+        mh3.load_model, tag=tag3).acquire()
+    self.assertEqual(8, model3.predict(10))
+
+  def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
+      self):
+    mh1 = FakeModelHandler(state=1)
+    mh2 = FakeModelHandler(state=2)
+    mh3 = FakeModelHandler(state=3)
+    mhs = {'key1': mh1, 'key2': mh2, 'key3': mh3}
+    mm = base._ModelManager(mh_map=mhs, max_models=1)
+    mm.increment_max_models(1)
+    tag1 = mm.load('key1')
+    sh1 = multi_process_shared.MultiProcessShared(mh1.load_model, tag=tag1)
+    model1 = sh1.acquire()
+    self.assertEqual(1, model1.predict(10))
+    model1.increment_state(5)
+    self.assertEqual(6, model1.predict(10))
+    sh1.release(model1)
+
+    tag2 = mm.load('key2')
+    tag3 = mm.load('key3')
+    sh2 = multi_process_shared.MultiProcessShared(mh2.load_model, tag=tag2)
+    model2 = sh2.acquire()
+    sh3 = multi_process_shared.MultiProcessShared(mh3.load_model, tag=tag3)
+    model3 = sh3.acquire()
+    model2.increment_state(5)
+    model3.increment_state(5)
+    self.assertEqual(7, model2.predict(10))
+    self.assertEqual(8, model3.predict(10))
+    sh2.release(model2)
+    sh3.release(model3)
+
+    # This should get recreated, so it shouldn't have the state updates
+    model1 = multi_process_shared.MultiProcessShared(
+        mh1.load_model, tag=tag1).acquire()
+    self.assertEqual(1, model1.predict(10))
+
+    # These should not get recreated, so they should have the state updates
+    model2 = multi_process_shared.MultiProcessShared(
+        mh2.load_model, tag=tag2).acquire()
+    self.assertEqual(7, model2.predict(10))
+    model3 = multi_process_shared.MultiProcessShared(
+        mh3.load_model, tag=tag3).acquire()
+    self.assertEqual(8, model3.predict(10))

Review Comment:
   Do we have an easy mechanism for doing this? If so, I'm on board, I'm not 
sure I see the utility in doing so though, what advantage do you see it getting 
us?



##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -964,6 +982,117 @@ def test_child_class_without_env_vars(self):
       actual = pcoll | base.RunInference(FakeModelHandlerNoEnvVars())
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  def test_model_manager_loads_shared_model(self):
+    mhs = {
+        'key1': FakeModelHandler(state=1),
+        'key2': FakeModelHandler(state=2),
+        'key3': FakeModelHandler(state=3)
+    }
+    mm = base._ModelManager(mh_map=mhs)
+    tag1 = mm.load('key1')
+    # Use bad_mh's load function to make sure we're actually loading the
+    # version already stored
+    bad_mh = FakeModelHandler(state=100)
+    model1 = multi_process_shared.MultiProcessShared(
+        bad_mh.load_model, tag=tag1).acquire()
+    self.assertEqual(1, model1.predict(10))
+
+    tag2 = mm.load('key2')

Review Comment:
   We shouldn't because of the randomness introduced when we get the tag (the 
uuid bit); each model reference to a model will have a unique tag created for 
it, regardless of the key it represents. I also just tried running the suite 
1000 times and it succeeded each time.
   
   Note that this is actually an important property if we have multiple model 
managers operating at once (e.g. in a pipeline that does multiple sets of 
sequential inference)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to