This is an automated email from the ASF dual-hosted git repository.
jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 118c51404d9 Implement DeduplicateTensorPerRow in MLTransform (#31307)
118c51404d9 is described below
commit 118c51404d93132d4a0cae135f00ff68d4fd84bf
Author: Jack McCluskey <[email protected]>
AuthorDate: Thu May 16 10:29:04 2024 -0400
Implement DeduplicateTensorPerRow in MLTransform (#31307)
---
sdks/python/apache_beam/ml/transforms/tft.py | 22 ++++++++
sdks/python/apache_beam/ml/transforms/tft_test.py | 62 +++++++++++++++++++++++
2 files changed, 84 insertions(+)
diff --git a/sdks/python/apache_beam/ml/transforms/tft.py
b/sdks/python/apache_beam/ml/transforms/tft.py
index 550dbedbc7b..370043bc0d9 100644
--- a/sdks/python/apache_beam/ml/transforms/tft.py
+++ b/sdks/python/apache_beam/ml/transforms/tft.py
@@ -681,3 +681,25 @@ class HashStrings(TFTOperation):
name=self.name)
}
return output_dict
+
+
+@register_input_dtype(str)
+class DeduplicateTensorPerRow(TFTOperation):
+ def __init__(self, columns: List[str], name: Optional[str] = None):
+ """ Deduplicates each row (0th dimension) of the provided tensor.
+
+ Args:
+ columns: A list of the columns to apply the transformation on.
+ name: optional. A name for this operation.
+ """
+ self.name = name
+ super().__init__(columns)
+
+ def apply_transform(
+ self, data: common_types.TensorType,
+ output_col_name: str) -> Dict[str, common_types.TensorType]:
+ output_dict = {
+ output_col_name: tft.deduplicate_tensor_per_row(
+ input_tensor=data, name=self.name)
+ }
+ return output_dict
diff --git a/sdks/python/apache_beam/ml/transforms/tft_test.py
b/sdks/python/apache_beam/ml/transforms/tft_test.py
index 6763032a8eb..5c42ecc012f 100644
--- a/sdks/python/apache_beam/ml/transforms/tft_test.py
+++ b/sdks/python/apache_beam/ml/transforms/tft_test.py
@@ -1009,5 +1009,67 @@ class HashWordsTest(unittest.TestCase):
assert_that(result, equal_to(expected_values, equals_fn=np.array_equal))
+class DeduplicateTensorPerRowTest(unittest.TestCase):
+ def setUp(self) -> None:
+ self.artifact_location = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.artifact_location)
+
+ def test_deduplicate(self):
+ values = [{
+ 'x': [b'a', b'b', b'a', b'b'],
+ }, {
+ 'x': [b'b', b'c', b'b', b'c']
+ }]
+
+ expected_output = [np.array([b'a', b'b']), np.array([b'b', b'c'])]
+ with beam.Pipeline() as p:
+ list_result = (
+ p
+ | "listCreate" >> beam.Create(values)
+ | "listMLTransform" >> base.MLTransform(
+ write_artifact_location=self.artifact_location).with_transform(
+ tft.DeduplicateTensorPerRow(columns=['x'])))
+ result = (list_result | beam.Map(lambda x: x.x))
+ assert_that(result, equal_to(expected_output, equals_fn=np.array_equal))
+
+ def test_deduplicate_no_op(self):
+ values = [{
+ 'x': [b'a', b'b'],
+ }, {
+ 'x': [b'c', b'd']
+ }]
+
+ expected_output = [np.array([b'a', b'b']), np.array([b'c', b'd'])]
+ with beam.Pipeline() as p:
+ list_result = (
+ p
+ | "listCreate" >> beam.Create(values)
+ | "listMLTransform" >> base.MLTransform(
+ write_artifact_location=self.artifact_location).with_transform(
+ tft.DeduplicateTensorPerRow(columns=['x'])))
+ result = (list_result | beam.Map(lambda x: x.x))
+ assert_that(result, equal_to(expected_output, equals_fn=np.array_equal))
+
+ def test_deduplicate_different_output_sizes(self):
+ values = [{
+ 'x': [b'a', b'b', b'a', b'b'],
+ }, {
+ 'x': [b'c', b'a', b'd', b'd']
+ }]
+
+ expected_output = [np.array([b'a', b'b']), np.array([b'c', b'a', b'd'])]
+ with beam.Pipeline() as p:
+ list_result = (
+ p
+ | "listCreate" >> beam.Create(values)
+ | "listMLTransform" >> base.MLTransform(
+ write_artifact_location=self.artifact_location).with_transform(
+ tft.DeduplicateTensorPerRow(columns=['x'])))
+ result = (list_result | beam.Map(lambda x: x.x))
+ assert_that(result, equal_to(expected_output, equals_fn=np.array_equal))
+
+
if __name__ == '__main__':
unittest.main()