AnandInguva commented on code in PR #27430:
URL: https://github.com/apache/beam/pull/27430#discussion_r1261709341


##########
sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py:
##########
@@ -61,49 +61,75 @@ def parse_args():
   return parser.parse_known_args()
 
 
-def run(args):
-  data = [
-      dict(x=["Let's", "go", "to", "the", "park"]),
-      dict(x=["I", "enjoy", "going", "to", "the", "park"]),
-      dict(x=["I", "enjoy", "reading", "books"]),
-      dict(x=["Beam", "can", "be", "fun"]),
-      dict(x=["The", "weather", "is", "really", "nice", "today"]),
-      dict(x=["I", "love", "to", "go", "to", "the", "park"]),
-      dict(x=["I", "love", "to", "read", "books"]),
-      dict(x=["I", "love", "to", "program"]),
-  ]
-
+def preprocess_data_for_ml_training(train_data, artifact_mode, args):
   with beam.Pipeline() as p:
-    input_data = p | beam.Create(data)
-
-    # arfifacts produce mode.
-    input_data |= (
-        'MLTransform' >> MLTransform(
+    input_data = (p | "CreateData" >> beam.Create(train_data))
+    transformed_data_pcoll = (
+        input_data
+        | 'MLTransform' >> MLTransform(
             artifact_location=args.artifact_location,
-            artifact_mode=ArtifactMode.PRODUCE,
+            artifact_mode=artifact_mode,
         ).with_transform(ComputeAndApplyVocabulary(
             columns=['x'])).with_transform(TFIDF(columns=['x'])))
 
-    # _ =  input_data | beam.Map(logging.info)
+    _ = transformed_data_pcoll | beam.Map(logging.info)
+  return transformed_data_pcoll
+
 
+def preprocess_data_for_ml_inference(test_data, artifact_mode, args):
   with beam.Pipeline() as p:
-    input_data = [
-        dict(x=['I', 'love', 'books']), dict(x=['I', 'love', 'Apache', 'Beam'])
-    ]
-    input_data = p | beam.Create(input_data)
-
-    # artifacts consume mode.
-    input_data |= (
-        MLTransform(
+
+    test_data_pcoll = (p | beam.Create(test_data))
+    # During inference phase, we want the pipeline to consume the artifacts
+    # produced by the previous run of MLTransform.
+    # So, we set artifact_mode to ArtifactMode.CONSUME.
+    transformed_data_pcoll = (
+        test_data_pcoll
+        | "MLTransformOnTestData" >> MLTransform(
             artifact_location=args.artifact_location,
-            artifact_mode=ArtifactMode.CONSUME,
+            artifact_mode=artifact_mode,
             # you don't need to specify transforms as they are already saved in
             # in the artifacts.
         ))
+    _ = transformed_data_pcoll | beam.Map(logging.info)
+  return transformed_data_pcoll
 
-    _ = input_data | beam.Map(logging.info)
 
-  # To fetch the artifacts after the pipeline is run
+def run(args):
+  train_data = [
+      dict(x=["Let's", "go", "to", "the", "park"]),
+      dict(x=["I", "enjoy", "going", "to", "the", "park"]),
+      dict(x=["I", "enjoy", "reading", "books"]),
+      dict(x=["Beam", "can", "be", "fun"]),
+      dict(x=["The", "weather", "is", "really", "nice", "today"]),
+      dict(x=["I", "love", "to", "go", "to", "the", "park"]),
+      dict(x=["I", "love", "to", "read", "books"]),
+      dict(x=["I", "love", "to", "program"]),
+  ]
+
+  test_data = [
+      dict(x=['I', 'love', 'books']), dict(x=['I', 'love', 'Apache', 'Beam'])
+  ]
+
+  # Preprocess the data for ML training.

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to