ahmedabu98 commented on code in PR #29834:
URL: https://github.com/apache/beam/pull/29834#discussion_r1474850603


##########
sdks/python/apache_beam/transforms/external_transform_provider_test.py:
##########
@@ -186,83 +193,139 @@ def setUp(self):
     os.mkdir(self.test_dir)
 
     self.assertTrue(
-        os.environ.get('EXPANSION_PORT'), "Expansion service port not found!")
+        os.environ.get('EXPANSION_PORTS'), "Expansion service port not found!")
+    logging.info("EXPANSION_PORTS: %s", os.environ.get('EXPANSION_PORTS'))
 
   def tearDown(self):
     shutil.rmtree(self.test_dir, ignore_errors=False)
 
-  def test_script_workflow(self):
+  def delete_and_validate(self):
+    delete_generated_files(self.test_dir)
+    self.assertEqual(len(os.listdir(self.test_dir)), 0)
+
+  def test_script_fails_with_invalid_destinations(self):
     expansion_service_config = {
         "gradle_target": 'sdks:java:io:expansion-service:shadowJar',
         'destinations': {
-            'python': f'apache_beam/transforms/{self.test_dir_name}'
+            'python': 'apache_beam/some_nonexistent_dir'
         }
     }
+    with self.assertRaises(ValueError):
+      self.create_and_check_transforms_config_exists(expansion_service_config)
+
+  def test_pretty_types(self):
+    types = [
+        typing.Optional[typing.List[str]],
+        numpy.int16,
+        str,
+        typing.Dict[str, numpy.float64],
+        typing.Optional[typing.Dict[str, typing.List[numpy.int64]]],
+        typing.Dict[int, typing.Optional[str]]
+    ]
+
+    expected_type_names = [('List[str]', True), ('numpy.int16', False),
+                           ('str', False), ('Dict[str, numpy.float64]', False),
+                           ('Dict[str, List[numpy.int64]]', True),
+                           ('Dict[int, Union[str, NoneType]]', False)]
+
+    for i in range(len(types)):
+      self.assertEqual(pretty_type(types[i]), expected_type_names[i])
+
+  def create_and_check_transforms_config_exists(self, 
expansion_service_config):
     with open(self.service_config_path, 'w') as f:
       yaml.dump([expansion_service_config], f)
 
-    # test that transform config YAML file is created
     generate_transforms_config(
         self.service_config_path, self.transform_config_path)
     self.assertTrue(os.path.exists(self.transform_config_path))
-    expected_destination = \
-      f'apache_beam/transforms/{self.test_dir_name}/generate_sequence'
-    # test that transform config is populated correctly
+
+  def create_and_validate_transforms_config(
+      self, expansion_service_config, expected_name, expected_destination):
+    self.create_and_check_transforms_config_exists(expansion_service_config)
+
     with open(self.transform_config_path) as f:
-      transforms = yaml.safe_load(f)
+      configs = yaml.safe_load(f)
       gen_seq_config = None
-      for transform in transforms:
-        if transform['identifier'] == self.GEN_SEQ_IDENTIFIER:
-          gen_seq_config = transform
+      for config in configs:
+        if config['identifier'] == self.GEN_SEQ_IDENTIFIER:
+          gen_seq_config = config
       self.assertIsNotNone(gen_seq_config)
       self.assertEqual(
           gen_seq_config['default_service'],
           expansion_service_config['gradle_target'])
-      self.assertEqual(gen_seq_config['name'], 'GenerateSequence')
+      self.assertEqual(gen_seq_config['name'], expected_name)
       self.assertEqual(
           gen_seq_config['destinations']['python'], expected_destination)
       self.assertIn("end", gen_seq_config['fields'])
       self.assertIn("start", gen_seq_config['fields'])
       self.assertIn("rate", gen_seq_config['fields'])
 
-    # test that the code for GenerateSequence is set to the right destination
+  def get_module(self, dest):
+    module_name = dest.replace('apache_beam/', '').replace('/', '_')
+    module = 'apache_beam.transforms.%s.%s' % (self.test_dir_name, module_name)
+    return import_module(module)
+
+  def write_wrappers_to_destinations_and_validate(
+      self, destinations: typing.List[str]):
+    """
+    Generate wrappers from the config path and validate all destinations are
+    included.
+    Then write wrappers to destinations and validate all destination paths
+    exist.
+
+    :return: Generated wrappers grouped by destination
+    """
     grouped_wrappers = get_wrappers_from_transform_configs(
         self.transform_config_path)
-    self.assertIn(expected_destination, grouped_wrappers)
-    # only the GenerateSequence wrapper is set to this destination
-    self.assertEqual(len(grouped_wrappers[expected_destination]), 1)
+    for dest in destinations:
+      self.assertIn(dest, grouped_wrappers)
+
+    # write to our test directory to avoid messing with other files
+    write_wrappers_to_destinations(grouped_wrappers, self.test_dir)
+
+    for dest in destinations:
+      self.assertTrue(
+          os.path.exists(
+              os.path.join(
+                  self.test_dir,
+                  dest.replace('apache_beam/', '').replace('/', '_') + ".py")))
+    return grouped_wrappers
+
+  def test_script_workflow(self):
+    expected_destination = 'apache_beam/transforms'
+    expansion_service_config = {
+        "gradle_target": 'sdks:java:io:expansion-service:shadowJar',
+        'destinations': {
+            'python': expected_destination
+        }
+    }
+
+    self.create_and_validate_transforms_config(
+        expansion_service_config, 'GenerateSequence', expected_destination)
+    grouped_wrappers = self.write_wrappers_to_destinations_and_validate(
+        [expected_destination])
+    # at least the GenerateSequence wrapper is set to this destination
+    self.assertGreaterEqual(len(grouped_wrappers[expected_destination]), 1)
 
-    # test that the correct destination is created
-    write_wrappers_to_destinations(grouped_wrappers)
-    self.assertTrue(
-        os.path.exists(
-            os.path.join(self.test_dir, 'generate_sequence' + PYTHON_SUFFIX)))
     # check the wrapper exists in this destination and has correct properties
-    generate_sequence_et = import_module(
-        expected_destination.replace('/', '.') + PYTHON_SUFFIX.rstrip('.py'))
-    self.assertTrue(hasattr(generate_sequence_et, 'GenerateSequence'))
+    output_module = self.get_module(expected_destination)
+    self.assertTrue(hasattr(output_module, 'GenerateSequence'))
+    self.assertTrue(hasattr(output_module, 'KafkaWrite'))  # also check that

Review Comment:
   We're skipping Kafka transforms in the `standard_expansion_services.yaml` 
config. 
   In this particular test we're creating a test config where nothing is being 
skipping.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to