bzablocki commented on code in PR #28310:
URL: https://github.com/apache/beam/pull/28310#discussion_r1319893045


##########
sdks/python/apache_beam/yaml/yaml_transform.py:
##########
@@ -208,6 +229,77 @@ def get_outputs(self, transform_name):
   def compute_outputs(self, transform_id):
     return expand_transform(self._transforms_by_uuid[transform_id], self)
 
+  def best_provider(
+      self, t, input_providers: 
yaml_provider.Iterable[yaml_provider.Provider]):
+    if isinstance(t, dict):
+      spec = t
+    else:
+      spec = self._transforms_by_uuid[self.get_transform_id(t)]
+    possible_providers = [
+        p for p in self.providers[spec['type']] if p.available()
+    ]
+    if not possible_providers:
+      raise ValueError(
+          'No available provider for type %r at %s' %
+          (spec['type'], identify_object(spec)))
+    # From here on, we have the invariant that possible_providers is not empty.
+
+    # Only one possible provider, no need to rank further.
+    if len(possible_providers) == 1:
+      return possible_providers[0]
+
+    def best_matches(
+        possible_providers: Iterable[yaml_provider.Provider],
+        adjacent_provider_options: Iterable[Iterable[yaml_provider.Provider]]
+    ) -> List[yaml_provider.Provider]:
+      """Given a set of possible providers, and a set of providers for each
+      adjacent transform, returns the top possible providers as ranked by
+      affinity to the adjacent transforms' providers.
+      """
+      providers_by_score = collections.defaultdict(list)
+      for p in possible_providers:
+        # The sum of the affinity of the best provider
+        # for each adjacent transform.
+        providers_by_score[sum(
+            max(p.affinity(ap) for ap in apo)
+            for apo in adjacent_provider_options)].append(p)
+      return providers_by_score[max(providers_by_score.keys())]
+
+    # If there are any inputs, prefer to match them.
+    if input_providers:
+      possible_providers = best_matches(
+          possible_providers, [[p] for p in input_providers])
+
+    # Without __uuid__ we can't find downstream operations.
+    if '__uuid__' not in spec:
+      return possible_providers[0]
+
+    # Match against downstream transforms, continuing until there is no tie
+    # or we run out of downstream transforms.
+    if len(possible_providers) > 1:
+      adjacent_transforms = list(self.followers(spec['__uuid__']))
+      while adjacent_transforms:
+        # This is a list of all possible providers for each adjacent transform.
+        adjacent_provider_options = [[
+            p for p in self.providers[self._transforms_by_uuid[t]['type']]
+            if p.available()
+        ] for t in adjacent_transforms]
+        if any(not apo for apo in adjacent_provider_options):
+          # One of the transforms had no available providers.
+          # We will throw an error later, doesn't matter what we return.
+          break
+        # Filter down the set of possible providers to the best ones.
+        possible_providers = best_matches(
+            possible_providers, adjacent_provider_options)
+        # If we are down to one option, no ned to go further.

Review Comment:
   ```suggestion
           # If we are down to one option, no need to go further.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to