claudevdm commented on code in PR #34398: URL: https://github.com/apache/beam/pull/34398#discussion_r2138380754
########## sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py: ########## @@ -0,0 +1,370 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections.abc import Callable +from collections.abc import Mapping +from dataclasses import dataclass +from enum import Enum +from typing import Any +from typing import List +from typing import Optional +from typing import Union + +from sqlalchemy import create_engine +from sqlalchemy import text + +import apache_beam as beam +from apache_beam.transforms.enrichment import EnrichmentSourceHandler + +QueryFn = Callable[[beam.Row], str] +ConditionValueFn = Callable[[beam.Row], list[Any]] + + +@dataclass +class CustomQueryConfig: + """Configuration for using a custom query function.""" + query_fn: QueryFn + + def __post_init__(self): + if not self.query_fn: + raise ValueError("CustomQueryConfig must provide a valid query_fn") + + +@dataclass +class TableFieldsQueryConfig: + """Configuration for using table name, where clause, and field names.""" + table_id: str + where_clause_template: str + where_clause_fields: List[str] + + def __post_init__(self): + if not self.table_id or not self.where_clause_template: + raise ValueError( + "TableFieldsQueryConfig and " + + "TableFunctionQueryConfig must provide table_id " + + "and where_clause_template") + + if not self.where_clause_fields: + raise ValueError( + "TableFieldsQueryConfig must provide non-empty " + + "where_clause_fields") + + +@dataclass +class TableFunctionQueryConfig: + """Configuration for using table name, where clause, and a value function.""" + table_id: str + where_clause_template: str + where_clause_value_fn: ConditionValueFn + + def __post_init__(self): + if not self.table_id or not self.where_clause_template: + raise ValueError( + "TableFieldsQueryConfig and " + + "TableFunctionQueryConfig must provide table_id " + + "and where_clause_template") + + if not self.where_clause_value_fn: + raise ValueError( + "TableFunctionQueryConfig must provide " + "where_clause_value_fn") + + +QueryConfig = Union[CustomQueryConfig, + TableFieldsQueryConfig, + TableFunctionQueryConfig] + + +class DatabaseTypeAdapter(Enum): + POSTGRESQL = "psycopg2" + MYSQL = "pymysql" + SQLSERVER = "pytds" + + def to_sqlalchemy_dialect(self): + """Map the adapter type to its corresponding SQLAlchemy dialect. + + Returns: + str: SQLAlchemy dialect string. + """ + if self == DatabaseTypeAdapter.POSTGRESQL: + return f"postgresql+{self.value}" + elif self == DatabaseTypeAdapter.MYSQL: + return f"mysql+{self.value}" + elif self == DatabaseTypeAdapter.SQLSERVER: + return f"mssql+{self.value}" + else: + raise ValueError(f"Unsupported adapter type: {self.name}") + + +class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): + """Enrichment handler for Cloud SQL databases. + + This handler is designed to work with the + :class:`apache_beam.transforms.enrichment.Enrichment` transform. + + To use this handler, you need to provide one of the following query configs: + * CustomQueryConfig - For providing a custom query function + * TableFieldsQueryConfig - For specifying table, where clause, and fields + * TableFunctionQueryConfig - For specifying table, where clause, and val fn + + By default, the handler retrieves all columns from the specified table. + To limit the columns, use the `column_names` parameter to specify + the desired column names. + + This handler queries the Cloud SQL database per element by default. + To enable batching, set the `min_batch_size` and `max_batch_size` parameters. + These values control the batching behavior in the + :class:`apache_beam.transforms.utils.BatchElements` transform. + + NOTE: Batching is not supported when using the CustomQueryConfig. + """ + def __init__( + self, + database_type_adapter: DatabaseTypeAdapter, + database_address: str, + database_user: str, + database_password: str, + database_id: str, + *, + query_config: QueryConfig, + column_names: Optional[list[str]] = None, + min_batch_size: int = 1, + max_batch_size: int = 10000, + **kwargs, + ): + """ + Example Usage: + handler = CloudSQLEnrichmentHandler( + database_type_adapter=adapter, + database_address='127.0.0.1:5432', + database_user='user', + database_password='password', + database_id='my_database', + query_config=TableFieldsQueryConfig('my_table',"id = '{}'",['id']), + min_batch_size=2, + max_batch_size=100) + + Args: + database_type_adapter: Adapter to handle specific database type operations + (e.g., MySQL, PostgreSQL). + database_address (str): Address or hostname of the Cloud SQL database, in + the form `<ip>:<port>`. The port is optional if the database uses + the default port. + database_user (str): Username for accessing the database. + database_password (str): Password for accessing the database. + database_id (str): Identifier for the database to query. + query_config: Configuration for database queries. Must be one of: + * CustomQueryConfig: For providing a custom query function + * TableFieldsQueryConfig: specifies table, where clause, and field names + * TableFunctionQueryConfig: specifies table, where clause, and val func + column_names (Optional[list[str]]): List of column names to select from + the Cloud SQL table. If not provided, all columns (`*`) are selected. + min_batch_size (int): Minimum number of rows to batch together when + querying the database. Defaults to 1 if `query_fn` is not used. + max_batch_size (int): Maximum number of rows to batch together. Defaults + to 10,000 if `query_fn` is not used. + **kwargs: Additional keyword arguments for database connection or query + handling. + + Note: + * Cannot use `min_batch_size` or `max_batch_size` with `query_fn`. + * Either `where_clause_fields` or `where_clause_value_fn` must be provided + for query construction if `query_fn` is not provided. + * Ensure that the database user has the necessary permissions to query the + specified table. + """ + self._database_type_adapter = database_type_adapter + self._database_id = database_id + self._database_user = database_user + self._database_password = database_password + self._database_address = database_address Review Comment: CloudSQL offers language connectors so users can connect to their database without allowlisting their IP, using IAM auth, and not having to provide the database IP. See https://cloud.google.com/sql/docs/mysql/connect-connectors#setup-and-usage for more detailed information. We don't have to add it in this PR, but I think we should change this interface so that it is easily interchangeable. For example, we can have a ConnectionConfig that captures all the arguments to connect without a language connector (database_address etc). And then we can later on add a LanguageConnectorConfig that accepts parameters to connect via language connector https://cloud.google.com/sql/docs/mysql/connect-connectors#examples Here is an example of a language connector for embeddings cloudsql (using jdbcio) https://github.com/apache/beam/blob/91d6ec23fb5bc60e0558d7841ec5e220a595306c/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@beam.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org