This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 28fd2b2c9df Mysql embeddings (#35393)
28fd2b2c9df is described below
commit 28fd2b2c9dfd7286c16e8da4504e2dc7f3605984
Author: claudevdm <[email protected]>
AuthorDate: Mon Jul 7 16:54:10 2025 -0400
Mysql embeddings (#35393)
* Add MySQL vector writer.
* Trigger tests again.
* Comments.
* Fix lints etc.
* Comment.
* Fix typo
* Lint fix.
* Update sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py
Co-authored-by: Danny McCormick <[email protected]>
---------
Co-authored-by: Claude <[email protected]>
Co-authored-by: Danny McCormick <[email protected]>
---
.../schemaio-expansion-service/build.gradle | 2 +
.../apache_beam/ml/rag/ingestion/cloudsql.py | 61 +-
.../ml/rag/ingestion/cloudsql_it_test.py | 1056 +++++++++++++++++---
sdks/python/apache_beam/ml/rag/ingestion/mysql.py | 268 +++++
.../apache_beam/ml/rag/ingestion/mysql_common.py | 433 ++++++++
.../apache_beam/ml/rag/ingestion/test_utils.py | 9 -
6 files changed, 1679 insertions(+), 150 deletions(-)
diff --git a/sdks/java/extensions/schemaio-expansion-service/build.gradle
b/sdks/java/extensions/schemaio-expansion-service/build.gradle
index 15873d58e61..12ee92a9e10 100644
--- a/sdks/java/extensions/schemaio-expansion-service/build.gradle
+++ b/sdks/java/extensions/schemaio-expansion-service/build.gradle
@@ -64,6 +64,8 @@ dependencies {
permitUnusedDeclared 'com.google.cloud:alloydb-jdbc-connector:1.2.0'
implementation 'com.google.cloud.sql:postgres-socket-factory:1.25.0'
permitUnusedDeclared 'com.google.cloud.sql:postgres-socket-factory:1.25.0'
+ implementation
'com.google.cloud.sql:mysql-socket-factory-connector-j-8:1.25.0'
+ permitUnusedDeclared
'com.google.cloud.sql:mysql-socket-factory-connector-j-8:1.25.0'
testImplementation library.java.junit
testImplementation library.java.mockito_core
runtimeOnly ("org.xerial:sqlite-jdbc:3.49.1.0")
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py
b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py
index 69ead961a76..d3710a7f70a 100644
--- a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py
+++ b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py
@@ -21,12 +21,12 @@ from typing import Dict
from typing import List
from typing import Optional
+from apache_beam.ml.rag.ingestion import mysql
+from apache_beam.ml.rag.ingestion import mysql_common
+from apache_beam.ml.rag.ingestion import postgres
+from apache_beam.ml.rag.ingestion import postgres_common
from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
-from apache_beam.ml.rag.ingestion.postgres import ColumnSpecsBuilder
-from apache_beam.ml.rag.ingestion.postgres import PostgresVectorWriterConfig
-from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
-from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution
@dataclass
@@ -138,7 +138,7 @@ class _PostgresConnectorConfig(LanguageConnectorConfig):
return cls(**asdict(config))
-class CloudSQLPostgresVectorWriterConfig(PostgresVectorWriterConfig):
+class CloudSQLPostgresVectorWriterConfig(postgres.PostgresVectorWriterConfig):
def __init__(
self,
connection_config: LanguageConnectorConfig,
@@ -146,10 +146,11 @@ class
CloudSQLPostgresVectorWriterConfig(PostgresVectorWriterConfig):
*,
# pylint: disable=dangerous-default-value
write_config: WriteConfig = WriteConfig(),
- column_specs: List[ColumnSpec] =
ColumnSpecsBuilder.with_defaults().build(
- ),
- conflict_resolution: Optional[ConflictResolution] = ConflictResolution(
- on_conflict_fields=[], action='IGNORE')):
+ column_specs: List[postgres_common.ColumnSpec] = postgres_common.
+ ColumnSpecsBuilder.with_defaults().build(),
+ conflict_resolution: Optional[
+ postgres_common.ConflictResolution] = postgres_common.
+ ConflictResolution(on_conflict_fields=[], action='IGNORE')):
"""Configuration for writing vectors to ClouSQL Postgres.
Supports flexible schema configuration through column specifications and
@@ -218,3 +219,45 @@ class
CloudSQLPostgresVectorWriterConfig(PostgresVectorWriterConfig):
table_name=table_name,
column_specs=column_specs,
conflict_resolution=conflict_resolution)
+
+
+@dataclass
+class _MySQLConnectorConfig(LanguageConnectorConfig):
+ def to_jdbc_url(self) -> str:
+ """Convert options to a properly formatted MySQL JDBC URL."""
+ return self._build_jdbc_url(
+ socketFactory="com.google.cloud.sql.mysql.SocketFactory",
+ database_type="mysql")
+
+ def additional_jdbc_args(self) -> Dict[str, List[Any]]:
+ return {
+ 'classpath': [
+ "mysql:mysql-connector-java:8.0.22",
+ "com.google.cloud.sql:mysql-socket-factory-connector-j-8:1.25.0"
+ ]
+ }
+
+ @classmethod
+ def from_base_config(cls, config: LanguageConnectorConfig):
+ return cls(**asdict(config))
+
+
+class CloudSQLMySQLVectorWriterConfig(mysql.MySQLVectorWriterConfig):
+ def __init__(
+ self,
+ connection_config: LanguageConnectorConfig,
+ table_name: str,
+ *,
+ write_config: WriteConfig = WriteConfig(),
+ # pylint: disable=dangerous-default-value
+ column_specs: List[mysql_common.ColumnSpec] = mysql_common.
+ ColumnSpecsBuilder.with_defaults().build(),
+ conflict_resolution: Optional[mysql_common.ConflictResolution] = None):
+ self.connector_config = _MySQLConnectorConfig.from_base_config(
+ connection_config)
+ super().__init__(
+ connection_config=self.connector_config.to_connection_config(),
+ write_config=write_config,
+ table_name=table_name,
+ column_specs=column_specs,
+ conflict_resolution=conflict_resolution)
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py
b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py
index 959e4cadb13..7ae49ba5182 100644
--- a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py
+++ b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py
@@ -15,209 +15,1001 @@
# limitations under the License.
#
+import json
import logging
import os
import secrets
import time
import unittest
+from dataclasses import dataclass
+from typing import Any
+from typing import List
+from typing import Literal
+from typing import Optional
import pytest
import sqlalchemy
from google.cloud.sql.connector import Connector
+from parameterized import parameterized
from sqlalchemy import text
import apache_beam as beam
from apache_beam.io.jdbc import ReadFromJdbc
+from apache_beam.ml.rag.ingestion import mysql_common
+from apache_beam.ml.rag.ingestion import postgres_common
from apache_beam.ml.rag.ingestion import test_utils
from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteTransform
+from apache_beam.ml.rag.ingestion.cloudsql import
CloudSQLMySQLVectorWriterConfig
from apache_beam.ml.rag.ingestion.cloudsql import
CloudSQLPostgresVectorWriterConfig
from apache_beam.ml.rag.ingestion.cloudsql import LanguageConnectorConfig
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Content
+from apache_beam.ml.rag.types import Embedding
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+_LOGGER = logging.getLogger(__name__)
+
+
+@dataclass
+class DatabaseTestConfig:
+ """Database-specific test configuration."""
+ database_type: Literal["postgresql", "mysql"]
+ writer_config_class: type
+ jdbc_driver: str
+ connector_module: Literal["pg8000", "pymysql"]
+ table_prefix: str
+
+ password_env_var: str
+ username: str
+ database: str
+ instance_uri: str
+
+ vector_column_type: str
+ metadata_column_type: str
+ common_module: Any
+ id_column_type: str = "VARCHAR(255)"
+
+
+class DatabaseTestHelper:
+ """Helper class to manage database setup, connections, and operations."""
+ def __init__(self, db_config: DatabaseTestConfig, table_suffix: str):
+ self.db_config = db_config
+ self.table_suffix = table_suffix
+ self.connector = None
+ self.engine = None
+ self.connection_config = None
+
+ self.default_table_name = f"{db_config.table_prefix}{table_suffix}"
+ self.custom_table_name = f"{db_config.table_prefix}_custom_{table_suffix}"
+ self.metadata_conflicts_table = f"{db_config.table_prefix}_meta_conf_" \
+ f"{table_suffix}"
+
+ self._setup_read_queries()
+
+ def _setup_read_queries(self):
+ if self.db_config.database_type == "postgresql":
+ self.read_queries = {
+ self.default_table_name: f"""
+ SELECT
+ CAST(id AS VARCHAR(255)),
+ CAST(content AS VARCHAR(255)),
+ CAST(embedding AS text),
+ CAST(metadata AS text)
+ FROM {self.default_table_name}
+ """,
+ self.custom_table_name: f"""
+ SELECT
+ CAST(custom_id AS VARCHAR(255)),
+ CAST(embedding_vec AS text),
+ CAST(content_col AS VARCHAR(255)),
+ CAST(metadata AS text)
+ FROM {self.custom_table_name}
+ ORDER BY custom_id
+ """,
+ self.metadata_conflicts_table: f"""
+ SELECT
+ CAST(id AS VARCHAR(255)),
+ CAST(embedding AS text),
+ CAST(content AS VARCHAR(255)),
+ CAST(source AS VARCHAR(255)),
+ CAST(timestamp AS VARCHAR(255))
+ FROM {self.metadata_conflicts_table}
+ ORDER BY timestamp, id
+ """
+ }
+ elif self.db_config.database_type == "mysql":
+ self.read_queries = {
+ self.default_table_name: f"""
+ SELECT
+ CAST(id AS CHAR(255)) as id,
+ CAST(content AS CHAR(255)) as content,
+ vector_to_string(embedding) as embedding,
+ CAST(metadata AS CHAR(10000)) as metadata
+ FROM {self.default_table_name}
+ """,
+ self.custom_table_name: f"""
+ SELECT
+ CAST(custom_id AS CHAR(255)) as custom_id,
+ vector_to_string(embedding_vec) as embedding_vec,
+ CAST(content_col AS CHAR(255)) as content_col,
+ CAST(metadata AS CHAR(10000)) as metadata
+ FROM {self.custom_table_name}
+ ORDER BY custom_id
+ """,
+ self.metadata_conflicts_table: f"""
+ SELECT
+ CAST(id AS CHAR(255)) as id,
+ vector_to_string(embedding) as embedding,
+ CAST(content AS CHAR(255)) as content,
+ CAST(source AS CHAR(255)) as source,
+ CAST(timestamp AS CHAR(255)) as timestamp
+ FROM {self.metadata_conflicts_table}
+ ORDER BY timestamp, id
+ """
+ }
+
+ def get_read_query(self, table_name: str) -> str:
+ if table_name not in self.read_queries:
+ raise ValueError(f"No read query defined for table: {table_name}")
+ return self.read_queries[table_name]
+
+ def setup_connection(self):
+ """Set up database connection and engine."""
+ if not os.environ.get(self.db_config.password_env_var):
+ raise ValueError("Password environment variable not set.")
+ password = os.environ.get(self.db_config.password_env_var)
+
+ self.connection_config = LanguageConnectorConfig(
+ username=self.db_config.username,
+ password=password,
+ database_name=self.db_config.database,
+ instance_name=self.db_config.instance_uri)
+
+ self.connector = Connector(refresh_strategy="LAZY")
[email protected]_gcp_java_expansion_service
[email protected](
- os.environ.get('EXPANSION_JARS'),
- "EXPANSION_JARS environment var is not provided, "
- "indicating that jars have not been built")
[email protected](
- os.environ.get('ALLOYDB_PASSWORD'),
- "ALLOYDB_PASSWORD environment var is not provided")
-class CloudSQLPostgresVectorWriterConfigTest(unittest.TestCase):
- POSTGRES_TABLE_PREFIX = 'python_rag_postgres_'
-
- @classmethod
- def _create_engine(cls):
- """Create SQLAlchemy engine using Cloud SQL connector."""
def getconn():
- conn = cls.connector.connect(
- cls.instance_uri,
- "pg8000",
- user=cls.username,
- password=cls.password,
- db=cls.database,
+ return self.connector.connect(
+ self.db_config.instance_uri,
+ self.db_config.connector_module,
+ user=self.db_config.username,
+ password=password,
+ db=self.db_config.database,
)
- return conn
-
- engine = sqlalchemy.create_engine(
- "postgresql+pg8000://",
- creator=getconn,
- )
- return engine
-
- @classmethod
- def setUpClass(cls):
- cls.database = os.environ.get('POSTGRES_DATABASE', 'postgres')
- cls.username = os.environ.get('POSTGRES_USERNAME', 'postgres')
- if not os.environ.get('ALLOYDB_PASSWORD'):
- raise ValueError('ALLOYDB_PASSWORD env not set')
- cls.password = os.environ.get('ALLOYDB_PASSWORD')
- cls.instance_uri = os.environ.get(
- 'POSTGRES_INSTANCE_URI',
- 'apache-beam-testing:us-central1:beam-integration-tests')
-
- # Create unique table name suffix
- cls.table_suffix = '%d%s' % (int(time.time()), secrets.token_hex(3))
-
- # Setup database connection
- cls.connector = Connector(refresh_strategy="LAZY")
- cls.engine = cls._create_engine()
- def skip_if_dataflow_runner(self):
- if self._runner and "dataflowrunner" in self._runner.lower():
- self.skipTest(
- "Skipping some tests on Dataflow Runner to avoid bloat and timeouts")
+ dialect = "postgresql+pg8000" \
+ if self.db_config.database_type == "postgresql" else "mysql+pymysql"
+ self.engine = sqlalchemy.create_engine(f"{dialect}://", creator=getconn)
- def setUp(self):
- self.write_test_pipeline = TestPipeline(is_integration_test=True)
- self.read_test_pipeline = TestPipeline(is_integration_test=True)
- self._runner = type(self.read_test_pipeline.runner).__name__
+ def create_all_tables(self):
+ if not self.engine:
+ raise ValueError("Engine not initialized. Call setup_connection()
first.")
- self.default_table_name = f"{self.POSTGRES_TABLE_PREFIX}" \
- f"{self.table_suffix}"
+ vector_type_large = self.db_config.vector_column_type.format(
+ size=test_utils.VECTOR_SIZE)
+ vector_type_small = self.db_config.vector_column_type.format(size=2)
+ metadata_type = self.db_config.metadata_column_type
+ id_type = self.db_config.id_column_type
- # Create test table
with self.engine.connect() as connection:
- connection.execute(
- text(
- f"""
+ default_table_sql = f"""
CREATE TABLE {self.default_table_name} (
- id TEXT PRIMARY KEY,
- embedding VECTOR({test_utils.VECTOR_SIZE}),
+ id {id_type} PRIMARY KEY,
+ embedding {vector_type_large},
content TEXT,
- metadata JSONB
+ metadata {metadata_type}
)
- """))
- connection.commit()
- _LOGGER = logging.getLogger(__name__)
- _LOGGER.info("Created table %s", self.default_table_name)
+ """
+ connection.execute(text(default_table_sql))
- def tearDown(self):
- # Drop test table
- with self.engine.connect() as connection:
- connection.execute(
- text(f"DROP TABLE IF EXISTS {self.default_table_name}"))
+ custom_table_sql = f"""
+ CREATE TABLE {self.custom_table_name} (
+ custom_id {id_type} PRIMARY KEY,
+ embedding_vec {vector_type_small},
+ content_col TEXT,
+ metadata {metadata_type}
+ )
+ """
+ connection.execute(text(custom_table_sql))
+
+ if self.db_config.database_type == "postgresql":
+ metadata_conflicts_sql = f"""
+ CREATE TABLE {self.metadata_conflicts_table} (
+ id {id_type},
+ source TEXT,
+ timestamp TIMESTAMP,
+ content TEXT,
+ embedding {vector_type_small},
+ PRIMARY KEY (id),
+ UNIQUE (source, timestamp)
+ )
+ """
+ elif self.db_config.database_type == "mysql":
+ metadata_conflicts_sql = f"""
+ CREATE TABLE {self.metadata_conflicts_table} (
+ id {id_type},
+ source TEXT,
+ timestamp TIMESTAMP,
+ content TEXT,
+ embedding {vector_type_small},
+ PRIMARY KEY (id),
+ UNIQUE KEY unique_source_timestamp (source(255),
timestamp)
+ )
+ """
+ connection.execute(text(metadata_conflicts_sql))
connection.commit()
- _LOGGER = logging.getLogger(__name__)
- _LOGGER.info("Dropped table %s", self.default_table_name)
-
- @classmethod
- def tearDownClass(cls):
- if hasattr(cls, 'connector'):
- cls.connector.close()
- if hasattr(cls, 'engine'):
- cls.engine.dispose()
-
- def test_language_connector(self):
- """Test language connector."""
- self.skip_if_dataflow_runner()
- connection_config = LanguageConnectorConfig(
- username=self.username,
- password=self.password,
- database_name=self.database,
- instance_name=self.instance_uri)
- writer_config = CloudSQLPostgresVectorWriterConfig(
- connection_config=connection_config,
table_name=self.default_table_name)
+ def create_writer_config(
+ self,
+ table_name: Optional[str] = None,
+ column_specs=None,
+ conflict_resolution=None):
+ if not self.connection_config:
+ raise ValueError(
+ "Connection not initialized. Call setup_connection() first.")
- # Create test chunks
- num_records = 150
- sample_size = min(500, num_records // 2)
- chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records)
+ table_name = table_name or self.default_table_name
- self.write_test_pipeline.not_use_test_runner_api = True
+ kwargs = {
+ 'connection_config': self.connection_config,
+ 'table_name': table_name,
+ }
- with self.write_test_pipeline as p:
- _ = (
- p
- | beam.Create(chunks)
- | VectorDatabaseWriteTransform(writer_config))
+ if column_specs is not None:
+ kwargs['column_specs'] = column_specs
+ if conflict_resolution is not None:
+ kwargs['conflict_resolution'] = conflict_resolution
- self.read_test_pipeline.not_use_test_runner_api = True
- read_query = f"""
- SELECT
- CAST(id AS VARCHAR(255)),
- CAST(content AS VARCHAR(255)),
- CAST(embedding AS text),
- CAST(metadata AS text)
- FROM {self.default_table_name}
- """
+ return self.db_config.writer_config_class(**kwargs)
- with self.read_test_pipeline as p:
- rows = (
- p
- | ReadFromJdbc(
- table_name=self.default_table_name,
- driver_class_name="org.postgresql.Driver",
- jdbc_url=writer_config.connector_config.to_connection_config(
- ).jdbc_url,
- username=self.username,
- password=self.password,
- query=read_query,
- classpath=writer_config.connector_config.additional_jdbc_args()
- ['classpath']))
+ def cleanup(self):
+ if self.engine:
+ table_names = [
+ self.default_table_name,
+ self.custom_table_name,
+ self.metadata_conflicts_table
+ ]
+
+ try:
+ with self.engine.connect() as connection:
+ for table_name in table_names:
+ connection.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
+ connection.commit()
+ _LOGGER.info(
+ "Dropped %s tables: %s",
+ self.db_config.database_type,
+ ', '.join(table_names))
+ except Exception as e:
+ _LOGGER.warning(
+ "Error dropping %s tables: %s", self.db_config.database_type, e)
+
+ if self.connector:
+ try:
+ self.connector.close()
+ except Exception as e:
+ _LOGGER.warning("Error closing connector: %s", e)
+
+ if self.engine:
+ try:
+ self.engine.dispose()
+ except Exception as e:
+ _LOGGER.warning("Error disposing engine: %s", e)
+
+
+class PipelineVerificationHelper:
+ """Helper class for common pipeline verification patterns."""
+ @staticmethod
+ def build_jdbc_params(helper: DatabaseTestHelper, table_name: str) -> dict:
+ """Build JDBC parameters dictionary for ReadFromJdbc."""
+ writer_config = helper.create_writer_config(table_name)
+
+ return {
+ 'table_name': table_name,
+ 'driver_class_name': helper.db_config.jdbc_driver,
+ 'jdbc_url': writer_config.connector_config.to_connection_config().
+ jdbc_url,
+ 'username': helper.db_config.username,
+ 'password': helper.connection_config.password,
+ 'query': helper.get_read_query(table_name),
+ 'classpath': writer_config.connector_config.additional_jdbc_args()
+ ['classpath']
+ }
+ @staticmethod
+ def verify_standard_operations(
+ pipeline, jdbc_params: dict, expected_chunks: List[Chunk]):
+ num_records = len(expected_chunks)
+ sample_size = min(500, num_records // 2)
+
+ with pipeline as p:
+ rows = (p | ReadFromJdbc(**jdbc_params))
+
+ # Count verification
count_result = rows | "Count All" >> beam.combiners.Count.Globally()
assert_that(count_result, equal_to([num_records]), label='count_check')
+ # Hash verification
chunks = (rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk))
chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally(
test_utils.HashingFn())
- assert_that(
- chunk_hashes,
- equal_to([test_utils.generate_expected_hash(num_records)]),
- label='hash_check')
+ expected_hash = test_utils.generate_expected_hash(num_records)
+ assert_that(chunk_hashes, equal_to([expected_hash]), label='hash_check')
- # Sample validation
+ # Sample validation - first N
first_n = (
chunks
| "Key on Index" >> beam.Map(test_utils.key_on_id)
| f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of(
sample_size, key=lambda x: x[0], reverse=True)
| "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs]))
- expected_first_n = test_utils.ChunkTestUtils.get_expected_values(
- 0, sample_size)
+ expected_first_n = expected_chunks[:sample_size]
assert_that(
first_n,
equal_to([expected_first_n]),
label=f"first_{sample_size}_check")
+ # Sample validation - last N
last_n = (
chunks
| "Key on Index 2" >> beam.Map(test_utils.key_on_id)
| f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of(
sample_size, key=lambda x: x[0])
| "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs]))
- expected_last_n = test_utils.ChunkTestUtils.get_expected_values(
- num_records - sample_size, num_records)[::-1]
+ expected_last_n = expected_chunks[-sample_size:][::-1]
assert_that(
last_n,
equal_to([expected_last_n]),
label=f"last_{sample_size}_check")
+# Database configurations
+POSTGRES_CONFIG = DatabaseTestConfig(
+ database_type="postgresql",
+ writer_config_class=CloudSQLPostgresVectorWriterConfig,
+ jdbc_driver="org.postgresql.Driver",
+ connector_module="pg8000",
+ table_prefix="python_rag_postgres_",
+ password_env_var="ALLOYDB_PASSWORD",
+ username="postgres",
+ database="postgres",
+ instance_uri="apache-beam-testing:us-central1:beam-integration-tests",
+ vector_column_type="VECTOR({size})",
+ metadata_column_type="JSONB",
+ common_module=postgres_common)
+
+MYSQL_CONFIG = DatabaseTestConfig(
+ database_type="mysql",
+ writer_config_class=CloudSQLMySQLVectorWriterConfig,
+ jdbc_driver="com.mysql.cj.jdbc.Driver",
+ connector_module="pymysql",
+ table_prefix="python_rag_mysql_",
+ password_env_var="ALLOYDB_PASSWORD",
+ username="mysql",
+ database="embeddings",
+
instance_uri="apache-beam-testing:us-central1:beam-integration-tests-mysql",
+ vector_column_type="VECTOR({size}) USING VARBINARY",
+ metadata_column_type="JSON",
+ common_module=mysql_common)
+
+
[email protected]_gcp_java_expansion_service
[email protected](
+ os.environ.get('EXPANSION_JARS'),
+ "EXPANSION_JARS environment var is not provided, "
+ "indicating that jars have not been built")
+class CloudSQLVectorWriterConfigTest(unittest.TestCase):
+ def setUp(self):
+ self.write_test_pipeline = TestPipeline(is_integration_test=True)
+ self.read_test_pipeline = TestPipeline(is_integration_test=True)
+ self.write_test_pipeline2 = TestPipeline(is_integration_test=True)
+ self.read_test_pipeline2 = TestPipeline(is_integration_test=True)
+
+ self.write_test_pipeline.not_use_test_runner_api = True
+ self.read_test_pipeline.not_use_test_runner_api = True
+ self.write_test_pipeline2.not_use_test_runner_api = True
+ self.read_test_pipeline2.not_use_test_runner_api = True
+ self._runner = type(self.read_test_pipeline.runner).__name__
+
+ self.db_helpers = {}
+ self.table_suffix = '%d%s' % (int(time.time()), secrets.token_hex(3))
+
+ # Set up database helpers
+ for config in [POSTGRES_CONFIG, MYSQL_CONFIG]:
+ helper = DatabaseTestHelper(config, self.table_suffix)
+ helper.setup_connection()
+ helper.create_all_tables()
+ self.db_helpers[config.database_type] = helper
+ _LOGGER.info("Successfully set up %s database", config.database_type)
+
+ def tearDown(self):
+ for helper in self.db_helpers.values():
+ helper.cleanup()
+
+ def skip_if_dataflow_runner(self):
+ if self._runner and "dataflowrunner" in self._runner.lower():
+ self.skipTest(
+ "Skipping some tests on Dataflow Runner to avoid bloat and timeouts")
+
+ @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+ def test_default_config(self, db_config):
+ """Test basic write and read operations with default configuration.
+
+ This test validates the most basic CloudSQL vector database
functionality:
+ - Default table schema: id (VARCHAR), content (TEXT), embedding (VECTOR),
+ metadata (JSON/JSONB)
+ - Default column specifications (no customization)
+ - Default conflict resolution (IGNORE on primary key conflicts)
+ - Write chunks to database and read them back
+ - Verify data integrity through count, hash, and sample validation
+ """
+ self.skip_if_dataflow_runner()
+
+ helper = self.db_helpers[db_config.database_type]
+ num_records = 150
+
+ # Create test data
+ test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records)
+
+ # Write test
+ writer_config = helper.create_writer_config()
+ self.write_test_pipeline.not_use_test_runner_api = True
+ with self.write_test_pipeline as p:
+ _ = (
+ p | beam.Create(test_chunks)
+ | VectorDatabaseWriteTransform(writer_config))
+
+ # Read and verify
+ self.read_test_pipeline.not_use_test_runner_api = True
+ jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+ helper, helper.default_table_name)
+ PipelineVerificationHelper.verify_standard_operations(
+ self.read_test_pipeline, jdbc_params, test_chunks)
+
+ @parameterized.expand([
+ (POSTGRES_CONFIG, "UPDATE", ["embedding", "content"]),
+ (MYSQL_CONFIG, "UPDATE", ["embedding", "content"]),
+ (POSTGRES_CONFIG, "IGNORE", None),
+ (MYSQL_CONFIG, "IGNORE", None),
+ (POSTGRES_CONFIG, "UPDATE_ALL", None), # Default update fields
+ (MYSQL_CONFIG, "UPDATE_ALL", None),
+ ])
+ def test_conflict_resolution(self, db_config, action, update_fields):
+ """Test conflict resolution strategies when primary key conflicts occur.
+
+ This test validates different approaches to handling duplicate primary
+ keys:
+
+ UPDATE with specific fields:
+ - When duplicate ID encountered, update only specified fields (embedding,
+ content)
+ - Other fields (metadata) remain unchanged from original record
+
+ IGNORE:
+ - When duplicate ID encountered, keep original record unchanged
+
+ UPDATE_ALL (default update fields):
+ - When duplicate ID encountered, update ALL non-key fields
+ - This includes content, embedding, AND metadata
+
+ Scenario for all strategies:
+ 1. Insert initial records
+ 2. Insert records with same IDs but different content/embeddings
+ 3. Verify final state matches expected conflict resolution behavior
+ """
+ self.skip_if_dataflow_runner()
+
+ helper = self.db_helpers[db_config.database_type]
+ num_records = 20
+
+ common_module = db_config.common_module
+ if action == "IGNORE":
+ if db_config.database_type == "mysql":
+ conflict_resolution = common_module.ConflictResolution(
+ action="IGNORE", primary_key_field="id")
+ else:
+ conflict_resolution = None # Default behavior for PostgreSQL
+ elif action == "UPDATE":
+ if db_config.database_type == "postgresql":
+ conflict_resolution = common_module.ConflictResolution(
+ on_conflict_fields="id",
+ action="UPDATE",
+ update_fields=update_fields)
+ else:
+ conflict_resolution = common_module.ConflictResolution(
+ action="UPDATE", update_fields=update_fields)
+ else: # UPDATE_ALL
+ if db_config.database_type == "postgresql":
+ conflict_resolution = common_module.ConflictResolution(
+ on_conflict_fields="id", action="UPDATE")
+ else:
+ conflict_resolution = common_module.ConflictResolution(action="UPDATE")
+
+ initial_chunks = test_utils.ChunkTestUtils.get_expected_values(
+ 0, num_records)
+ writer_config = helper.create_writer_config(
+ conflict_resolution=conflict_resolution)
+
+ self.write_test_pipeline.not_use_test_runner_api = True
+ with self.write_test_pipeline as p:
+ _ = (
+ p | "Write Initial" >> beam.Create(initial_chunks)
+ | VectorDatabaseWriteTransform(writer_config))
+
+ # Write conflicting data
+ updated_chunks = test_utils.ChunkTestUtils.get_expected_values(
+ 0, num_records, content_prefix="Updated", seed_multiplier=2)
+
+ self.write_test_pipeline2.not_use_test_runner_api = True
+ with self.write_test_pipeline2 as p:
+ _ = (
+ p | "Write Conflicts" >> beam.Create(updated_chunks)
+ | VectorDatabaseWriteTransform(writer_config))
+
+ jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+ helper, helper.default_table_name)
+ expected_chunks = updated_chunks if action != "IGNORE" else initial_chunks
+
+ self.read_test_pipeline.not_use_test_runner_api = True
+ with self.read_test_pipeline as p:
+ rows = (p | ReadFromJdbc(**jdbc_params))
+
+ count_result = rows | "Count All" >> beam.combiners.Count.Globally()
+ assert_that(count_result, equal_to([num_records]), label='count_check')
+
+ chunks = rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk)
+ assert_that(chunks, equal_to(expected_chunks), label='chunks_check')
+
+ @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+ def test_custom_column_names_and_value_functions(self, db_config):
+ """Test completely custom column specifications with custom value
+ extraction.
+
+ This test validates advanced customization of how chunk data is stored:
+
+ Custom column names:
+ - custom_id (instead of 'id')
+ - embedding_vec (instead of 'embedding')
+ - content_col (instead of 'content')
+
+ Custom value extraction functions:
+ - ID: Extract timestamp from metadata and prefix with "timestamp_"
+ - Content: Prefix content with its character length "10:actual_content"
+ - Embedding: Use custom embedding extraction function
+
+ This tests the flexibility to completely reshape how chunk data maps
+ to database columns, useful for integrating with existing database
schemas
+ or applying business-specific transformations.
+ """
+ self.skip_if_dataflow_runner()
+
+ helper = self.db_helpers[db_config.database_type]
+ num_records = 20
+ common_module = db_config.common_module
+
+ test_chunks = [
+ Chunk(
+ id=str(i),
+ content=Content(text=f"content_{i}"),
+ embedding=Embedding(dense_embedding=[float(i), float(i + 1)]),
+ metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"})
+ for i in range(num_records)
+ ]
+
+ chunk_embedding_fn = common_module.chunk_embedding_fn
+ specs = (
+ common_module.ColumnSpecsBuilder().add_custom_column_spec(
+ common_module.ColumnSpec.text(
+ column_name="custom_id",
+ value_fn=lambda chunk:
+ f"timestamp_{chunk.metadata.get('timestamp', '')}")
+ ).add_custom_column_spec(
+ common_module.ColumnSpec.vector(
+ column_name="embedding_vec",
+ value_fn=chunk_embedding_fn)).add_custom_column_spec(
+ common_module.ColumnSpec.text(
+ column_name="content_col",
+ value_fn=lambda chunk:
+ f"{len(chunk.content.text)}:{chunk.content.text}")).
+ with_metadata_spec().build())
+
+ def custom_row_to_chunk(row):
+ timestamp = row.custom_id.split('timestamp_')[1]
+ i = int(timestamp.split('T')[1][:2])
+
+ embedding_list = [
+ float(x) for x in row.embedding_vec.strip('[]').split(',')
+ ]
+
+ content = row.content_col.split(':', 1)[1]
+
+ return Chunk(
+ id=str(i),
+ content=Content(text=content),
+ embedding=Embedding(dense_embedding=embedding_list),
+ metadata=json.loads(row.metadata))
+
+ writer_config = helper.create_writer_config(helper.custom_table_name,
specs)
+ self.write_test_pipeline.not_use_test_runner_api = True
+ with self.write_test_pipeline as p:
+ _ = (
+ p | beam.Create(test_chunks)
+ | VectorDatabaseWriteTransform(writer_config))
+
+ jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+ helper, helper.custom_table_name)
+
+ self.read_test_pipeline.not_use_test_runner_api = True
+ with self.read_test_pipeline as p:
+ rows = (p | ReadFromJdbc(**jdbc_params))
+
+ count_result = rows | "Count All" >> beam.combiners.Count.Globally()
+ assert_that(count_result, equal_to([num_records]), label='count_check')
+
+ chunks = rows | "To Chunks" >> beam.Map(custom_row_to_chunk)
+ assert_that(chunks, equal_to(test_chunks), label='chunks_check')
+
+ @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+ def test_custom_type_conversion_with_default_columns(self, db_config):
+ """Test custom type conversion and SQL typecasting with modified column
+ names.
+
+ This test validates data type handling and database-specific SQL
features:
+
+ Type conversion:
+ - Convert string IDs to integers before storage
+ - Apply length-prefix transformation to content
+
+ SQL typecasting (database-specific):
+ - PostgreSQL: Use ::text typecast for converted integers
+ - MySQL: Rely on automatic type conversion (no explicit typecast)
+
+ Column name customization:
+ - Use custom names but with standard spec builders (not completely custom
+ functions)
+
+ This tests the ability to adapt data types for database constraints
+ while maintaining the standard chunk-to-database mapping logic.
+ """
+ self.skip_if_dataflow_runner()
+
+ helper = self.db_helpers[db_config.database_type]
+ num_records = 20
+ common_module = db_config.common_module
+
+ test_chunks = [
+ Chunk(
+ id=str(i),
+ content=Content(text=f"content_{i}"),
+ embedding=Embedding(dense_embedding=[float(i), float(i + 1)]),
+ metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"})
+ for i in range(num_records)
+ ]
+
+ if db_config.database_type == "postgresql":
+ specs = (
+ common_module.ColumnSpecsBuilder().with_id_spec(
+ column_name="custom_id",
+ python_type=int,
+ convert_fn=lambda x: int(x),
+ sql_typecast="::text").with_content_spec(
+ column_name="content_col",
+ convert_fn=lambda x: f"{len(x)}:{x}" # Add length prefix
+ ).with_embedding_spec(
+ column_name="embedding_vec").with_metadata_spec().build())
+ else: # MySQL
+ specs = (
+ common_module.ColumnSpecsBuilder().with_id_spec(
+ column_name="custom_id",
+ python_type=int,
+ convert_fn=lambda x: int(x)).with_content_spec(
+ column_name="content_col",
+ convert_fn=lambda x: f"{len(x)}:{x}").with_embedding_spec(
+
column_name="embedding_vec").with_metadata_spec().build())
+
+ def type_conversion_row_to_chunk(row):
+ embedding_list = [
+ float(x) for x in row.embedding_vec.strip('[]').split(',')
+ ]
+
+ content = row.content_col.split(':', 1)[1]
+
+ return Chunk(
+ id=row.custom_id, # custom_id is the converted ID field
+ content=Content(text=content),
+ embedding=Embedding(dense_embedding=embedding_list),
+ metadata=json.loads(row.metadata))
+
+ writer_config = helper.create_writer_config(helper.custom_table_name,
specs)
+ self.write_test_pipeline.not_use_test_runner_api = True
+ with self.write_test_pipeline as p:
+ _ = (
+ p | beam.Create(test_chunks)
+ | VectorDatabaseWriteTransform(writer_config))
+
+ jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+ helper, helper.custom_table_name)
+
+ self.read_test_pipeline.not_use_test_runner_api = True
+ with self.read_test_pipeline as p:
+ rows = (p | ReadFromJdbc(**jdbc_params))
+
+ count_result = rows | "Count All" >> beam.combiners.Count.Globally()
+ assert_that(count_result, equal_to([num_records]), label='count_check')
+
+ chunks = rows | "To Chunks" >> beam.Map(type_conversion_row_to_chunk)
+ assert_that(chunks, equal_to(test_chunks), label='chunks_check')
+
+ @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+ def test_default_id_embedding_specs(self, db_config):
+ """Test minimal schema with only ID and embedding columns.
+
+ This test validates the ability to create a minimal vector database
+ schema:
+ - Only stores id and embedding fields
+ - content and metadata columns are excluded from the table
+ - Tests that the system correctly handles missing/null fields
+
+ Use case: When you only need vector similarity search without storing
+ the original content or metadata (perhaps stored elsewhere).
+
+ Validation:
+ - Chunks written with content/metadata are stored with those fields as
+ null
+ - Reading back produces chunks with null content and empty metadata
+ - Vector similarity operations still work normally
+ """
+ self.skip_if_dataflow_runner()
+
+ helper = self.db_helpers[db_config.database_type]
+ num_records = 20
+ common_module = db_config.common_module
+
+ specs = (
+
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+ build())
+
+ writer_config = helper.create_writer_config(column_specs=specs)
+
+ test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records)
+
+ with self.write_test_pipeline as p:
+ _ = (
+ p | beam.Create(test_chunks)
+ | VectorDatabaseWriteTransform(writer_config))
+
+ expected_chunks = test_utils.ChunkTestUtils.get_expected_values(
+ 0, num_records)
+ for chunk in expected_chunks:
+ chunk.content.text = None # Content column not included in schema
+ chunk.metadata = {} # Metadata column not included in schema
+
+ jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+ helper, helper.default_table_name)
+ if db_config.database_type == "postgresql":
+ jdbc_params['query'] = f"""
+ SELECT
+ CAST(id AS VARCHAR(255)),
+ CAST(embedding AS text)
+ FROM {helper.default_table_name}
+ ORDER BY id
+ """
+ elif db_config.database_type == "mysql":
+ jdbc_params['query'] = f"""
+ SELECT
+ CAST(id AS CHAR(255)) as id,
+ vector_to_string(embedding) as embedding
+ FROM {helper.default_table_name}
+ """
+
+ with self.read_test_pipeline as p:
+ rows = (p | ReadFromJdbc(**jdbc_params))
+ chunks = rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk)
+ assert_that(chunks, equal_to(expected_chunks), label='chunks_check')
+
+ @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+ def test_metadata_field_extraction(self, db_config):
+ """Test extracting specific metadata fields into separate database columns.
+
+ This test validates the ability to:
+ - Extract specific fields from the JSON metadata object
+ - Map them to dedicated database columns (e.g., metadata.source -> source
+ column)
+ - Apply database-specific SQL typecasts (PostgreSQL ::timestamp vs MySQL
+ default)
+ - Store and retrieve the extracted fields correctly
+
+ This is different from default metadata handling which stores the entire
+ metadata object as JSON in a single column.
+ """
+ self.skip_if_dataflow_runner()
+
+ helper = self.db_helpers[db_config.database_type]
+ num_records = 20
+ common_module = db_config.common_module
+
+ if db_config.database_type == "postgresql":
+ specs = (
+
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+ ).with_content_spec().add_metadata_field(
+ field="source",
+ column_name="source",
+ python_type=str,
+ sql_typecast=None).add_metadata_field(
+ field="timestamp",
+ python_type=str,
+ sql_typecast="::timestamp").build())
+ else:
+ specs = (
+
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+ ).with_content_spec().add_metadata_field(
+ field="source", column_name="source",
+ python_type=str).add_metadata_field(
+ field="timestamp", python_type=str).build())
+
+ writer_config = helper.create_writer_config(
+ helper.metadata_conflicts_table, specs, conflict_resolution=None)
+
+ test_chunks = [
+ Chunk(
+ id=str(i),
+ content=Content(text=f"content_{i}"),
+ embedding=Embedding(dense_embedding=[float(i), float(i + 1)]),
+ metadata={
+ "source": f"source_{i % 3}",
+ "timestamp": f"2024-02-02T{i:02d}:00:00"
+ }) for i in range(num_records)
+ ]
+
+ self.write_test_pipeline.not_use_test_runner_api = True
+ with self.write_test_pipeline as p:
+ _ = (
+ p | beam.Create(test_chunks)
+ | VectorDatabaseWriteTransform(writer_config))
+
+ def metadata_row_to_chunk(row):
+ embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')]
+ timestamp = row.timestamp.replace(
+ ' ', 'T') if ' ' in row.timestamp else row.timestamp
+ return Chunk(
+ id=row.id,
+ content=Content(text=row.content),
+ embedding=Embedding(dense_embedding=embedding_list),
+ metadata={
+ "source": row.source, "timestamp": timestamp
+ })
+
+ jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+ helper, helper.metadata_conflicts_table)
+
+ self.read_test_pipeline.not_use_test_runner_api = True
+ with self.read_test_pipeline as p:
+ rows = (p | ReadFromJdbc(**jdbc_params))
+ chunks = rows | "To Chunks" >> beam.Map(metadata_row_to_chunk)
+ assert_that(chunks, equal_to(test_chunks), label='chunks_check')
+
+ @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+ def test_composite_unique_constraint_conflicts(self, db_config):
+ """Test conflict resolution when unique constraints span multiple columns.
+
+ This test validates conflict resolution when the unique constraint is
NOT
+ on the primary key, but on a combination of other columns (source +
+ timestamp).
+
+ Scenario:
+ 1. Insert records with unique (source, timestamp) combinations
+ 2. Attempt to insert records with same (source, timestamp) but different
+ IDs and content
+ 3. Verify that conflict resolution (UPDATE) works correctly based on
+ composite key
+
+ This is different from test_conflict_resolution which tests conflicts on
+ the primary key field only.
+ """
+ self.skip_if_dataflow_runner()
+
+ helper = self.db_helpers[db_config.database_type]
+ num_records = 5
+ common_module = db_config.common_module
+
+ if db_config.database_type == "postgresql":
+ specs = (
+
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+ ).with_content_spec().add_metadata_field(
+ field="source",
+ column_name="source",
+ python_type=str,
+ sql_typecast=None).add_metadata_field(
+ field="timestamp",
+ python_type=str,
+ sql_typecast="::timestamp").build())
+
+ conflict_resolution = common_module.ConflictResolution(
+ on_conflict_fields=["source", "timestamp"],
+ action="UPDATE",
+ update_fields=["embedding", "content"])
+ elif db_config.database_type == "mysql":
+ specs = (
+
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+ ).with_content_spec().add_metadata_field(
+ field="source", column_name="source",
+ python_type=str).add_metadata_field(
+ field="timestamp", python_type=str).build())
+
+ # MySQL conflict resolution - detects unique constraint automatically
+ conflict_resolution = common_module.ConflictResolution(
+ action="UPDATE", update_fields=["embedding", "content"])
+
+ writer_config = helper.create_writer_config(
+ helper.metadata_conflicts_table, specs, conflict_resolution)
+
+ initial_chunks = [
+ Chunk(
+ id=str(i),
+ content=Content(text=f"content_{i}"),
+ embedding=Embedding(dense_embedding=[float(i), float(i + 1)]),
+ metadata={
+ "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00"
+ }) for i in range(num_records)
+ ]
+
+ with self.write_test_pipeline as p:
+ _ = (
+ p | "Write Initial" >> beam.Create(initial_chunks)
+ | VectorDatabaseWriteTransform(writer_config))
+
+ conflicting_chunks = [
+ Chunk(
+ id=f"new_{i}",
+ content=Content(text=f"updated_content_{i}"),
+ embedding=Embedding(
+ dense_embedding=[float(i) * 2, float(i + 1) * 2]),
+ metadata={
+ "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00"
+ }) for i in range(num_records)
+ ]
+
+ with self.write_test_pipeline2 as p:
+ _ = (
+ p | "Write Conflicts" >> beam.Create(conflicting_chunks)
+ | VectorDatabaseWriteTransform(writer_config))
+
+ expected_chunks = [
+ Chunk(
+ id=str(i),
+ content=Content(text=f"updated_content_{i}"),
+ embedding=Embedding(
+ dense_embedding=[float(i) * 2, float(i + 1) * 2]),
+ metadata={
+ "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00"
+ }) for i in range(num_records)
+ ]
+
+ def metadata_row_to_chunk(row):
+ embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')]
+ timestamp = row.timestamp.replace(
+ ' ', 'T') if ' ' in row.timestamp else row.timestamp
+ return Chunk(
+ id=row.id,
+ content=Content(text=row.content),
+ embedding=Embedding(dense_embedding=embedding_list),
+ metadata={
+ "source": row.source, "timestamp": timestamp
+ })
+
+ jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+ helper, helper.metadata_conflicts_table)
+
+ with self.read_test_pipeline as p:
+ rows = (p | ReadFromJdbc(**jdbc_params))
+
+ count_result = rows | "Count All" >> beam.combiners.Count.Globally()
+ assert_that(count_result, equal_to([num_records]), label='count_check')
+
+ chunks = rows | "To Chunks" >> beam.Map(metadata_row_to_chunk)
+ assert_that(chunks, equal_to(expected_chunks), label='chunks_check')
+
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql.py
b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py
new file mode 100644
index 00000000000..c64c083b6c9
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py
@@ -0,0 +1,268 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from abc import ABC
+from abc import abstractmethod
+from typing import Callable
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+
+import apache_beam as beam
+from apache_beam.coders import registry
+from apache_beam.coders.row_coder import RowCoder
+from apache_beam.io.jdbc import WriteToJdbc
+from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
+from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
+from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
+from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpec
+from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpecsBuilder
+from apache_beam.ml.rag.ingestion.mysql_common import ConflictResolution
+from apache_beam.ml.rag.types import Chunk
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class _ConflictResolutionStrategy(ABC):
+ """Abstract base class for conflict resolution strategies."""
+ @abstractmethod
+ def get_conflict_clause(self, all_columns: List[str]) -> str:
+ """Generate the MySQL conflict clause."""
+ pass
+
+
+class _NoConflictStrategy(_ConflictResolutionStrategy):
+ """Strategy for when no conflict resolution is needed."""
+ def get_conflict_clause(self, all_columns: List[str]) -> str:
+ return ""
+
+
+class _UpdateStrategy(_ConflictResolutionStrategy):
+ """Strategy for UPDATE action on conflict."""
+ def __init__(self, update_fields: Optional[List[str]] = None):
+ self.update_fields = update_fields
+
+ def get_conflict_clause(self, all_columns: List[str]) -> str:
+ # Use provided fields or default to all columns
+ fields_to_update = self.update_fields or all_columns
+ assert len(fields_to_update) > 0
+
+ updates = [f"{field} = VALUES({field})" for field in fields_to_update]
+ return f"ON DUPLICATE KEY UPDATE {', '.join(updates)}"
+
+
+class _IgnoreStrategy(_ConflictResolutionStrategy):
+ """Strategy for IGNORE action on conflict."""
+ def __init__(self, primary_key_field: str):
+ self.primary_key_field = primary_key_field
+
+ def get_conflict_clause(self, all_columns: List[str]) -> str:
+ return f"ON DUPLICATE KEY UPDATE {self.primary_key_field}"\
+ f" = {self.primary_key_field}"
+
+
+def _create_conflict_strategy(
+ conflict_resolution: Optional[ConflictResolution]
+) -> _ConflictResolutionStrategy:
+ if conflict_resolution is None:
+ return _NoConflictStrategy()
+ if conflict_resolution.action == "UPDATE":
+ return _UpdateStrategy(conflict_resolution.update_fields)
+ if conflict_resolution.action == "IGNORE":
+ assert conflict_resolution.primary_key_field is not None
+ return _IgnoreStrategy(conflict_resolution.primary_key_field)
+ raise ValueError(f"Unknown conflict resolution {conflict_resolution.action}")
+
+
+class _MySQLQueryBuilder:
+ def __init__(
+ self,
+ table_name: str,
+ *,
+ column_specs: List[ColumnSpec],
+ conflict_resolution: Optional[ConflictResolution] = None):
+ """Builds SQL queries for writing Chunks with Embeddings to MySQL.
+ """
+ self.table_name = table_name
+
+ self.column_specs = column_specs
+ self.conflict_resolution_strategy = _create_conflict_strategy(
+ conflict_resolution)
+
+ names = [col.column_name for col in self.column_specs]
+ duplicates = set(name for name in names if names.count(name) > 1)
+ if duplicates:
+ raise ValueError(f"Duplicate column names found: {duplicates}")
+
+ fields = [(col.column_name, col.python_type) for col in self.column_specs]
+ type_name = f"VectorRecord_{table_name}"
+ self.record_type = NamedTuple(type_name, fields) # type: ignore
+
+ registry.register_coder(self.record_type, RowCoder)
+
+ def build_insert(self) -> str:
+ fields = [col.column_name for col in self.column_specs]
+ placeholders = [col.placeholder for col in self.column_specs]
+
+ # Build base query
+ query = f"""
+ INSERT INTO {self.table_name}
+ ({', '.join(fields)})
+ VALUES ({', '.join(placeholders)})
+ """
+ conflict_clause = self.conflict_resolution_strategy.get_conflict_clause(
+ all_columns=fields)
+ query += f" {conflict_clause}"
+
+ _LOGGER.info("MySQL Query with placeholders %s", query)
+ return query
+
+ def create_converter(self) -> Callable[[Chunk], NamedTuple]:
+ """Creates a function to convert Chunks to records."""
+ def convert(chunk: Chunk) -> self.record_type: # type: ignore
+ return self.record_type(
+ **{col.column_name: col.value_fn(chunk)
+ for col in self.column_specs}) # type: ignore
+
+ return convert
+
+
+class MySQLVectorWriterConfig(VectorDatabaseWriteConfig):
+ def __init__(
+ self,
+ connection_config: ConnectionConfig,
+ table_name: str,
+ *,
+ # pylint: disable=dangerous-default-value
+ write_config: WriteConfig = WriteConfig(),
+ column_specs: List[ColumnSpec] =
ColumnSpecsBuilder.with_defaults().build(
+ ),
+ conflict_resolution: Optional[ConflictResolution] = None):
+ """Configuration for writing vectors to MySQL using jdbc.
+
+ Supports flexible schema configuration through column specifications and
+ conflict resolution strategies with MySQL-specific syntax.
+
+ Args:
+ connection_config:
+ :class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`.
+ table_name: Target table name.
+ write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control
+ batch sizes, authosharding, etc.
+ column_specs:
+ Use :class:`~.mysql_common.ColumnSpecsBuilder` to configure how
+ embeddings and metadata are written to the database
+ schema. If None, uses default Chunk schema with MySQL vector
+ functions.
+ conflict_resolution: Optional
+ :class:`~.mysql_common.ConflictResolution`
+ strategy for handling insert conflicts. ON DUPLICATE KEY UPDATE.
+ None by default, meaning errors are thrown when attempting to
insert
+ duplicates.
+
+ Examples:
+ Simple case with default schema:
+
+ >>> config = MySQLVectorWriterConfig(
+ ... connection_config=ConnectionConfig(...),
+ ... table_name='embeddings'
+ ... )
+
+ Custom schema with metadata fields and MySQL functions:
+
+ >>> specs = (ColumnSpecsBuilder()
+ ... .with_id_spec(column_name="my_id_column")
+ ... .with_embedding_spec(
+ ... column_name="embedding_vec",
+ ... placeholder="string_to_vector(?)"
+ ... )
+ ... .add_metadata_field(field="source", column_name="src")
+ ... .add_metadata_field(
+ ... "timestamp",
+ ... column_name="created_at",
+ ... placeholder="STR_TO_DATE(?, '%Y-%m-%d %H:%i:%s')"
+ ... )
+ ... .build())
+
+ Minimal schema (only ID + embedding written):
+
+ >>> column_specs = (ColumnSpecsBuilder()
+ ... .with_id_spec()
+ ... .with_embedding_spec()
+ ... .build())
+
+ >>> config = MySQLVectorWriterConfig(
+ ... connection_config=ConnectionConfig(...),
+ ... table_name='embeddings',
+ ... column_specs=specs,
+ ... conflict_resolution=ConflictResolution(
+ ... on_conflict_fields=["id"],
+ ... action="UPDATE",
+ ... update_fields=["embedding", "content"]
+ ... )
+ ... )
+
+ Using MySQL JSON functions:
+
+ >>> specs = (ColumnSpecsBuilder()
+ ... .with_id_spec()
+ ... .with_embedding_spec()
+ ... .with_metadata_spec(
+ ... column_name="metadata_json",
+ ... placeholder="CAST(? AS JSON)"
+ ... )
+ ... .build())
+ """
+ self.connection_config = connection_config
+ self.write_config = write_config
+ # NamedTuple is created and registered here during pipeline construction
+ self.query_builder = _MySQLQueryBuilder(
+ table_name,
+ column_specs=column_specs,
+ conflict_resolution=conflict_resolution)
+
+ def create_write_transform(self) -> beam.PTransform:
+ return _WriteToMySQLVectorDatabase(self)
+
+
+class _WriteToMySQLVectorDatabase(beam.PTransform):
+ """Implementation of MySQL vector database write."""
+ def __init__(self, config: MySQLVectorWriterConfig):
+ self.config = config
+ self.query_builder = config.query_builder
+ self.connection_config = config.connection_config
+ self.write_config = config.write_config
+
+ def expand(self, pcoll: beam.PCollection[Chunk]):
+ return (
+ pcoll
+ |
+ "Convert to Records" >> beam.Map(self.query_builder.create_converter())
+ | "Write to MySQL" >> WriteToJdbc(
+ table_name=self.query_builder.table_name,
+ driver_class_name="com.mysql.cj.jdbc.Driver",
+ jdbc_url=self.connection_config.jdbc_url,
+ username=self.connection_config.username,
+ password=self.connection_config.password,
+ statement=self.query_builder.build_insert(),
+ connection_properties=self.connection_config.connection_properties,
+ connection_init_sqls=self.connection_config.connection_init_sqls,
+ autosharding=self.write_config.autosharding,
+ max_connections=self.write_config.max_connections,
+ write_batch_size=self.write_config.write_batch_size,
+ **self.connection_config.additional_jdbc_args))
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py
b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py
new file mode 100644
index 00000000000..c1ee703a5f2
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py
@@ -0,0 +1,433 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from dataclasses import dataclass
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Type
+
+from apache_beam.ml.rag.types import Chunk
+
+
+def chunk_embedding_fn(chunk: Chunk) -> str:
+ """Convert embedding to MySQL vector string format.
+
+ Formats dense embedding as a MySQL-compatible vector string.
+ Example: [1.0, 2.0] -> '[1.0,2.0]'
+
+ Args:
+ chunk: Input Chunk object.
+
+ Returns:
+ str: MySQL vector string representation of the embedding.
+
+ Raises:
+ ValueError: If chunk has no dense embedding.
+ """
+ if chunk.embedding is None or chunk.embedding.dense_embedding is None:
+ raise ValueError(f'Expected chunk to contain embedding. {chunk}')
+ return '[' + ','.join(str(x) for x in chunk.embedding.dense_embedding) + ']'
+
+
+@dataclass
+class ColumnSpec:
+ """Specification for mapping Chunk fields to MySQL columns for insertion.
+
+ Defines how to extract and format values from Chunks into MySQL database
+ columns, handling the full pipeline from Python value to SQL insertion.
+
+ The insertion process works as follows:
+ - value_fn extracts a value from the Chunk and formats it as needed
+ - The value is stored in a NamedTuple field with the specified python_type
+ - During SQL insertion, the value is bound to a ? placeholder
+
+ Attributes:
+ column_name: The column name in the database table.
+ python_type: Python type for the NamedTuple field that will hold the
+ value. Must be compatible with
+ :class:`~apache_beam.coders.row_coder.RowCoder`.
+ value_fn: Function to extract and format the value from a Chunk.
+ Takes a Chunk and returns a value of python_type.
+ placeholder: Optional placeholder to apply typecasts or functions to
+ value ? placeholder e.g. "string_to_vector(?)" for vector columns.
+
+ Examples:
+
+ Basic text column (uses standard JDBC type mapping):
+
+ >>> ColumnSpec.text(
+ ... column_name="content",
+ ... value_fn=lambda chunk: chunk.content.text
+ ... )
+ ... # Results in: INSERT INTO table (content) VALUES (?)
+
+ Timestamp from metadata:
+
+ >>> ColumnSpec(
+ ... column_name="created_at",
+ ... python_type=str,
+ ... value_fn=lambda chunk: chunk.metadata.get("timestamp")
+ ... )
+ ... # Results in: INSERT INTO table (created_at) VALUES (?)
+
+
+ Factory Methods:
+ text: Creates a text column specification.
+ integer: Creates an integer column specification.
+ float: Creates a float column specification.
+ vector: Creates a vector column specification with string_to_vector().
+ json: Creates a JSON column specification.
+ """
+ column_name: str
+ python_type: Type
+ value_fn: Callable[[Chunk], Any]
+ placeholder: str = '?'
+
+ @classmethod
+ def text(
+ cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec':
+ """Create a text column specification."""
+ return cls(column_name, str, value_fn)
+
+ @classmethod
+ def integer(
+ cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec':
+ """Create an integer column specification."""
+ return cls(column_name, int, value_fn)
+
+ @classmethod
+ def float(
+ cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec':
+ """Create a float column specification."""
+ return cls(column_name, float, value_fn)
+
+ @classmethod
+ def vector(
+ cls,
+ column_name: str,
+ value_fn: Callable[[Chunk], Any] = chunk_embedding_fn) -> 'ColumnSpec':
+ """Create a vector column specification with string_to_vector()
function."""
+ return cls(column_name, str, value_fn, "string_to_vector(?)")
+
+ @classmethod
+ def json(
+ cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec':
+ """Create a JSON column specification."""
+ return cls(column_name, str, value_fn)
+
+
+def embedding_to_string(embedding: List[float]) -> str:
+ """Convert embedding to MySQL vector string format."""
+ return '[' + ','.join(str(x) for x in embedding) + ']'
+
+
+class ColumnSpecsBuilder:
+ """Builder for :class:`.ColumnSpec`'s with chainable methods."""
+ def __init__(self):
+ self._specs: List[ColumnSpec] = []
+
+ @staticmethod
+ def with_defaults() -> 'ColumnSpecsBuilder':
+ """Add all default column specifications."""
+ return (
+ ColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+ with_content_spec().with_metadata_spec())
+
+ def with_id_spec(
+ self,
+ column_name: str = "id",
+ python_type: Type = str,
+ convert_fn: Optional[Callable[[str],
+ Any]] = None) -> 'ColumnSpecsBuilder':
+ """Add ID :class:`.ColumnSpec` with optional type and conversion.
+
+ Args:
+ column_name: Name for the ID column (defaults to "id")
+ python_type: Python type for the column (defaults to str)
+ convert_fn: Optional function to convert the chunk ID
+ If None, uses ID as-is
+
+ Returns:
+ Self for method chaining
+
+ Example:
+ >>> builder.with_id_spec(
+ ... column_name="doc_id",
+ ... python_type=int,
+ ... convert_fn=lambda id: int(id.split('_')[1])
+ ... )
+ """
+ def value_fn(chunk: Chunk) -> Any:
+ value = chunk.id
+ return convert_fn(value) if convert_fn else value
+
+ self._specs.append(
+ ColumnSpec(
+ column_name=column_name, python_type=python_type,
+ value_fn=value_fn))
+ return self
+
+ def with_content_spec(
+ self,
+ column_name: str = "content",
+ python_type: Type = str,
+ convert_fn: Optional[Callable[[str],
+ Any]] = None) -> 'ColumnSpecsBuilder':
+ """Add content :class:`.ColumnSpec` with optional type and conversion.
+
+ Args:
+ column_name: Name for the content column (defaults to "content")
+ python_type: Python type for the column (defaults to str)
+ convert_fn: Optional function to convert the content text
+ If None, uses content text as-is
+
+ Returns:
+ Self for method chaining
+
+ Example:
+ >>> builder.with_content_spec(
+ ... column_name="content_length",
+ ... python_type=int,
+ ... convert_fn=len # Store content length instead of content
+ ... )
+ """
+ def value_fn(chunk: Chunk) -> Any:
+ if chunk.content.text is None:
+ raise ValueError(f'Expected chunk to contain content. {chunk}')
+ value = chunk.content.text
+ return convert_fn(value) if convert_fn else value
+
+ self._specs.append(
+ ColumnSpec(
+ column_name=column_name, python_type=python_type,
+ value_fn=value_fn))
+ return self
+
+ def with_metadata_spec(
+ self,
+ column_name: str = "metadata",
+ python_type: Type = str,
+ convert_fn: Optional[Callable[[Dict[str, Any]], Any]] = None
+ ) -> 'ColumnSpecsBuilder':
+ """Add metadata :class:`.ColumnSpec` with optional type and conversion.
+
+ Args:
+ column_name: Name for the metadata column (defaults to "metadata")
+ python_type: Python type for the column (defaults to str)
+ convert_fn: Optional function to convert the metadata dictionary
+ If None and python_type is str, converts to JSON string
+
+ Returns:
+ Self for method chaining
+
+ Example:
+ >>> builder.with_metadata_spec(
+ ... column_name="meta_tags",
+ ... python_type=str,
+ ... convert_fn=lambda meta: ','.join(meta.keys())
+ ... )
+ """
+ def value_fn(chunk: Chunk) -> Any:
+ if convert_fn:
+ return convert_fn(chunk.metadata)
+ return json.dumps(
+ chunk.metadata) if python_type == str else chunk.metadata
+
+ self._specs.append(
+ ColumnSpec(
+ column_name=column_name, python_type=python_type,
+ value_fn=value_fn))
+ return self
+
+ def with_embedding_spec(
+ self,
+ column_name: str = "embedding",
+ convert_fn: Callable[[List[float]], Any] = embedding_to_string
+ ) -> 'ColumnSpecsBuilder':
+ """Add embedding :class:`.ColumnSpec` with optional conversion.
+
+ Args:
+ column_name: Name for the embedding column (defaults to "embedding")
+ convert_fn: Optional function to convert the dense embedding values
+ If None, uses default MySQL vector format
+
+ Returns:
+ Self for method chaining
+
+ Example:
+ >>> builder.with_embedding_spec(
+ ... column_name="embedding_vector",
+ ... convert_fn=lambda values: '[' + ','.join(f"{x:.4f}"
+ ... for x in values) + ']'
+ ... )
+ """
+ def value_fn(chunk: Chunk) -> Any:
+ if chunk.embedding is None or chunk.embedding.dense_embedding is None:
+ raise ValueError(f'Expected chunk to contain embedding. {chunk}')
+ values = chunk.embedding.dense_embedding
+ return convert_fn(values)
+
+ self._specs.append(
+ ColumnSpec.vector(column_name=column_name, value_fn=value_fn))
+ return self
+
+ def add_metadata_field(
+ self,
+ field: str,
+ python_type: Type,
+ column_name: Optional[str] = None,
+ convert_fn: Optional[Callable[[Any], Any]] = None,
+ default: Any = None) -> 'ColumnSpecsBuilder':
+ """Add a :class:`.ColumnSpec` that extracts and converts a field from
+ chunk metadata.
+
+ Args:
+ field: Key to extract from chunk metadata
+ python_type: Python type for the column (e.g. str, int, float)
+ column_name: Name for the column (defaults to metadata field name)
+ convert_fn: Optional function to convert the extracted value to
+ desired type. If None, value is used as-is
+ default: Default value if field is missing from metadata
+
+ Returns:
+ Self for chaining
+
+ Examples:
+ Simple string field:
+ >>> builder.add_metadata_field("source", str)
+
+ Integer with default:
+ >>> builder.add_metadata_field(
+ ... field="count",
+ ... python_type=int,
+ ... column_name="item_count",
+ ... default=0
+ ... )
+
+ Float with conversion and default:
+ >>> builder.add_metadata_field(
+ ... field="confidence",
+ ... python_type=float,
+ ... convert_fn=lambda x: round(float(x), 2),
+ ... default=0.0
+ ... )
+
+ Timestamp with conversion:
+ >>> builder.add_metadata_field(
+ ... field="created_at",
+ ... python_type=str,
+ ... convert_fn=lambda ts: ts.replace('T', ' ')
+ ... )
+ """
+ name = column_name or field
+
+ def value_fn(chunk: Chunk) -> Any:
+ value = chunk.metadata.get(field, default)
+ if value is not None and convert_fn is not None:
+ value = convert_fn(value)
+ return value
+
+ spec = ColumnSpec(
+ column_name=name, python_type=python_type, value_fn=value_fn)
+
+ self._specs.append(spec)
+ return self
+
+ def add_custom_column_spec(self, spec: ColumnSpec) -> 'ColumnSpecsBuilder':
+ """Add a custom :class:`.ColumnSpec` to the builder.
+
+ Use this method when you need complete control over the
+ :class:`.ColumnSpec`, including custom value extraction and type handling.
+
+ Args:
+ spec: A :class:`.ColumnSpec` instance defining the column name, type,
+ value extraction, and optional MySQL function.
+
+ Returns:
+ Self for method chaining
+
+ Examples:
+ Custom text column from chunk metadata:
+ >>> builder.add_custom_column_spec(
+ ... ColumnSpec.text(
+ ... column_name="source_and_id",
+ ... value_fn=lambda chunk:
+ ... f"{chunk.metadata.get('source')}_{chunk.id}"
+ ... )
+ ... )
+ """
+ self._specs.append(spec)
+ return self
+
+ def build(self) -> List[ColumnSpec]:
+ """Build the final list of column specifications."""
+ return self._specs.copy()
+
+
+@dataclass
+class ConflictResolution:
+ """Specification for how to handle conflicts during insert.
+
+ Configures conflict handling behavior when inserting records that may
+ violate unique constraints using MySQL's ON DUPLICATE KEY UPDATE syntax.
+
+ MySQL automatically detects conflicts based on PRIMARY KEY or UNIQUE
+ constraints defined on the table.
+
+ Attributes:
+ action: How to handle conflicts - either "UPDATE" or "IGNORE".
+ UPDATE: Updates existing record with new values.
+ IGNORE: Skips conflicting records (uses no-op update).
+ update_fields: Optional list of fields to update on conflict. If None,
+ all fields are updated (for UPDATE action only).
+ primary_key_field: Required for IGNORE action. The primary key field
+ name to use for the no-op update.
+
+ Examples:
+ Update all fields on conflict:
+ >>> ConflictResolution(action="UPDATE")
+
+ Update specific fields on conflict:
+ >>> ConflictResolution(
+ ... action="UPDATE",
+ ... update_fields=["embedding", "content"]
+ ... )
+
+ Ignore conflicts with explicit primary key:
+ >>> ConflictResolution(
+ ... action="IGNORE",
+ ... primary_key_field="id"
+ ... )
+
+ Ignore conflicts with custom primary key:
+ >>> ConflictResolution(
+ ... action="IGNORE",
+ ... primary_key_field="custom_id"
+ ... )
+ """
+ action: Literal["UPDATE", "IGNORE"] = "UPDATE"
+ update_fields: Optional[List[str]] = None
+ primary_key_field: Optional[str] = None
+
+ def __post_init__(self):
+ """Validate configuration after initialization."""
+ if self.action == "IGNORE" and self.primary_key_field is None:
+ raise ValueError("primary_key_field is required when action='IGNORE'")
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py
b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py
index cd30766a288..0373874c09d 100644
--- a/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py
+++ b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py
@@ -18,21 +18,12 @@
import hashlib
import json
from typing import List
-from typing import NamedTuple
import apache_beam as beam
-from apache_beam.coders import registry
-from apache_beam.coders.row_coder import RowCoder
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
-TestRow = NamedTuple(
- 'TestRow',
- [('id', str), ('embedding', List[float]), ('content', str),
- ('metadata', str)])
-registry.register_coder(TestRow, RowCoder)
-
VECTOR_SIZE = 768