yeandy commented on code in PR #17762:
URL: https://github.com/apache/beam/pull/17762#discussion_r882910668


##########
sdks/python/apache_beam/ml/inference/sklearn_inference.py:
##########
@@ -41,7 +41,9 @@ class ModelFileType(enum.Enum):
   JOBLIB = 2
 
 
-class SklearnInferenceRunner(InferenceRunner):
+class SklearnInferenceRunner(InferenceRunner[numpy.ndarray,
+                                             PredictionResult,
+                                             Any]):

Review Comment:
   Correct me if I'm wrong @ryanthompson591, but I think the base class for 
sklearn model is `BaseEstimator` (from sklearn.base import BaseEstimator). 
Could we use this for the type instead of `Any`?



##########
sdks/python/apache_beam/transforms/ptransform.py:
##########
@@ -328,7 +331,7 @@ def visit_dict(self, pvalueish, sibling, pairs, context):
         self.visit(p, sibling, pairs, context)
 
 
-class PTransform(WithTypeHints, HasDisplayData):
+class PTransform(WithTypeHints, HasDisplayData, Generic[InputT, OutputT]):

Review Comment:
   Adding `Generic[InputT, OutputT]` breaks this test: 
`<apache_beam.transforms.periodicsequence_test.PeriodicSequenceTest 
testMethod=test_periodicimpulse_default_start>`. 
   
   ```
       def test_periodicimpulse_default_start(self):
         default_parameters = inspect.signature(PeriodicImpulse).parameters
   >     it = default_parameters["start_timestamp"].default
   E     KeyError: 'start_timestamp'
   ```
   
   Do you know why this happens?



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -153,14 +160,12 @@ def update(
     self._inference_request_batch_byte_size.update(examples_byte_size)
 
 
-class _RunInferenceDoFn(beam.DoFn):
+class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
   """A DoFn implementation generic to frameworks."""
-  def __init__(self, model_loader: ModelLoader, clock=None):
+  def __init__(
+      self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock=None):

Review Comment:
   Why are we using `Any` instead of `ModelT` here? (same with line 99)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to