This is an automated email from the ASF dual-hosted git repository.

anandinguva pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new aa6c8dc6fcd Add MLTransform to __init__ and change input types of 
transforms (#27687)
aa6c8dc6fcd is described below

commit aa6c8dc6fcdda9222ae00e442b5125b4ca005b4d
Author: Anand Inguva <[email protected]>
AuthorDate: Wed Jul 26 10:45:00 2023 -0400

    Add MLTransform to __init__ and change input types of transforms (#27687)
    
    * Add MLTransform to __init__
    
    * Update sdks/python/apache_beam/ml/transforms/tft.py
---
 sdks/python/apache_beam/ml/__init__.py       |  2 ++
 sdks/python/apache_beam/ml/transforms/tft.py | 11 +++++++----
 2 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/sdks/python/apache_beam/ml/__init__.py 
b/sdks/python/apache_beam/ml/__init__.py
index f319fa046f7..6813137f519 100644
--- a/sdks/python/apache_beam/ml/__init__.py
+++ b/sdks/python/apache_beam/ml/__init__.py
@@ -16,3 +16,5 @@
 #
 
 """Contains packages for supported machine learning transforms."""
+
+from apache_beam.ml.transforms.base import MLTransform
diff --git a/sdks/python/apache_beam/ml/transforms/tft.py 
b/sdks/python/apache_beam/ml/transforms/tft.py
index 4f4c30e6270..e3260d604d4 100644
--- a/sdks/python/apache_beam/ml/transforms/tft.py
+++ b/sdks/python/apache_beam/ml/transforms/tft.py
@@ -448,7 +448,8 @@ class TFIDF(TFTOperation):
     self.tfidf_weight = None
 
   def apply_transform(
-      self, data: tf.SparseTensor, output_column_name: str) -> tf.SparseTensor:
+      self, data: common_types.TensorType,
+      output_column_name: str) -> common_types.TensorType:
 
     if self.vocab_size is None:
       try:
@@ -507,7 +508,8 @@ class ScaleByMinMax(TFTOperation):
       raise ValueError('max_value must be greater than min_value')
 
   def apply_transform(
-      self, data: tf.Tensor, output_column_name: str) -> tf.Tensor:
+      self, data: common_types.TensorType,
+      output_column_name: str) -> common_types.TensorType:
 
     output = tft.scale_by_min_max(
         x=data, output_min=self.min_value, output_max=self.max_value)
@@ -545,8 +547,9 @@ class NGrams(TFTOperation):
     self.name = name
     self.split_string_by_delimiter = split_string_by_delimiter
 
-  def apply_transform(self, data: tf.SparseTensor,
-                      output_column_name: str) -> Dict[str, tf.SparseTensor]:
+  def apply_transform(
+      self, data: common_types.TensorType,
+      output_column_name: str) -> Dict[str, common_types.TensorType]:
     if self.split_string_by_delimiter:
       data = self._split_string_with_delimiter(
           data, self.split_string_by_delimiter)

Reply via email to