damccorm commented on code in PR #24656:
URL: https://github.com/apache/beam/pull/24656#discussion_r1049758451


##########
sdks/python/apache_beam/examples/inference/multi_language/README.md:
##########
@@ -0,0 +1,49 @@
+<!--
+    Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+    Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+-->
+## Setting up the expansion service
+In order to start the python expansion service, run the following command:

Review Comment:
   ```suggestion
   In order to start the python expansion service locally, run the following 
command:
   ```
   
   @chamikaramj does this happen automatically for some runners (e.g. dataflow)?



##########
sdks/python/apache_beam/examples/inference/multi_language/expansion_service/run_inference_expansion.py:
##########
@@ -0,0 +1,209 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# pytype: skip-file
+
+import argparse
+import logging
+import signal
+import sys
+import typing
+
+import grpc
+
+import apache_beam as beam
+from apache_beam.coders import RowCoder
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import 
PytorchModelHandlerKeyedTensor
+from apache_beam.pipeline import PipelineOptions
+from apache_beam.portability.api import beam_artifact_api_pb2_grpc
+from apache_beam.portability.api import beam_expansion_api_pb2_grpc
+from apache_beam.portability.api import external_transforms_pb2
+from apache_beam.runners.portability import artifact_service
+from apache_beam.runners.portability import expansion_service
+from apache_beam.transforms import fully_qualified_named_transform
+from apache_beam.transforms import ptransform
+from apache_beam.transforms.environments import PyPIArtifactRegistry
+from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
+from apache_beam.utils import thread_pool_executor
+from transformers import BertConfig
+from transformers import BertForMaskedLM
+from transformers import BertTokenizer
+
+# This script provides an expansion service for a run inference transform
+# with pre and post processing.
+# The model used is a BertLM, base uncased model.
+_LOGGER = logging.getLogger(__name__)
+
+# A transform that runs inference on a Bertmodel.
+TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert"
+
+
[email protected]_urn(TEST_RUN_BERT_URN, None)
+class RunInferenceTransform(ptransform.PTransform):
+  class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor):
+    """Wrapper to PytorchModelHandler to limit batch size to 1.
+        The tokenized strings generated from BertTokenizer may have different
+        lengths, which doesn't work with torch.stack() in current RunInference
+        implementation since stack() requires tensors to be the same size.
+        Restricting max_batch_size to 1 means there is only 1 example per
+        `batch` in the run_inference() call.
+        """
+    def batch_elements_kwargs(self):
+      return {'max_batch_size': 1}

Review Comment:
   Does Bert normally handle batches of elements? Would it make sense to use a 
