ryanthompson591 commented on a change in pull request #16917: URL: https://github.com/apache/beam/pull/16917#discussion_r817907842
########## File path: sdks/python/apache_beam/ml/inference/api.py ########## @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +import apache_beam as beam +from typing import Tuple, TypeVar, Union +# TODO: implement RunInferenceImpl Review comment: I would just remove these comments for now. We'll remember them in the next PR. ########## File path: sdks/python/apache_beam/ml/inference/api.py ########## @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +import apache_beam as beam +from typing import Tuple, TypeVar, Union +# TODO: implement RunInferenceImpl +# from apache_beam.ml.inference.base import RunInferenceImpl + + +@dataclass +class BaseModelSpec: + model_url: str + + +@dataclass +class PyTorchModelSpec(BaseModelSpec): + device: str + + def __post_init__(self): + self.device = self.device.upper() + + +@dataclass +class SklearnModelSpec(BaseModelSpec): + serialization_type: str + + def __post_init__(self): Review comment: I think an enum would remove the need for this code entirely. ########## File path: sdks/python/apache_beam/ml/inference/api.py ########## @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +import apache_beam as beam +from typing import Tuple, TypeVar, Union +# TODO: implement RunInferenceImpl +# from apache_beam.ml.inference.base import RunInferenceImpl + + +@dataclass +class BaseModelSpec: + model_url: str + + +@dataclass +class PyTorchModelSpec(BaseModelSpec): + device: str + + def __post_init__(self): + self.device = self.device.upper() + + +@dataclass +class SklearnModelSpec(BaseModelSpec): + serialization_type: str + + def __post_init__(self): + self.serialization_type = self.serialization_type.upper() + + +_K = TypeVar('_K') +_INPUT_TYPE = TypeVar('_INPUT_TYPE') +_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') + + +@dataclass +class PredictionResult: + key: _K Review comment: I don't think key should be in prediciton result. If my understanding of beam is correct then keys will be in tuples (since other beam transforms expect that pattern). So keyed output would look something like: [(key1, PredictionResult(example, inference)), (key2, PredictionResult(....)] ########## File path: sdks/python/apache_beam/ml/inference/api.py ########## @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +import apache_beam as beam +from typing import Tuple, TypeVar, Union +# TODO: implement RunInferenceImpl +# from apache_beam.ml.inference.base import RunInferenceImpl + + +@dataclass +class BaseModelSpec: + model_url: str + + +@dataclass +class PyTorchModelSpec(BaseModelSpec): + device: str + + def __post_init__(self): + self.device = self.device.upper() + + +@dataclass +class SklearnModelSpec(BaseModelSpec): + serialization_type: str + + def __post_init__(self): + self.serialization_type = self.serialization_type.upper() + + +_K = TypeVar('_K') +_INPUT_TYPE = TypeVar('_INPUT_TYPE') +_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') + + +@dataclass +class PredictionResult: + key: _K + example: _INPUT_TYPE + inference: _OUTPUT_TYPE + + [email protected]_fn [email protected]_input_types(Union[_INPUT_TYPE, Tuple[_K, _INPUT_TYPE]]) [email protected]_output_types(PredictionResult) +def RunInference( + examples: beam.pvalue.PCollection, + model: BaseModelSpec) -> beam.pvalue.PCollection: + """Run inference with a model. + + There one type of inference you can perform using this PTransform: + 1. In-process inference from a SavedModel instance. + TODO: Add remote inference by using a service endpoint. + + Args: + examples: A PCollection containing examples of the following possible kinds, + each with their corresponding return type. + - PCollection[Example] -> PCollection[PredictionResult] + - PCollection[Tuple[K, Example]] -> PCollection[ + Tuple[K, PredictionResult]] + model: Model inference endpoint. + Returns: + A PCollection (possibly keyed) containing PredictionResults. + """ + pass + # TODO: implement RunInferenceImpl Review comment: I would just leave these comments out for now, we know we'll be adding this. something like pass #TODO add implementation. ########## File path: sdks/python/apache_beam/ml/inference/api.py ########## @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +import apache_beam as beam +from typing import Tuple, TypeVar, Union +# TODO: implement RunInferenceImpl +# from apache_beam.ml.inference.base import RunInferenceImpl + + +@dataclass +class BaseModelSpec: + model_url: str + + +@dataclass +class PyTorchModel(BaseModelSpec): + device: str + + def __post_init__(self): + self.device = self.device.upper() + + +@dataclass +class SklearnModel(BaseModelSpec): + serialization_type: str + + def __post_init__(self): + self.serialization_type = self.serialization_type.upper() + + +_K = TypeVar('_K') +_INPUT_TYPE = TypeVar('_INPUT_TYPE') +_OUTPUT_TYPE = TypeVar('_OUTPUT_TYPE') + + +@dataclass +class PredictionResult: + key: _K + example: _INPUT_TYPE + inference: _OUTPUT_TYPE + + [email protected]_fn [email protected]_input_types(Union[_INPUT_TYPE, Tuple[_K, _INPUT_TYPE]]) [email protected]_output_types(PredictionResult) Review comment: yeah, TFX is doing it right I think. They either have output type or a key with output type. I'm not the typing expert, IMO its ok to have a typing specific PR that just concentrates on that too, if this gets messy. We need typing to support TFX's proto use case as well as ours new frameworks use case. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
