This is an automated email from the ASF dual-hosted git repository.
anandinguva pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new aa6c8dc6fcd Add MLTransform to __init__ and change input types of
transforms (#27687)
aa6c8dc6fcd is described below
commit aa6c8dc6fcdda9222ae00e442b5125b4ca005b4d
Author: Anand Inguva <[email protected]>
AuthorDate: Wed Jul 26 10:45:00 2023 -0400
Add MLTransform to __init__ and change input types of transforms (#27687)
* Add MLTransform to __init__
* Update sdks/python/apache_beam/ml/transforms/tft.py
---
sdks/python/apache_beam/ml/__init__.py | 2 ++
sdks/python/apache_beam/ml/transforms/tft.py | 11 +++++++----
2 files changed, 9 insertions(+), 4 deletions(-)
diff --git a/sdks/python/apache_beam/ml/__init__.py
b/sdks/python/apache_beam/ml/__init__.py
index f319fa046f7..6813137f519 100644
--- a/sdks/python/apache_beam/ml/__init__.py
+++ b/sdks/python/apache_beam/ml/__init__.py
@@ -16,3 +16,5 @@
#
"""Contains packages for supported machine learning transforms."""
+
+from apache_beam.ml.transforms.base import MLTransform
diff --git a/sdks/python/apache_beam/ml/transforms/tft.py
b/sdks/python/apache_beam/ml/transforms/tft.py
index 4f4c30e6270..e3260d604d4 100644
--- a/sdks/python/apache_beam/ml/transforms/tft.py
+++ b/sdks/python/apache_beam/ml/transforms/tft.py
@@ -448,7 +448,8 @@ class TFIDF(TFTOperation):
self.tfidf_weight = None
def apply_transform(
- self, data: tf.SparseTensor, output_column_name: str) -> tf.SparseTensor:
+ self, data: common_types.TensorType,
+ output_column_name: str) -> common_types.TensorType:
if self.vocab_size is None:
try:
@@ -507,7 +508,8 @@ class ScaleByMinMax(TFTOperation):
raise ValueError('max_value must be greater than min_value')
def apply_transform(
- self, data: tf.Tensor, output_column_name: str) -> tf.Tensor:
+ self, data: common_types.TensorType,
+ output_column_name: str) -> common_types.TensorType:
output = tft.scale_by_min_max(
x=data, output_min=self.min_value, output_max=self.max_value)
@@ -545,8 +547,9 @@ class NGrams(TFTOperation):
self.name = name
self.split_string_by_delimiter = split_string_by_delimiter
- def apply_transform(self, data: tf.SparseTensor,
- output_column_name: str) -> Dict[str, tf.SparseTensor]:
+ def apply_transform(
+ self, data: common_types.TensorType,
+ output_column_name: str) -> Dict[str, common_types.TensorType]:
if self.split_string_by_delimiter:
data = self._split_string_with_delimiter(
data, self.split_string_by_delimiter)