custom inference function instead of a full wrapper? A custom inference 
function allows you to supply a function to run instead of the default function 
(https://github.com/apache/beam/blob/f16d5d51b7551c12a854833676fc5a20bf8b6702/sdks/python/apache_beam/ml/inference/pytorch_inference.py#L145)



##########
sdks/python/apache_beam/examples/inference/multi_language/last_word_prediction/src/main/java/org/MultiLangRunInference.java:
##########
@@ -0,0 +1,93 @@
+package org;
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.runners.core.construction.External;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.options.Validation.Required;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.util.ByteStringOutputStream;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.PDone;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class MultiLangRunInference {
+    public interface MultiLangueageOptions extends PipelineOptions {
+
+        @Description("Path to an input file that contains labels and pixels to 
feed into the model")
+        @Required
+        String getInputFile();
+
+        void setInputFile(String value);
+
+        @Description("Path to an input file that contains labels and pixels to 
feed into the model")
+        @Required
+        String getOutputFile();
+
+        void setOutputFile(String value);
+    }
+
+    private static byte[] toStringPayloadBytes(String model) {
+        Row configRow = Row.withSchema(Schema.of(Field.of("model", 
FieldType.STRING)))
+                .withFieldValue("model", model)
+                .build();
+
+        ByteStringOutputStream outputStream = new ByteStringOutputStream();
+
+        try {
+            RowCoder.of(configRow.getSchema()).encode(configRow, outputStream);
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+        ExternalTransforms.ExternalConfigurationPayload payload = 
ExternalTransforms.ExternalConfigurationPayload
+                .newBuilder()
+                
.setSchema(SchemaTranslation.schemaToProto(configRow.getSchema(), false))
+                .setPayload(outputStream.toByteString())
+                .build();
+        return payload.toByteArray();
+    }
+
+    public static void main(String[] args) {
+
+        final String TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert";
+        MultiLangueageOptions options = 
PipelineOptionsFactory.fromArgs(args).withValidation()
+                .as(MultiLangueageOptions.class);
+
+        Pipeline pipeline = Pipeline.create(options);
+        PCollection<String> predictions = pipeline.apply("Read Input", 
TextIO.read().from(options.getInputFile()))
+            .apply("Run Inference" ,External.of(TEST_RUN_BERT_URN, 
toStringPayloadBytes("bert-base-uncased"), "localhost:12345"));

Review Comment:
   `localhost:12345` what is this? Is this the expansion service port? If so, 
should this be configurable/passed in as an arg?



##########
sdks/python/apache_beam/examples/inference/multi_language/last_word_prediction/src/main/java/org/MultiLangRunInference.java:
##########
@@ -0,0 +1,93 @@
+package org;
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.runners.core.construction.External;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.options.Validation.Required;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.util.ByteStringOutputStream;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.PDone;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class MultiLangRunInference {
+    public interface MultiLangueageOptions extends PipelineOptions {
+
+        @Description("Path to an input file that contains labels and pixels to 
feed into the model")
+        @Required
+        String getInputFile();
+
+        void setInputFile(String value);
+
+        @Description("Path to an input file that contains labels and pixels to 
feed into the model")
+        @Required
+        String getOutputFile();
+
+        void setOutputFile(String value);
+    }
+
+    private static byte[] toStringPayloadBytes(String model) {
+        Row configRow = Row.withSchema(Schema.of(Field.of("model", 
FieldType.STRING)))
+                .withFieldValue("model", model)
+                .build();
+
+        ByteStringOutputStream outputStream = new ByteStringOutputStream();
+
+        try {
+            RowCoder.of(configRow.getSchema()).encode(configRow, outputStream);
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+        ExternalTransforms.ExternalConfigurationPayload payload = 
ExternalTransforms.ExternalConfigurationPayload
+                .newBuilder()
+                
.setSchema(SchemaTranslation.schemaToProto(configRow.getSchema(), false))
+                .setPayload(outputStream.toByteString())
+                .build();
+        return payload.toByteArray();
+    }

Review Comment:
   @chamikaramj there's an easier way to pass strings, right? Shouldn't this 
encoding happen automatically?



##########
sdks/python/apache_beam/examples/inference/multi_language/expansion_service/run_inference_expansion.py:
##########
@@ -0,0 +1,209 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# pytype: skip-file
+
+import argparse
+import logging
+import signal
+import sys
+import typing
+
+import grpc
+
+import apache_beam as beam
+from apache_beam.coders import RowCoder
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import 
PytorchModelHandlerKeyedTensor
+from apache_beam.pipeline import PipelineOptions
+from apache_beam.portability.api import beam_artifact_api_pb2_grpc
+from apache_beam.portability.api import beam_expansion_api_pb2_grpc
+from apache_beam.portability.api import external_transforms_pb2
+from apache_beam.runners.portability import artifact_service
+from apache_beam.runners.portability import expansion_service
+from apache_beam.transforms import fully_qualified_named_transform
+from apache_beam.transforms import ptransform
+from apache_beam.transforms.environments import PyPIArtifactRegistry
+from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
+from apache_beam.utils import thread_pool_executor
+from transformers import BertConfig
+from transformers import BertForMaskedLM
+from transformers import BertTokenizer
+
+# This script provides an expansion service for a run inference transform
+# with pre and post processing.
+# The model used is a BertLM, base uncased model.
+_LOGGER = logging.getLogger(__name__)
+
+# A transform that runs inference on a Bertmodel.

Review Comment:
   ```suggestion
   # This URN will be used to register a transform that runs inference on a 
BERT model.
   ```



##########
sdks/python/apache_beam/examples/inference/multi_language/README.md:
##########
@@ -0,0 +1,49 @@
+<!--
+    Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+    Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+-->
+## Setting up the expansion service
+In order to start the python expansion service, run the following command:
+
+```
+python -m expansion_service.run_inference_expansion \
+    --port=<port to host expansion service> \

Review Comment:
   @chamikaramj how should we choose the port here? Does it matter? A small 
comment explaining this would be helpful



##########
sdks/python/apache_beam/examples/inference/multi_language/expansion_service/run_inference_expansion.py:
##########
@@ -0,0 +1,209 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# pytype: skip-file
+
+import argparse
+import logging
+import signal
+import sys
+import typing
+
+import grpc
+
+import apache_beam as beam
+from apache_beam.coders import RowCoder
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import 
PytorchModelHandlerKeyedTensor
+from apache_beam.pipeline import PipelineOptions
+from apache_beam.portability.api import beam_artifact_api_pb2_grpc
+from apache_beam.portability.api import beam_expansion_api_pb2_grpc
+from apache_beam.portability.api import external_transforms_pb2
+from apache_beam.runners.portability import artifact_service
+from apache_beam.runners.portability import expansion_service
+from apache_beam.transforms import fully_qualified_named_transform
+from apache_beam.transforms import ptransform
+from apache_beam.transforms.environments import PyPIArtifactRegistry
+from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
+from apache_beam.utils import thread_pool_executor
+from transformers import BertConfig
+from transformers import BertForMaskedLM
+from transformers import BertTokenizer
+
+# This script provides an expansion service for a run inference transform
+# with pre and post processing.
+# The model used is a BertLM, base uncased model.
+_LOGGER = logging.getLogger(__name__)
+
+# A transform that runs inference on a Bertmodel.
+TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert"
+
+
[email protected]_urn(TEST_RUN_BERT_URN, None)
+class RunInferenceTransform(ptransform.PTransform):
+  class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor):
+    """Wrapper to PytorchModelHandler to limit batch size to 1.
+        The tokenized strings generated from BertTokenizer may have different
+        lengths, which doesn't work with torch.stack() in current RunInference
+        implementation since stack() requires tensors to be the same size.
+        Restricting max_batch_size to 1 means there is only 1 example per
+        `batch` in the run_inference() call.
+        """
+    def batch_elements_kwargs(self):
+      return {'max_batch_size': 1}
+
+  class Preprocess(beam.DoFn):
+    def __init__(self, tokenizer):
+      # self._model_name = model_name
+      logging.info('Starting Preprocess')
+      # self._tokenizer = BertTokenizer.from_pretrained(self._model_name)

Review Comment:
   ```suggestion
         logging.info('Starting Preprocess')
   ```
   
   Nit



##########
sdks/python/apache_beam/examples/inference/multi_language/README.md:
##########
@@ -0,0 +1,49 @@
+<!--
+    Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+    Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+-->
+## Setting up the expansion service
+In order to start the python expansion service, run the following command:
+
+```
+python -m expansion_service.run_inference_expansion \
+    --port=<port to host expansion service> \
+    --env_config=<python container>
+```
+If you use a custom python container, make sure it is publicly accessible for 
the workers to pull.

Review Comment:
   @chamikaramj is there a way to do authenticated pulls of containers?



##########
sdks/python/apache_beam/examples/inference/multi_language/expansion_service/run_inference_expansion.py:
##########
@@ -0,0 +1,209 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# pytype: skip-file
+
+import argparse
+import logging
+import signal
+import sys
+import typing
+
+import grpc
+
+import apache_beam as beam
+from apache_beam.coders import RowCoder
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import 
PytorchModelHandlerKeyedTensor
+from apache_beam.pipeline import PipelineOptions
+from apache_beam.portability.api import beam_artifact_api_pb2_grpc
+from apache_beam.portability.api import beam_expansion_api_pb2_grpc
+from apache_beam.portability.api import external_transforms_pb2
+from apache_beam.runners.portability import artifact_service
+from apache_beam.runners.portability import expansion_service
+from apache_beam.transforms import fully_qualified_named_transform
+from apache_beam.transforms import ptransform
+from apache_beam.transforms.environments import PyPIArtifactRegistry
+from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
+from apache_beam.utils import thread_pool_executor
+from transformers import BertConfig
+from transformers import BertForMaskedLM
+from transformers import BertTokenizer
+
+# This script provides an expansion service for a run inference transform
+# with pre and post processing.
+# The model used is a BertLM, base uncased model.
+_LOGGER = logging.getLogger(__name__)
+
+# A transform that runs inference on a Bertmodel.
+TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert"
+
+
[email protected]_urn(TEST_RUN_BERT_URN, None)
+class RunInferenceTransform(ptransform.PTransform):
+  class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor):
+    """Wrapper to PytorchModelHandler to limit batch size to 1.
+        The tokenized strings generated from BertTokenizer may have different
+        lengths, which doesn't work with torch.stack() in current RunInference
+        implementation since stack() requires tensors to be the same size.
+        Restricting max_batch_size to 1 means there is only 1 example per
+        `batch` in the run_inference() call.
+        """
+    def batch_elements_kwargs(self):
+      return {'max_batch_size': 1}
+
+  class Preprocess(beam.DoFn):
+    def __init__(self, tokenizer):
+      # self._model_name = model_name
+      logging.info('Starting Preprocess')
+      # self._tokenizer = BertTokenizer.from_pretrained(self._model_name)
+      self._tokenizer = tokenizer
+      logging.info('Tokenizer loaded')
+
+    def process(self, text: str):
+      import torch
+      if len(text.strip()) > 0:
+        logging.info('Preprocessing Line: %s', text)
+        text_list = text.split()
+        masked_text = ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]])
+        tokens = self._tokenizer(masked_text, return_tensors='pt')
+        tokens = {key: torch.squeeze(val) for key, val in tokens.items()}
+        return [(text, tokens)]
+
+  class Postprocess(beam.DoFn):
+    def __init__(self, bert_tokenizer):
+      logging.info('Starting Postprocess')
+      self.bert_tokenizer = bert_tokenizer
+
+    def process(self, element: typing.Tuple[str, PredictionResult]) \
+        -> typing.Iterable[str]:
+      text, prediction_result = element
+      inputs = prediction_result.example
+      logits = prediction_result.inference['logits']
+      mask_token_index = (
+          inputs['input_ids'] == self.bert_tokenizer.mask_token_id).nonzero(
+              as_tuple=True)[0]
+      predicted_token_id = logits[mask_token_index].argmax(axis=-1)
+      decoded_word = self.bert_tokenizer.decode(predicted_token_id)
+      text = text.replace('.', '')
+      yield text + '\n Predicted word: ' + decoded_word.upper()
+
+  def __init__(self, model):
+    self._model = model
+    # can also save the model config and tokenizer in gcs and load in
+    self._model_config = BertConfig.from_pretrained(self._model)
+    self._tokenizer = BertTokenizer.from_pretrained(self._model)
+    self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper(
+        state_dict_path=(
+            "gs://apache-beam-x-lang-testing/input/"
+            "bert-model/bert-base-uncased.pth"),

Review Comment:
   This should probably be configurable/a paramter passed in from Java



##########
sdks/python/apache_beam/examples/inference/multi_language/expansion_service/run_inference_expansion.py:
##########
@@ -0,0 +1,209 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# pytype: skip-file
+
+import argparse
+import logging
+import signal
+import sys
+import typing
+
+import grpc
+
+import apache_beam as beam
+from apache_beam.coders import RowCoder
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import 
PytorchModelHandlerKeyedTensor
+from apache_beam.pipeline import PipelineOptions
+from apache_beam.portability.api import beam_artifact_api_pb2_grpc
+from apache_beam.portability.api import beam_expansion_api_pb2_grpc
+from apache_beam.portability.api import external_transforms_pb2
+from apache_beam.runners.portability import artifact_service
+from apache_beam.runners.portability import expansion_service
+from apache_beam.transforms import fully_qualified_named_transform
+from apache_beam.transforms import ptransform
+from apache_beam.transforms.environments import PyPIArtifactRegistry
+from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
+from apache_beam.utils import thread_pool_executor
+from transformers import BertConfig
+from transformers import BertForMaskedLM
+from transformers import BertTokenizer
+
+# This script provides an expansion service for a run inference transform
+# with pre and post processing.
+# The model used is a BertLM, base uncased model.
+_LOGGER = logging.getLogger(__name__)
+
+# A transform that runs inference on a Bertmodel.
+TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert"
+
+
[email protected]_urn(TEST_RUN_BERT_URN, None)
+class RunInferenceTransform(ptransform.PTransform):
+  class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor):
+    """Wrapper to PytorchModelHandler to limit batch size to 1.
+        The tokenized strings generated from BertTokenizer may have different
+        lengths, which doesn't work with torch.stack() in current RunInference
+        implementation since stack() requires tensors to be the same size.
+        Restricting max_batch_size to 1 means there is only 1 example per
+        `batch` in the run_inference() call.
+        """
+    def batch_elements_kwargs(self):
+      return {'max_batch_size': 1}
+
+  class Preprocess(beam.DoFn):
+    def __init__(self, tokenizer):
+      # self._model_name = model_name
+      logging.info('Starting Preprocess')
+      # self._tokenizer = BertTokenizer.from_pretrained(self._model_name)
+      self._tokenizer = tokenizer
+      logging.info('Tokenizer loaded')
+
+    def process(self, text: str):
+      import torch
+      if len(text.strip()) > 0:
+        logging.info('Preprocessing Line: %s', text)
+        text_list = text.split()
+        masked_text = ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]])
+        tokens = self._tokenizer(masked_text, return_tensors='pt')
+        tokens = {key: torch.squeeze(val) for key, val in tokens.items()}
+        return [(text, tokens)]
+
+  class Postprocess(beam.DoFn):
+    def __init__(self, bert_tokenizer):
+      logging.info('Starting Postprocess')
+      self.bert_tokenizer = bert_tokenizer
+
+    def process(self, element: typing.Tuple[str, PredictionResult]) \
+        -> typing.Iterable[str]:
+      text, prediction_result = element
+      inputs = prediction_result.example
+      logits = prediction_result.inference['logits']
+      mask_token_index = (
+          inputs['input_ids'] == self.bert_tokenizer.mask_token_id).nonzero(
+              as_tuple=True)[0]
+      predicted_token_id = logits[mask_token_index].argmax(axis=-1)
+      decoded_word = self.bert_tokenizer.decode(predicted_token_id)
+      text = text.replace('.', '')
+      yield text + '\n Predicted word: ' + decoded_word.upper()
+
+  def __init__(self, model):
+    self._model = model
+    # can also save the model config and tokenizer in gcs and load in
+    self._model_config = BertConfig.from_pretrained(self._model)
+    self._tokenizer = BertTokenizer.from_pretrained(self._model)
+    self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper(
+        state_dict_path=(
+            "gs://apache-beam-x-lang-testing/input/"
+            "bert-model/bert-base-uncased.pth"),
+        model_class=BertForMaskedLM,
+        model_params={'config': self._model_config},
+        device='cuda:0')
+
+  def expand(self, pcoll):
+    return (
+        pcoll
+        | 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer))
+        | 'Inference' >> RunInference(KeyedModelHandler(self._model_handler))
+        | 'Postprocess' >> beam.ParDo(self.Postprocess(
+            self._tokenizer)).with_input_types(typing.Iterable[str]))
+
+  def to_runner_api_parameter(self, unused_context):
+    return TEST_RUN_BERT_URN, ImplicitSchemaPayloadBuilder(
+        {'model': self._model}).payload()
+
+  @staticmethod
+  def from_runner_api_parameter(unused_ptransform, payload, unused_context):
+    return RunInferenceTransform(parse_string_payload(payload)['model'])
+
+
[email protected]_urn('payload', bytes)

Review Comment:
   What is this transform doing? Could you add some explanatory comments?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to