robertwb commented on a change in pull request #14390:
URL: https://github.com/apache/beam/pull/14390#discussion_r616218047
##########
File path: sdks/python/apache_beam/pipeline_test.py
##########
@@ -960,6 +963,207 @@ def annotations(self):
transform.annotations['proto'], some_proto.SerializeToString())
self.assertEqual(seen, 2)
+ def test_runner_api_roundtrip_preserves_resource_hints(self):
+ p = beam.Pipeline()
+ _ = (
+ p | beam.Create([1, 2])
+ | beam.Map(lambda x: x + 1).with_resource_hints(accelerator='gpu'))
+
+ self.assertEqual(
+ p.transforms_stack[0].parts[1].transform.get_resource_hints(),
+ {common_urns.resource_hints.ACCELERATOR.urn: b'gpu'})
+
+ for _ in range(3):
+ # Verify that DEFAULT environments are recreated during multiple
RunnerAPI
+ # translation and hints don't get lost.
+ p = Pipeline.from_runner_api(
+ Pipeline.to_runner_api(p, use_fake_coders=True), None, None)
+ self.assertEqual(
+ p.transforms_stack[0].parts[1].transform.get_resource_hints(),
+ {common_urns.resource_hints.ACCELERATOR.urn: b'gpu'})
+
+ def test_hints_on_composite_transforms_are_propagated_to_subtransforms(self):
+ class FooHint(ResourceHint):
+ urn = 'foo_urn'
+
+ class BarHint(ResourceHint):
+ urn = 'bar_urn'
+
+ class BazHint(ResourceHint):
+ urn = 'baz_urn'
+
+ class QuxHint(ResourceHint):
+ urn = 'qux_urn'
+
+ class UseMaxValueHint(ResourceHint):
+ urn = 'use_max_value_urn'
+
+ @classmethod
+ def get_merged_value(
+ cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes
+ return ResourceHint._use_max(outer_value, inner_value)
+
+ ResourceHint.register_resource_hint('foo_hint', FooHint)
+ ResourceHint.register_resource_hint('bar_hint', BarHint)
+ ResourceHint.register_resource_hint('baz_hint', BazHint)
+ ResourceHint.register_resource_hint('qux_hint', QuxHint)
+ ResourceHint.register_resource_hint('use_max_value_hint', UseMaxValueHint)
+
+ @beam.ptransform_fn
+ def SubTransform(pcoll):
+ return pcoll | beam.Map(lambda x: x + 1).with_resource_hints(
+ foo_hint='set_on_subtransform', use_max_value_hint='10')
+
+ @beam.ptransform_fn
+ def CompositeTransform(pcoll):
+ return pcoll | beam.Map(lambda x: x * 2) | SubTransform()
+
+ p = beam.Pipeline()
+ _ = (
+ p | beam.Create([1, 2])
+ | CompositeTransform().with_resource_hints(
+ foo_hint='should_be_overriden_by_subtransform',
+ bar_hint='set_on_composite',
+ baz_hint='set_on_composite',
+ use_max_value_hint='100'))
+ options = PortableOptions([
+ '--resource_hint=baz_hint=should_be_overriden_by_composite',
+ '--resource_hint=qux_hint=set_via_options',
+ '--environment_type=PROCESS',
+ '--environment_option=process_command=foo',
+ '--sdk_location=container',
+ ])
+ environment = ProcessEnvironment.from_options(options)
+ proto = Pipeline.to_runner_api(
+ p, use_fake_coders=True, default_environment=environment)
+
+ for t in proto.components.transforms.values():
+ if "CompositeTransform/SubTransform/Map" in t.unique_name:
+ environment = proto.components.environments.get(t.environment_id)
+ self.assertEqual(
+ environment.resource_hints.get('foo_urn'), b'set_on_subtransform')
+ self.assertEqual(
+ environment.resource_hints.get('bar_urn'), b'set_on_composite')
+ self.assertEqual(
+ environment.resource_hints.get('baz_urn'), b'set_on_composite')
+ self.assertEqual(
+ environment.resource_hints.get('qux_urn'), b'set_via_options')
+ self.assertEqual(
+ environment.resource_hints.get('use_max_value_urn'), b'100')
+ found = True
+ assert found
+
+ def test_environments_with_same_resource_hints_are_reused(self):
+ class HintX(ResourceHint):
+ urn = 'X_urn'
+
+ class HintY(ResourceHint):
+ urn = 'Y_urn'
+
+ class HintIsOdd(ResourceHint):
+ urn = 'IsOdd_urn'
+
+ ResourceHint.register_resource_hint('X', HintX)
+ ResourceHint.register_resource_hint('Y', HintY)
+ ResourceHint.register_resource_hint('IsOdd', HintIsOdd)
+
+ p = beam.Pipeline()
+ num_iter = 4
+ for i in range(num_iter):
+ _ = (
+ p
+ | f'NoHintCreate_{i}' >> beam.Create([1, 2])
+ | f'NoHint_{i}' >> beam.Map(lambda x: x + 1))
+ _ = (
+ p
+ | f'XCreate_{i}' >> beam.Create([1, 2])
+ |
+ f'HintX_{i}' >> beam.Map(lambda x: x + 1).with_resource_hints(X='X'))
+ _ = (
+ p
+ | f'XYCreate_{i}' >> beam.Create([1, 2])
+ | f'HintXY_{i}' >> beam.Map(lambda x: x + 1).with_resource_hints(
+ X='X', Y='Y'))
+ _ = (
+ p
+ | f'IsOddCreate_{i}' >> beam.Create([1, 2])
+ | f'IsOdd_{i}' >>
+ beam.Map(lambda x: x + 1).with_resource_hints(IsOdd=str(i % 2 != 0)))
+
+ proto = Pipeline.to_runner_api(p, use_fake_coders=True)
+ count_x = count_xy = count_is_odd = count_no_hints = 0
+ env_ids = set()
+ for _, t in proto.components.transforms.items():
+ env = proto.components.environments[t.environment_id]
+ if t.unique_name.startswith('HintX_'):
+ count_x += 1
+ env_ids.add(t.environment_id)
+ self.assertEqual(env.resource_hints, {'X_urn': b'X'})
+
+ if t.unique_name.startswith('HintXY_'):
+ count_xy += 1
+ env_ids.add(t.environment_id)
+ self.assertEqual(env.resource_hints, {'X_urn': b'X', 'Y_urn': b'Y'})
+
+ if t.unique_name.startswith('NoHint_'):
+ count_no_hints += 1
+ env_ids.add(t.environment_id)
+ self.assertEqual(env.resource_hints, {})
+
+ if t.unique_name.startswith('IsOdd_'):
+ count_is_odd += 1
+ env_ids.add(t.environment_id)
+ self.assertTrue(
+ env.resource_hints == {'IsOdd_urn': b'True'} or
+ env.resource_hints == {'IsOdd_urn': b'False'})
+ assert count_x == count_is_odd == count_xy == count_no_hints == num_iter
Review comment:
Yeah, it's slow. I don't think much time has gone into optimizing it,
but speeding it up would especially help TFX pipelines (that tend to have lots
and lots of stages).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]