mohamedawnallah commented on code in PR #34398:
URL: https://github.com/apache/beam/pull/34398#discussion_r2280346195


##########
sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py:
##########
@@ -0,0 +1,659 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import re
+from abc import ABC
+from abc import abstractmethod
+from collections.abc import Callable
+from collections.abc import Mapping
+from dataclasses import dataclass
+from dataclasses import field
+from enum import Enum
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import pg8000
+import pymysql
+import pytds
+from google.cloud.sql.connector import Connector as CloudSQLConnector
+from google.cloud.sql.connector.enums import RefreshStrategy
+from sqlalchemy import create_engine
+from sqlalchemy import text
+from sqlalchemy.engine import Connection as DBAPIConnection
+
+import apache_beam as beam
+from apache_beam.transforms.enrichment import EnrichmentSourceHandler
+
+QueryFn = Callable[[beam.Row], str]
+ConditionValueFn = Callable[[beam.Row], list[Any]]
+
+
+@dataclass
+class CustomQueryConfig:
+  """Configuration for using a custom query function."""
+  query_fn: QueryFn
+
+  def __post_init__(self):
+    if not self.query_fn:
+      raise ValueError("CustomQueryConfig must provide a valid query_fn")
+
+
+@dataclass
+class TableFieldsQueryConfig:
+  """Configuration for using table name, where clause, and field names."""
+  table_id: str
+  where_clause_template: str
+  where_clause_fields: List[str]
+
+  def __post_init__(self):
+    if not self.table_id or not self.where_clause_template:
+      raise ValueError(
+          "TableFieldsQueryConfig must provide table_id and " +
+          "where_clause_template")
+
+    if not self.where_clause_fields:
+      raise ValueError(
+          "TableFieldsQueryConfig must provide non-empty " +
+          "where_clause_fields")
+
+
+@dataclass
+class TableFunctionQueryConfig:
+  """Configuration for using table name, where clause, and a value function."""
+  table_id: str
+  where_clause_template: str
+  where_clause_value_fn: ConditionValueFn
+
+  def __post_init__(self):
+    if not self.table_id or not self.where_clause_template:
+      raise ValueError(
+          "TableFunctionQueryConfig must provide table_id and " +
+          "where_clause_template")
+
+    if not self.where_clause_value_fn:
+      raise ValueError(
+          "TableFunctionQueryConfig must provide " + "where_clause_value_fn")
+
+
+class DatabaseTypeAdapter(Enum):
+  POSTGRESQL = "pg8000"
+  MYSQL = "pymysql"
+  SQLSERVER = "pytds"
+
+  def to_sqlalchemy_dialect(self):
+    """Map the adapter type to its corresponding SQLAlchemy dialect.
+
+    Returns:
+        str: SQLAlchemy dialect string.
+    """
+    if self == DatabaseTypeAdapter.POSTGRESQL:
+      return f"postgresql+{self.value}"
+    elif self == DatabaseTypeAdapter.MYSQL:
+      return f"mysql+{self.value}"
+    elif self == DatabaseTypeAdapter.SQLSERVER:
+      return f"mssql+{self.value}"
+    else:
+      raise ValueError(f"Unsupported database adapter type: {self.name}")
+
+
+class ConnectionConfig(ABC):
+  @abstractmethod
+  def get_connector_handler(self) -> Callable[[], DBAPIConnection]:
+    pass
+
+  @abstractmethod
+  def get_db_url(self) -> str:
+    pass
+
+
+@dataclass
+class CloudSQLConnectionConfig(ConnectionConfig):
+  """Connects to Google Cloud SQL using Cloud SQL Python Connector.
+
+    Args:
+        db_adapter: The database adapter type (PostgreSQL, MySQL, SQL Server).
+        instance_connection_uri: URI for connecting to the Cloud SQL instance.
+        user: Username for authentication.
+        password: Password for authentication. Defaults to None.
+        db_id: Database identifier/name.
+        refresh_strategy: Strategy for refreshing connection (default: LAZY).
+        connector_kwargs: Additional keyword arguments for the
+          Cloud SQL Python Connector. Enables forward compatibility.
+        connect_kwargs: Additional keyword arguments for the client connect
+          method. Enables forward compatibility.
+    """
+  db_adapter: DatabaseTypeAdapter
+  instance_connection_uri: str
+  user: str = field(default_factory=str)
+  password: str = field(default_factory=str)
+  db_id: str = field(default_factory=str)
+  refresh_strategy: RefreshStrategy = RefreshStrategy.LAZY
+  connector_kwargs: Dict[str, Any] = field(default_factory=dict)
+  connect_kwargs: Dict[str, Any] = field(default_factory=dict)
+
+  def __post_init__(self):
+    if not self.instance_connection_uri:
+      raise ValueError("Instance connection URI cannot be empty")
+
+  def get_connector_handler(self) -> Callable[[], DBAPIConnection]:
+    """Returns a function that creates a new database connection.
+
+      The returned connector function creates database connections that should
+      be properly closed by the caller when no longer needed.
+      """
+    cloudsql_client = CloudSQLConnector(
+        refresh_strategy=self.refresh_strategy, **self.connector_kwargs)
+
+    cloudsql_connector = lambda: cloudsql_client.connect(
+        instance_connection_string=self.instance_connection_uri, driver=self.
+        db_adapter.value, user=self.user, password=self.password, 
db=self.db_id,
+        **self.connect_kwargs)
+
+    return cloudsql_connector
+
+  def get_db_url(self) -> str:
+    return self.db_adapter.to_sqlalchemy_dialect() + "://"
+
+
+@dataclass
+class ExternalSQLDBConnectionConfig(ConnectionConfig):
+  """Connects to External SQL DBs (PostgreSQL, MySQL, SQL Server) over TCP.
+
+    Args:
+        db_adapter: The database adapter type (PostgreSQL, MySQL, SQL Server).
+        host: Hostname or IP address of the database server.
+        port: Port number for the database connection.
+        user: Username for authentication.
+        password: Password for authentication.
+        db_id: Database identifier/name.
+        connect_kwargs: Additional keyword arguments for the client connect
+          method. Enables forward compatibility.
+    """
+  db_adapter: DatabaseTypeAdapter
+  host: str
+  port: int
+  user: str = field(default_factory=str)
+  password: str = field(default_factory=str)
+  db_id: str = field(default_factory=str)
+  connect_kwargs: Dict[str, Any] = field(default_factory=dict)
+
+  def __post_init__(self):
+    if not self.host:
+      raise ValueError("Database host cannot be empty")
+
+  def get_connector_handler(self) -> Callable[[], DBAPIConnection]:
+    """Returns a function that creates a new database connection.
+
+      The returned connector function creates database connections that should
+      be properly closed by the caller when no longer needed.
+      """
+    if self.db_adapter == DatabaseTypeAdapter.POSTGRESQL:
+      return lambda: pg8000.connect(
+          host=self.host, port=self.port, database=self.db_id, user=self.user,
+          password=self.password, **self.connect_kwargs)
+    elif self.db_adapter == DatabaseTypeAdapter.MYSQL:
+      return lambda: pymysql.connect(
+          host=self.host, port=self.port, database=self.db_id, user=self.user,
+          password=self.password, **self.connect_kwargs)
+    elif self.db_adapter == DatabaseTypeAdapter.SQLSERVER:
+      return lambda: pytds.connect(
+          dsn=self.host, port=self.port, database=self.db_id, user=self.user,
+          password=self.password, **self.connect_kwargs)
+    else:
+      raise ValueError(f"Unsupported database adapter: {self.db_adapter}")
+
+  def get_db_url(self) -> str:
+    return self.db_adapter.to_sqlalchemy_dialect() + "://"
+
+
+QueryConfig = Union[CustomQueryConfig,
+                    TableFieldsQueryConfig,
+                    TableFunctionQueryConfig]
+
+
+class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]):
+  """Enrichment handler for Cloud SQL databases.
+
+  This handler is designed to work with the
+  :class:`apache_beam.transforms.enrichment.Enrichment` transform.
+
+  To use this handler, you need to provide one of the following query configs:
+    * CustomQueryConfig - For providing a custom query function
+    * TableFieldsQueryConfig - For specifying table, where clause, and fields
+    * TableFunctionQueryConfig - For specifying table, where clause, and val fn

Review Comment:
   > I think these shouldn't be indented. From the failing docs precommit:
   > 
   > ```
   > 
/runner/_work/beam/beam/sdks/python/test-suites/tox/pycommon/build/srcs/sdks/python/target/.tox-docs/docs/lib/python3.9/site-packages/apache_beam/transforms/enrichment_handlers/cloudsql.py:docstring
 of 
apache_beam.transforms.enrichment_handlers.cloudsql.CloudSQLEnrichmentHandler:30:
 ERROR: Unexpected indentation.
   > ```
   
   Addressed



##########
sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py:
##########
@@ -0,0 +1,659 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import re
+from abc import ABC
+from abc import abstractmethod
+from collections.abc import Callable
+from collections.abc import Mapping
+from dataclasses import dataclass
+from dataclasses import field
+from enum import Enum
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import pg8000
+import pymysql
+import pytds
+from google.cloud.sql.connector import Connector as CloudSQLConnector
+from google.cloud.sql.connector.enums import RefreshStrategy
+from sqlalchemy import create_engine
+from sqlalchemy import text
+from sqlalchemy.engine import Connection as DBAPIConnection
+
+import apache_beam as beam
+from apache_beam.transforms.enrichment import EnrichmentSourceHandler
+
+QueryFn = Callable[[beam.Row], str]
+ConditionValueFn = Callable[[beam.Row], list[Any]]
+
+
+@dataclass
+class CustomQueryConfig:
+  """Configuration for using a custom query function."""
+  query_fn: QueryFn
+
+  def __post_init__(self):
+    if not self.query_fn:
+      raise ValueError("CustomQueryConfig must provide a valid query_fn")
+
+
+@dataclass
+class TableFieldsQueryConfig:
+  """Configuration for using table name, where clause, and field names."""
+  table_id: str
+  where_clause_template: str
+  where_clause_fields: List[str]
+
+  def __post_init__(self):
+    if not self.table_id or not self.where_clause_template:
+      raise ValueError(
+          "TableFieldsQueryConfig must provide table_id and " +
+          "where_clause_template")
+
+    if not self.where_clause_fields:
+      raise ValueError(
+          "TableFieldsQueryConfig must provide non-empty " +
+          "where_clause_fields")
+
+
+@dataclass
+class TableFunctionQueryConfig:
+  """Configuration for using table name, where clause, and a value function."""
+  table_id: str
+  where_clause_template: str
+  where_clause_value_fn: ConditionValueFn
+
+  def __post_init__(self):
+    if not self.table_id or not self.where_clause_template:
+      raise ValueError(
+          "TableFunctionQueryConfig must provide table_id and " +
+          "where_clause_template")
+
+    if not self.where_clause_value_fn:
+      raise ValueError(
+          "TableFunctionQueryConfig must provide " + "where_clause_value_fn")
+
+
+class DatabaseTypeAdapter(Enum):
+  POSTGRESQL = "pg8000"
+  MYSQL = "pymysql"
+  SQLSERVER = "pytds"
+
+  def to_sqlalchemy_dialect(self):
+    """Map the adapter type to its corresponding SQLAlchemy dialect.
+
+    Returns:
+        str: SQLAlchemy dialect string.
+    """
+    if self == DatabaseTypeAdapter.POSTGRESQL:
+      return f"postgresql+{self.value}"
+    elif self == DatabaseTypeAdapter.MYSQL:
+      return f"mysql+{self.value}"
+    elif self == DatabaseTypeAdapter.SQLSERVER:
+      return f"mssql+{self.value}"
+    else:
+      raise ValueError(f"Unsupported database adapter type: {self.name}")
+
+
+class ConnectionConfig(ABC):
+  @abstractmethod
+  def get_connector_handler(self) -> Callable[[], DBAPIConnection]:
+    pass
+
+  @abstractmethod
+  def get_db_url(self) -> str:
+    pass
+
+
+@dataclass
+class CloudSQLConnectionConfig(ConnectionConfig):
+  """Connects to Google Cloud SQL using Cloud SQL Python Connector.
+
+    Args:
+        db_adapter: The database adapter type (PostgreSQL, MySQL, SQL Server).
+        instance_connection_uri: URI for connecting to the Cloud SQL instance.
+        user: Username for authentication.
+        password: Password for authentication. Defaults to None.
+        db_id: Database identifier/name.
+        refresh_strategy: Strategy for refreshing connection (default: LAZY).
+        connector_kwargs: Additional keyword arguments for the
+          Cloud SQL Python Connector. Enables forward compatibility.
+        connect_kwargs: Additional keyword arguments for the client connect
+          method. Enables forward compatibility.
+    """
+  db_adapter: DatabaseTypeAdapter
+  instance_connection_uri: str
+  user: str = field(default_factory=str)
+  password: str = field(default_factory=str)
+  db_id: str = field(default_factory=str)
+  refresh_strategy: RefreshStrategy = RefreshStrategy.LAZY
+  connector_kwargs: Dict[str, Any] = field(default_factory=dict)
+  connect_kwargs: Dict[str, Any] = field(default_factory=dict)
+
+  def __post_init__(self):
+    if not self.instance_connection_uri:
+      raise ValueError("Instance connection URI cannot be empty")
+
+  def get_connector_handler(self) -> Callable[[], DBAPIConnection]:
+    """Returns a function that creates a new database connection.
+
+      The returned connector function creates database connections that should
+      be properly closed by the caller when no longer needed.
+      """
+    cloudsql_client = CloudSQLConnector(
+        refresh_strategy=self.refresh_strategy, **self.connector_kwargs)
+
+    cloudsql_connector = lambda: cloudsql_client.connect(
+        instance_connection_string=self.instance_connection_uri, driver=self.
+        db_adapter.value, user=self.user, password=self.password, 
db=self.db_id,
+        **self.connect_kwargs)
+
+    return cloudsql_connector
+
+  def get_db_url(self) -> str:
+    return self.db_adapter.to_sqlalchemy_dialect() + "://"
+
+
+@dataclass
+class ExternalSQLDBConnectionConfig(ConnectionConfig):
+  """Connects to External SQL DBs (PostgreSQL, MySQL, SQL Server) over TCP.
+
+    Args:
+        db_adapter: The database adapter type (PostgreSQL, MySQL, SQL Server).
+        host: Hostname or IP address of the database server.
+        port: Port number for the database connection.
+        user: Username for authentication.
+        password: Password for authentication.
+        db_id: Database identifier/name.
+        connect_kwargs: Additional keyword arguments for the client connect
+          method. Enables forward compatibility.
+    """
+  db_adapter: DatabaseTypeAdapter
+  host: str
+  port: int
+  user: str = field(default_factory=str)
+  password: str = field(default_factory=str)
+  db_id: str = field(default_factory=str)
+  connect_kwargs: Dict[str, Any] = field(default_factory=dict)
+
+  def __post_init__(self):
+    if not self.host:
+      raise ValueError("Database host cannot be empty")
+
+  def get_connector_handler(self) -> Callable[[], DBAPIConnection]:
+    """Returns a function that creates a new database connection.
+
+      The returned connector function creates database connections that should
+      be properly closed by the caller when no longer needed.
+      """
+    if self.db_adapter == DatabaseTypeAdapter.POSTGRESQL:
+      return lambda: pg8000.connect(
+          host=self.host, port=self.port, database=self.db_id, user=self.user,
+          password=self.password, **self.connect_kwargs)
+    elif self.db_adapter == DatabaseTypeAdapter.MYSQL:
+      return lambda: pymysql.connect(
+          host=self.host, port=self.port, database=self.db_id, user=self.user,
+          password=self.password, **self.connect_kwargs)
+    elif self.db_adapter == DatabaseTypeAdapter.SQLSERVER:
+      return lambda: pytds.connect(
+          dsn=self.host, port=self.port, database=self.db_id, user=self.user,
+          password=self.password, **self.connect_kwargs)
+    else:
+      raise ValueError(f"Unsupported database adapter: {self.db_adapter}")
+
+  def get_db_url(self) -> str:
+    return self.db_adapter.to_sqlalchemy_dialect() + "://"
+
+
+QueryConfig = Union[CustomQueryConfig,
+                    TableFieldsQueryConfig,
+                    TableFunctionQueryConfig]
+
+
+class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]):
+  """Enrichment handler for Cloud SQL databases.
+
+  This handler is designed to work with the
+  :class:`apache_beam.transforms.enrichment.Enrichment` transform.
+
+  To use this handler, you need to provide one of the following query configs:
+    * CustomQueryConfig - For providing a custom query function
+    * TableFieldsQueryConfig - For specifying table, where clause, and fields
+    * TableFunctionQueryConfig - For specifying table, where clause, and val fn

Review Comment:
   > I think these shouldn't be indented. From the failing docs precommit:
   > 
   > ```
   > 
/runner/_work/beam/beam/sdks/python/test-suites/tox/pycommon/build/srcs/sdks/python/target/.tox-docs/docs/lib/python3.9/site-packages/apache_beam/transforms/enrichment_handlers/cloudsql.py:docstring
 of 
apache_beam.transforms.enrichment_handlers.cloudsql.CloudSQLEnrichmentHandler:30:
 ERROR: Unexpected indentation.
   > ```
   
   Addressed. The CI now passing



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to