gemini-code-assist[bot] commented on code in PR #38215: URL: https://github.com/apache/beam/pull/38215#discussion_r3292856825
########## sdks/python/apache_beam/examples/ml_transform/mltransform_generate_vocab.py: ########## @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Batch-only vocabulary generation pipeline using MLTransform. + +This pipeline creates a vocabulary artifact from one or more input columns. + +Key properties: +- Batch only (no streaming path). +- Vocabulary generation via MLTransform ComputeAndApplyVocabulary. +- Output format: one token per line. +""" + +import argparse +import json +import logging +import tempfile +from typing import Any + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary +from apache_beam.options.pipeline_options import PipelineOptions + + +def parse_bool_flag(value: str) -> bool: + value_lc = value.strip().lower() + if value_lc in ('1', 'true', 't', 'yes', 'y'): + return True + if value_lc in ('0', 'false', 'f', 'no', 'n'): + return False + raise ValueError( + f'Invalid boolean value {value!r}. Expected true/false style value.') + + +def normalize_text(value: Any, lowercase: bool = True) -> str: + if value is None: + return '' + text = str(value).strip() + if lowercase: + text = text.lower() + return text + + +def _parse_json_line(line: str) -> dict[str, Any]: + try: + parsed = json.loads(line) + except json.JSONDecodeError: + # Treat plain-text rows as values for the default "text" column. + return {'text': line} + if not isinstance(parsed, dict): + raise ValueError( + f'Input JSON line must decode to an object, got: {parsed!r}') + return parsed + + +def _extract_column_values(row: dict[str, Any], + columns: list[str]) -> list[str]: + values: list[str] = [] + for col in columns: + if col not in row: + continue + val = row[col] + if val is None: + continue + if isinstance(val, list): + values.extend(str(item) for item in val if item is not None) + else: + values.append(str(val)) + return values + + +def _build_vocab_text( + row: dict[str, Any], columns: list[str], lowercase: bool) -> str: + values = _extract_column_values(row, columns) + normalized_values = [ + normalize_text(value, lowercase=lowercase) for value in values + ] + non_empty_values = [value for value in normalized_values if value] + return ' '.join(non_empty_values) + + +def _resolve_vocab_asset_path( + artifact_location: str, vocab_filename: str, column_name: str) -> str: + asset_name = f'{vocab_filename}_{column_name}' + pattern = ( + f'{artifact_location.rstrip("/")}' + f'/*/transform_fn/assets/{asset_name}') + matches = FileSystems.match([pattern])[0].metadata_list + if not matches: + raise ValueError( + f'Could not locate vocabulary artifact {asset_name!r} under ' + f'{artifact_location!r}.') + return matches[0].path Review Comment:  The glob pattern used here may match multiple directories if multiple runs have been performed in the same `artifact_location`. `MLTransform` typically creates a timestamped subdirectory for each run. Picking `matches[0]` without sorting by timestamp or ensuring uniqueness could lead to non-deterministic results, potentially reading a vocabulary from a previous run instead of the current one. ########## sdks/python/apache_beam/examples/ml_transform/mltransform_generate_vocab.py: ########## @@ -0,0 +1,247 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Batch-only vocabulary generation pipeline using MLTransform. + +This pipeline creates a vocabulary artifact from one or more input columns. + +Key properties: +- Batch only (no streaming path). +- Vocabulary generation via MLTransform ComputeAndApplyVocabulary. +- Output format: one token per line. +""" + +import argparse +import json +import logging +import tempfile +from typing import Any + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.ml.transforms.tft import ComputeAndApplyVocabulary +from apache_beam.options.pipeline_options import PipelineOptions + + +def parse_bool_flag(value: str) -> bool: + value_lc = value.strip().lower() + if value_lc in ('1', 'true', 't', 'yes', 'y'): + return True + if value_lc in ('0', 'false', 'f', 'no', 'n'): + return False + raise ValueError( + f'Invalid boolean value {value!r}. Expected true/false style value.') + + +def normalize_text(value: Any, lowercase: bool = True) -> str: + if value is None: + return '' + text = str(value).strip() + if lowercase: + text = text.lower() + return text + + +def _parse_json_line(line: str) -> dict[str, Any]: + try: + parsed = json.loads(line) + except json.JSONDecodeError: + # Treat plain-text rows as values for the default "text" column. + return {'text': line} + if not isinstance(parsed, dict): + raise ValueError( + f'Input JSON line must decode to an object, got: {parsed!r}') + return parsed + + +def _extract_column_values(row: dict[str, Any], + columns: list[str]) -> list[str]: + values: list[str] = [] + for col in columns: + if col not in row: + continue + val = row[col] + if val is None: + continue + if isinstance(val, list): + values.extend(str(item) for item in val if item is not None) + else: + values.append(str(val)) + return values + + +def _build_vocab_text( + row: dict[str, Any], columns: list[str], lowercase: bool) -> str: + values = _extract_column_values(row, columns) + normalized_values = [ + normalize_text(value, lowercase=lowercase) for value in values + ] + non_empty_values = [value for value in normalized_values if value] + return ' '.join(non_empty_values) + + +def _resolve_vocab_asset_path( + artifact_location: str, vocab_filename: str, column_name: str) -> str: + asset_name = f'{vocab_filename}_{column_name}' + pattern = ( + f'{artifact_location.rstrip("/")}' + f'/*/transform_fn/assets/{asset_name}') + matches = FileSystems.match([pattern])[0].metadata_list + if not matches: + raise ValueError( + f'Could not locate vocabulary artifact {asset_name!r} under ' + f'{artifact_location!r}.') + return matches[0].path + + +def _read_vocab_tokens(vocab_asset_path: str) -> list[str]: + tokens = [] + with FileSystems.open(vocab_asset_path) as f: + for raw_line in f: + token = raw_line.decode('utf-8').rstrip('\n') + if token: + tokens.append(token) + return tokens + + +def _write_vocab_file(output_path: str, tokens: list[str]) -> None: + with FileSystems.create(output_path) as f: + for token in tokens: + f.write((token + '\n').encode('utf-8')) + + +def parse_known_args(argv): + parser = argparse.ArgumentParser( + description='Generate vocabulary from batch input with MLTransform.') + parser.add_argument('--input_file', help='Input JSONL file path.') + parser.add_argument( + '--input_table', + help='Input BigQuery table path in PROJECT:DATASET.TABLE format.') + parser.add_argument('--output_vocab', help='Output vocab file prefix/path.') + parser.add_argument( + '--columns', + help='Comma-separated source columns to include in vocabulary.') + parser.add_argument( + '--vocab_size', + type=int, + default=50000, + help='Maximum vocabulary size (top-K by frequency).') + parser.add_argument( + '--min_frequency', + type=int, + default=1, + help='Minimum token frequency required to keep token.') + parser.add_argument( + '--lowercase', + default='true', + help='Whether to lowercase text before vocabulary generation.') + parser.add_argument( + '--input_expand_factor', + type=int, + default=1, + help=( + 'Batch-only: repeat each input line this many times to scale volume ' + 'for load/perf testing.')) + parser.add_argument( + '--artifact_location', + default='', + help=( + 'Artifact directory for MLTransform output. If empty, a temporary ' + 'local directory is used.')) + return parser.parse_known_args(argv) + + +def validate_args(args) -> list[str]: + has_input_file = bool(args.input_file) + has_input_table = bool(args.input_table) + if not has_input_file and not has_input_table: + raise ValueError('One of --input_file or --input_table is required.') + if has_input_file and has_input_table: + raise ValueError('Use exactly one of --input_file or --input_table.') + if not args.output_vocab: + raise ValueError('--output_vocab is required.') + if not args.columns: + raise ValueError('--columns is required.') + if args.vocab_size is None or args.vocab_size <= 0: + raise ValueError('--vocab_size must be > 0.') + if args.min_frequency is None or args.min_frequency < 1: + raise ValueError('--min_frequency must be >= 1.') + if args.input_expand_factor is None or args.input_expand_factor < 1: + raise ValueError('--input_expand_factor must be >= 1.') + return [col.strip() for col in args.columns.split(',') if col.strip()] + + +def run(argv=None, test_pipeline=None): + known_args, pipeline_args = parse_known_args(argv) + columns = validate_args(known_args) + lowercase = parse_bool_flag(known_args.lowercase) + artifact_location = known_args.artifact_location or tempfile.mkdtemp( + prefix='mltransform_generate_vocab_artifacts_') Review Comment:  Using `tempfile.mkdtemp()` as a default for `artifact_location` will cause issues when running on the `DataflowRunner`. `MLTransform` writes artifacts to this location from the workers. If the location is local to the driver machine, workers on Dataflow will not be able to write to it in a way that the driver can later access. For remote runners, this should ideally default to a sub-path of the pipeline's `temp_location` on GCS to ensure workers can write to a shared, accessible location. ########## sdks/python/apache_beam/examples/ml_transform/mltransform_generate_vocab_test.py: ########## @@ -0,0 +1,191 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os +import tempfile +import unittest + +try: + from apache_beam.examples.ml_transform import mltransform_generate_vocab +except ImportError: + raise unittest.SkipTest('tensorflow_transform is not installed.') + + +class MLTransformGenerateVocabUnitTest(unittest.TestCase): + def test_normalize_text(self): + text = mltransform_generate_vocab.normalize_text(' Hello Beam ', True) + self.assertEqual(text, 'hello beam') + + def test_null_and_empty_handling_helpers(self): + normalized_none = mltransform_generate_vocab.normalize_text(None, True) + self.assertEqual(normalized_none, '') + self.assertEqual( + mltransform_generate_vocab._build_vocab_text( + {'text': ['Beam', None, ' ', 'Flow']}, ['text'], lowercase=True), + 'beam flow') + + +class MLTransformGenerateVocabCliValidationTest(unittest.TestCase): + def test_missing_required_args(self): + args, _ = mltransform_generate_vocab.parse_known_args([]) + with self.assertRaisesRegex(ValueError, 'input_file or --input_table'): + mltransform_generate_vocab.validate_args(args) + + def test_invalid_numeric_values(self): + args, _ = mltransform_generate_vocab.parse_known_args([ + '--input_file=a.jsonl', + '--output_vocab=/tmp/vocab', + '--columns=text', + '--vocab_size=0', + '--min_frequency=0', + ]) + with self.assertRaisesRegex(ValueError, 'vocab_size'): + mltransform_generate_vocab.validate_args(args) + + def test_invalid_input_expand_factor(self): + args, _ = mltransform_generate_vocab.parse_known_args([ + '--input_file=a.jsonl', + '--output_vocab=/tmp/vocab', + '--columns=text', + '--input_expand_factor=0', + ]) + with self.assertRaisesRegex(ValueError, 'input_expand_factor'): + mltransform_generate_vocab.validate_args(args) + + +class MLTransformGenerateVocabIntegrationTest(unittest.TestCase): + def test_batch_pipeline_exact_output_order(self): + with tempfile.TemporaryDirectory() as tmpdir: + input_path = os.path.join(tmpdir, 'input.jsonl') + output_prefix = os.path.join(tmpdir, 'vocab.txt') + + rows = [ + { + 'id': '1', 'text': 'Beam beam ML pipeline' + }, + { + 'id': '2', 'text': 'Beam pipeline dataflow' + }, + { + 'id': '3', 'text': 'ML transform beam' + }, + { + 'id': '4', 'text': 'vocab vocab vocab test' + }, + { + 'id': '5', 'text': 'rare_token_once' + }, + { + 'id': '6', 'text': '' + }, + { + 'id': '7', 'text': None + }, + ] + with open(input_path, 'w', encoding='utf-8') as f: + for row in rows: + f.write(json.dumps(row) + '\n') + + mltransform_generate_vocab.run([ + f'--input_file={input_path}', + f'--output_vocab={output_prefix}', + '--columns=text', + '--vocab_size=3', + '--min_frequency=2', + '--lowercase=true', + '--runner=DirectRunner', + ]) + + output_path = output_prefix + with open(output_path, 'r', encoding='utf-8') as f: + output_tokens = [line.rstrip('\n') for line in f] + + self.assertEqual(set(output_tokens), {'beam', 'vocab', 'ml'}) + self.assertEqual(len(output_tokens), 3) Review Comment:  The test name `test_batch_pipeline_exact_output_order` suggests that the order of tokens in the vocabulary is being verified. However, the implementation uses `set(output_tokens)`, which ignores the order. If the order is important (e.g., for benchmarking stability or verifying frequency-based sorting), the test should compare lists directly. If the order is not guaranteed, the test name should be updated to reflect that. ########## sdks/python/apache_beam/testing/load_tests/dataflow_cost_benchmark.py: ########## @@ -193,6 +193,29 @@ def _get_throughput_metrics( """Query Cloud Monitoring for per-PCollection throughput.""" name = ( pcollection_name if pcollection_name is not None else self.pcollection) + + def _point_numeric_value(point) -> float: + value = point.value + # point.value is proto-plus, so use the underlying protobuf oneof. + raw_value = getattr(value, '_pb', None) + if raw_value is not None: + active_field = raw_value.WhichOneof('value') + if active_field == 'double_value': + return float(value.double_value) + if active_field == 'int64_value': + return float(value.int64_value) + if active_field == 'distribution_value': + # Use aligned mean for distribution-valued points. + distribution = value.distribution_value + if distribution.count > 0: + return float(distribution.mean) + return 0.0 + if active_field == 'money_value': + money = value.money_value + nanos = getattr(money, 'nanos', 0) or 0 Review Comment:  The `or 0` at the end of this line is redundant because `getattr(money, 'nanos', 0)` already provides a default value of `0` if the attribute is missing. Additionally, if the attribute exists but is `0`, the `or 0` still results in `0`. ```suggestion nanos = getattr(money, 'nanos', 0) ``` ########## sdks/python/apache_beam/ml/anomaly/transforms.py: ########## @@ -400,7 +400,6 @@ def expand( ret = ( input - | beam.Reshuffle() | f"Score and Learn ({model_uuid})" >> RunScoreAndLearn(self._detector)) Review Comment:  The `beam.Reshuffle()` transform was removed here. Reshuffle is often used to prevent fusion and ensure that work is distributed across workers, especially before heavy processing steps like `RunScoreAndLearn`. Removing it might impact the scalability and performance of the anomaly detection pipeline by allowing the runner to fuse the scoring step with preceding transforms, potentially limiting parallelism. Could you provide the rationale for this change? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
