ryanthompson591 commented on code in PR #23456:
URL: https://github.com/apache/beam/pull/23456#discussion_r1004721652


##########
sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py:
##########
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Intended only for internal testing.
+
+from typing import Dict
+from typing import Optional
+
+import tensorflow as tf
+
+
+class TFModelWrapperWithSignature(tf.keras.Model):
+  def __init__(
+      self,
+      model,
+      preprocess_input=None,
+      input_dtype=None,
+      feature_description=None,
+      **kwargs):

Review Comment:
   If we can be specific about kwargs should be, that would be ideal.  If not 
it's fine.



##########
sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py:
##########
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Intended only for internal testing.
+
+from typing import Dict
+from typing import Optional
+
+import tensorflow as tf
+
+
+class TFModelWrapperWithSignature(tf.keras.Model):
+  def __init__(
+      self,
+      model,
+      preprocess_input=None,
+      input_dtype=None,
+      feature_description=None,
+      **kwargs):
+    super().__init__()
+    self.model = model
+    self.preprocess_input = preprocess_input
+    self.input_dtype = input_dtype
+    self.feature_description = feature_description
+    if not feature_description:
+      self.feature_description = {'image': tf.io.FixedLenFeature((), 
tf.string)}
+    self._kwargs = kwargs
+
+  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
+  def call(self, serialized_examples):
+    features = tf.io.parse_example(
+        serialized_examples, features=self.feature_description)
+
+    # Initialize a TensorArray to store the deserialized values.
+    # For more details, please look at
+    # 
https://github.com/tensorflow/tensorflow/issues/39323#issuecomment-627586602
+    batch = len(features['image'])

Review Comment:
   rename to batch_len ?



