saavannanavati commented on a change in pull request #12352: URL: https://github.com/apache/beam/pull/12352#discussion_r471163295
########## File path: sdks/python/apache_beam/typehints/typecheck.py ########## @@ -265,3 +268,89 @@ def visit_transform(self, applied_transform): transform.get_type_hints(), applied_transform.full_label), applied_transform.full_label) + + +class PerformanceTypeCheckVisitor(pipeline.PipelineVisitor): + + _in_combine = False + combine_classes = ( + core.CombineFn, + core.CombinePerKey, + core.CombineValuesDoFn, + core.CombineValues, + core.CombineGlobally) + + def enter_composite_transform(self, applied_transform): + if isinstance(applied_transform.transform, self.combine_classes): + self._in_combine = True + + def leave_composite_transform(self, applied_transform): + if isinstance(applied_transform.transform, self.combine_classes): + self._in_combine = False + + def visit_transform(self, applied_transform): + transform = applied_transform.transform + if isinstance(transform, core.ParDo) and not self._in_combine: + # Prefix label with 'ParDo' if necessary + full_label = applied_transform.full_label + if not full_label.startswith('ParDo'): + full_label = 'ParDo(%s)' % full_label + + # Store output type hints in current transform + transform.fn._runtime_output_constraints = {} + output_type_hints = self.get_output_type_hints(transform) + if output_type_hints: + transform.fn._runtime_output_constraints[full_label] = ( + output_type_hints) + + # Store input type hints in producer transform + producer = applied_transform.inputs[0].producer + input_type_hints = self.get_input_type_hints(transform) + if input_type_hints: + producer.transform._add_type_constraint_from_consumer( + full_label, input_type_hints) + + def get_input_type_hints(self, transform): + type_hints = transform.get_type_hints() + + input_types = None + if type_hints.input_types: + normal_hints, kwarg_hints = type_hints.input_types + if kwarg_hints: + input_types = kwarg_hints + if normal_hints: + input_types = normal_hints + + parameter_name = 'Unknown Parameter' + try: + argspec = inspect.getfullargspec(transform.fn._process_argspec_fn()) + if len(argspec.args): + arg_index = 0 + if argspec.args[0] == 'self': + arg_index = 1 + parameter_name = argspec.args[arg_index] + if isinstance(input_types, dict): + input_types = (input_types[argspec.args[arg_index]], ) + except TypeError: Review comment: This was leftover from something else, I'll remove it. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org