This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 118c51404d9 Implement DeduplicateTensorPerRow in MLTransform (#31307)
118c51404d9 is described below

commit 118c51404d93132d4a0cae135f00ff68d4fd84bf
Author: Jack McCluskey <[email protected]>
AuthorDate: Thu May 16 10:29:04 2024 -0400

    Implement DeduplicateTensorPerRow in MLTransform (#31307)
---
 sdks/python/apache_beam/ml/transforms/tft.py      | 22 ++++++++
 sdks/python/apache_beam/ml/transforms/tft_test.py | 62 +++++++++++++++++++++++
 2 files changed, 84 insertions(+)

diff --git a/sdks/python/apache_beam/ml/transforms/tft.py 
b/sdks/python/apache_beam/ml/transforms/tft.py
index 550dbedbc7b..370043bc0d9 100644
--- a/sdks/python/apache_beam/ml/transforms/tft.py
+++ b/sdks/python/apache_beam/ml/transforms/tft.py
@@ -681,3 +681,25 @@ class HashStrings(TFTOperation):
             name=self.name)
     }
     return output_dict
+
+
+@register_input_dtype(str)
+class DeduplicateTensorPerRow(TFTOperation):
+  def __init__(self, columns: List[str], name: Optional[str] = None):
+    """ Deduplicates each row (0th dimension) of the provided tensor.
+
+    Args:
+      columns: A list of the columns to apply the transformation on.
+      name: optional. A name for this operation. 
+    """
+    self.name = name
+    super().__init__(columns)
+
+  def apply_transform(
+      self, data: common_types.TensorType,
+      output_col_name: str) -> Dict[str, common_types.TensorType]:
+    output_dict = {
+        output_col_name: tft.deduplicate_tensor_per_row(
+            input_tensor=data, name=self.name)
+    }
+    return output_dict
diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py 
b/sdks/python/apache_beam/ml/transforms/tft_test.py
index 6763032a8eb..5c42ecc012f 100644
--- a/sdks/python/apache_beam/ml/transforms/tft_test.py
+++ b/sdks/python/apache_beam/ml/transforms/tft_test.py
@@ -1009,5 +1009,67 @@ class HashWordsTest(unittest.TestCase):
       assert_that(result, equal_to(expected_values, equals_fn=np.array_equal))
 
 
+class DeduplicateTensorPerRowTest(unittest.TestCase):
+  def setUp(self) -> None:
+    self.artifact_location = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.artifact_location)
+
+  def test_deduplicate(self):
+    values = [{
+        'x': [b'a', b'b', b'a', b'b'],
+    }, {
+        'x': [b'b', b'c', b'b', b'c']
+    }]
+
+    expected_output = [np.array([b'a', b'b']), np.array([b'b', b'c'])]
+    with beam.Pipeline() as p:
+      list_result = (
+          p
+          | "listCreate" >> beam.Create(values)
+          | "listMLTransform" >> base.MLTransform(
+              write_artifact_location=self.artifact_location).with_transform(
+                  tft.DeduplicateTensorPerRow(columns=['x'])))
+      result = (list_result | beam.Map(lambda x: x.x))
+      assert_that(result, equal_to(expected_output, equals_fn=np.array_equal))
+
+  def test_deduplicate_no_op(self):
+    values = [{
+        'x': [b'a', b'b'],
+    }, {
+        'x': [b'c', b'd']
+    }]
+
+    expected_output = [np.array([b'a', b'b']), np.array([b'c', b'd'])]
+    with beam.Pipeline() as p:
+      list_result = (
+          p
+          | "listCreate" >> beam.Create(values)
+          | "listMLTransform" >> base.MLTransform(
+              write_artifact_location=self.artifact_location).with_transform(
+                  tft.DeduplicateTensorPerRow(columns=['x'])))
+      result = (list_result | beam.Map(lambda x: x.x))
+      assert_that(result, equal_to(expected_output, equals_fn=np.array_equal))
+
+  def test_deduplicate_different_output_sizes(self):
+    values = [{
+        'x': [b'a', b'b', b'a', b'b'],
+    }, {
+        'x': [b'c', b'a', b'd', b'd']
+    }]
+
+    expected_output = [np.array([b'a', b'b']), np.array([b'c', b'a', b'd'])]
+    with beam.Pipeline() as p:
+      list_result = (
+          p
+          | "listCreate" >> beam.Create(values)
+          | "listMLTransform" >> base.MLTransform(
+              write_artifact_location=self.artifact_location).with_transform(
+                  tft.DeduplicateTensorPerRow(columns=['x'])))
+      result = (list_result | beam.Map(lambda x: x.x))
+      assert_that(result, equal_to(expected_output, equals_fn=np.array_equal))
+
+
 if __name__ == '__main__':
   unittest.main()

Reply via email to