##########
sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py:
##########
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Intended only for internal testing.
+
+from typing import Dict
+from typing import Optional
+
+import tensorflow as tf
+
+
+class TFModelWrapperWithSignature(tf.keras.Model):
+  def __init__(
+      self,
+      model,
+      preprocess_input=None,
+      input_dtype=None,
+      feature_description=None,
+      **kwargs):
+    super().__init__()
+    self.model = model
+    self.preprocess_input = preprocess_input
+    self.input_dtype = input_dtype
+    self.feature_description = feature_description
+    if not feature_description:
+      self.feature_description = {'image': tf.io.FixedLenFeature((), 
tf.string)}
+    self._kwargs = kwargs
+
+  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
+  def call(self, serialized_examples):
+    features = tf.io.parse_example(
+        serialized_examples, features=self.feature_description)
+
+    # Initialize a TensorArray to store the deserialized values.
+    # For more details, please look at
+    # 
https://github.com/tensorflow/tensorflow/issues/39323#issuecomment-627586602
+    batch = len(features['image'])
+    deserialized_vectors = tf.TensorArray(
+        self.input_dtype, size=batch, dynamic_size=True)
+    # Vectorized version of tf.io.parse_tensor is not available.
+    # Use for loop to vectorize the tensor. For more details, refer
+    # https://github.com/tensorflow/tensorflow/issues/43706
+    for i in range(batch):
+      deserialized_value = tf.io.parse_tensor(
+          features['image'][i], out_type=self.input_dtype)
+      # 
http://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873
+      # In Graph mode, return value must get assigned in order to
+      # update the array
+      deserialized_vectors = deserialized_vectors.write(i, deserialized_value)
+
+    deserialized_tensor = deserialized_vectors.stack()
+    if self.preprocess_input:
+      deserialized_tensor = self.preprocess_input(deserialized_tensor)
+    return self.model(deserialized_tensor, **self._kwargs)
+
+
+def save_tf_model_with_signature(
+    path_to_save_model,
+    model=None,
+    preprocess_input=None,
+    input_dtype=tf.float32,
+    feature_description: Optional[Dict] = None,
+    **kwargs,
+):
+  """
+  Helper function used to save the Tensorflow Model with a serving signature.
+  This is intended only for internal testing.
+  Args:
+   path_to_save_model: Path to save the model with modified signature.
+   model: Base tensorflow model used for TFX-BSL RunInference transform.
+   preprocess_input: Preprocess method to be included as part of the
+   Model's serving signature.
+   input_dtype: dtype of the inputs to the model.

Review Comment:
   call out if this is a list or what the expected input is?



##########
sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py:
##########
@@ -0,0 +1,192 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A pipeline that uses TFX RunInference API to perform image classification.
+Please look at https://github.com/tensorflow/tfx-bsl/tree/master/tfx_bsl/beam.

Review Comment:
   How about here for documentation?
   
https://www.tensorflow.org/tfx/tfx_bsl/api_docs/python/tfx_bsl/public/beam/run_inference/CreateModelHandler



##########
sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py:
##########
@@ -0,0 +1,71 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tensorflow as tf
+
+
+class TFModelWrapperWithSignature(tf.keras.Model):

Review Comment:
   Not done.  Adding comments shouldn't effect the postcommit.



##########
sdks/python/apache_beam/examples/inference/tfx_bsl/tfx_bsl_inference_it_test.py:
##########
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+import uuid
+
+import pytest
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=ungrouped-imports
+try:
+  import tfx_bsl
+  import tensorflow as tf
+  from apache_beam.examples.inference.tfx_bsl import 
tensorflow_image_classification
+  from apache_beam.examples.inference.tfx_bsl.build_tensorflow_model import 
save_tf_model_with_signature
+except ImportError as e:
+  tfx_bsl = None
+# pylint: disable=line-too-long
+_EXPECTED_OUTPUTS = {
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005001.JPEG':
 '681',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005002.JPEG':
 '333',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005003.JPEG':
 '711',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005004.JPEG':
 '286',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005005.JPEG':
 '445',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005006.JPEG':
 '288',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005007.JPEG':
 '880',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005008.JPEG':
 '534',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005009.JPEG':
 '888',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005010.JPEG':
 '996',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005011.JPEG':
 '327',
+    
'gs://apache-beam-ml/datasets/imagenet/raw-data/validation/ILSVRC2012_val_00005012.JPEG':
 '573'
+}
+
+
+def process_outputs(filepath):
+  with FileSystems().open(filepath) as f:
+    lines = f.readlines()
+  lines = [l.decode('utf-8').strip('\n') for l in lines]
+  return lines
+
+
[email protected](
+    tfx_bsl is None, 'Missing dependencies. '
+    'Test depends on tfx_bsl')
+class TFXRunInferenceTests(unittest.TestCase):
+  @pytest.mark.uses_tensorflow
+  @pytest.mark.it_postcommit
+  def test_tfx_run_inference_mobilenetv2(self):

Review Comment:
   why are we using mobilenetv2? Wouldn't a GPU model be better?



##########
sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py:
##########
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Intended only for internal testing.
+
+from typing import Dict
+from typing import Optional
+
+import tensorflow as tf
+
+
+class TFModelWrapperWithSignature(tf.keras.Model):
+  def __init__(
+      self,
+      model,
+      preprocess_input=None,
+      input_dtype=None,
+      feature_description=None,
+      **kwargs):
+    super().__init__()
+    self.model = model
+    self.preprocess_input = preprocess_input
+    self.input_dtype = input_dtype
+    self.feature_description = feature_description
+    if not feature_description:
+      self.feature_description = {'image': tf.io.FixedLenFeature((), 
tf.string)}
+    self._kwargs = kwargs
+
+  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
+  def call(self, serialized_examples):
+    features = tf.io.parse_example(
+        serialized_examples, features=self.feature_description)
+
+    # Initialize a TensorArray to store the deserialized values.
+    # For more details, please look at
+    # 
https://github.com/tensorflow/tensorflow/issues/39323#issuecomment-627586602
+    batch = len(features['image'])
+    deserialized_vectors = tf.TensorArray(
+        self.input_dtype, size=batch, dynamic_size=True)
+    # Vectorized version of tf.io.parse_tensor is not available.
+    # Use for loop to vectorize the tensor. For more details, refer
+    # https://github.com/tensorflow/tensorflow/issues/43706
+    for i in range(batch):
+      deserialized_value = tf.io.parse_tensor(
+          features['image'][i], out_type=self.input_dtype)
+      # 
http://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873

Review Comment:
   I like the comment format:
   
   ```
   # Some description of why this exists. 
   # See http://mylink...
   ```
   
   Here you have it flipped.



##########
sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py:
##########
@@ -0,0 +1,192 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A pipeline that uses TFX RunInference API to perform image classification.

Review Comment:
   A sample pipeline illustrating how to use Apache Beams run_inference api 
with tfx-bsl's CreateModelHandler API.
   
    ---- some more suggested text you might add---
   The images for this example are taken from <insert something>.



##########
sdks/python/apache_beam/examples/inference/tfx_bsl/build_tensorflow_model.py:
##########
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Intended only for internal testing.
+
+from typing import Dict
+from typing import Optional
+
+import tensorflow as tf
+
+
+class TFModelWrapperWithSignature(tf.keras.Model):
+  def __init__(
+      self,
+      model,
+      preprocess_input=None,
+      input_dtype=None,
+      feature_description=None,
+      **kwargs):
+    super().__init__()
+    self.model = model
+    self.preprocess_input = preprocess_input
+    self.input_dtype = input_dtype
+    self.feature_description = feature_description
+    if not feature_description:
+      self.feature_description = {'image': tf.io.FixedLenFeature((), 
tf.string)}
+    self._kwargs = kwargs
+
+  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
+  def call(self, serialized_examples):
+    features = tf.io.parse_example(
+        serialized_examples, features=self.feature_description)
+
+    # Initialize a TensorArray to store the deserialized values.
+    # For more details, please look at
+    # 
https://github.com/tensorflow/tensorflow/issues/39323#issuecomment-627586602
+    batch = len(features['image'])
+    deserialized_vectors = tf.TensorArray(
+        self.input_dtype, size=batch, dynamic_size=True)
+    # Vectorized version of tf.io.parse_tensor is not available.
+    # Use for loop to vectorize the tensor. For more details, refer
+    # https://github.com/tensorflow/tensorflow/issues/43706
+    for i in range(batch):
+      deserialized_value = tf.io.parse_tensor(
+          features['image'][i], out_type=self.input_dtype)
+      # 
http://github.com/tensorflow/tensorflow/issues/30409#issuecomment-508962873
+      # In Graph mode, return value must get assigned in order to
+      # update the array
+      deserialized_vectors = deserialized_vectors.write(i, deserialized_value)
+
+    deserialized_tensor = deserialized_vectors.stack()
+    if self.preprocess_input:
+      deserialized_tensor = self.preprocess_input(deserialized_tensor)
+    return self.model(deserialized_tensor, **self._kwargs)

Review Comment:
   What are you expecting in kwargs?  Is it possible to make it known though 
comment or explicitly?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to