This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 28fd2b2c9df Mysql embeddings (#35393)
28fd2b2c9df is described below

commit 28fd2b2c9dfd7286c16e8da4504e2dc7f3605984
Author: claudevdm <[email protected]>
AuthorDate: Mon Jul 7 16:54:10 2025 -0400

    Mysql embeddings (#35393)
    
    * Add MySQL vector writer.
    
    * Trigger tests again.
    
    * Comments.
    
    * Fix lints etc.
    
    * Comment.
    
    * Fix typo
    
    * Lint fix.
    
    * Update sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py
    
    Co-authored-by: Danny McCormick <[email protected]>
    
    ---------
    
    Co-authored-by: Claude <[email protected]>
    Co-authored-by: Danny McCormick <[email protected]>
---
 .../schemaio-expansion-service/build.gradle        |    2 +
 .../apache_beam/ml/rag/ingestion/cloudsql.py       |   61 +-
 .../ml/rag/ingestion/cloudsql_it_test.py           | 1056 +++++++++++++++++---
 sdks/python/apache_beam/ml/rag/ingestion/mysql.py  |  268 +++++
 .../apache_beam/ml/rag/ingestion/mysql_common.py   |  433 ++++++++
 .../apache_beam/ml/rag/ingestion/test_utils.py     |    9 -
 6 files changed, 1679 insertions(+), 150 deletions(-)

diff --git a/sdks/java/extensions/schemaio-expansion-service/build.gradle 
b/sdks/java/extensions/schemaio-expansion-service/build.gradle
index 15873d58e61..12ee92a9e10 100644
--- a/sdks/java/extensions/schemaio-expansion-service/build.gradle
+++ b/sdks/java/extensions/schemaio-expansion-service/build.gradle
@@ -64,6 +64,8 @@ dependencies {
     permitUnusedDeclared 'com.google.cloud:alloydb-jdbc-connector:1.2.0'
     implementation 'com.google.cloud.sql:postgres-socket-factory:1.25.0'
     permitUnusedDeclared 'com.google.cloud.sql:postgres-socket-factory:1.25.0'
+    implementation 
'com.google.cloud.sql:mysql-socket-factory-connector-j-8:1.25.0'
+    permitUnusedDeclared 
'com.google.cloud.sql:mysql-socket-factory-connector-j-8:1.25.0'
     testImplementation library.java.junit
     testImplementation library.java.mockito_core
     runtimeOnly ("org.xerial:sqlite-jdbc:3.49.1.0")
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py 
b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py
index 69ead961a76..d3710a7f70a 100644
--- a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py
+++ b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py
@@ -21,12 +21,12 @@ from typing import Dict
 from typing import List
 from typing import Optional
 
+from apache_beam.ml.rag.ingestion import mysql
+from apache_beam.ml.rag.ingestion import mysql_common
+from apache_beam.ml.rag.ingestion import postgres
+from apache_beam.ml.rag.ingestion import postgres_common
 from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
 from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
-from apache_beam.ml.rag.ingestion.postgres import ColumnSpecsBuilder
-from apache_beam.ml.rag.ingestion.postgres import PostgresVectorWriterConfig
-from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec
-from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution
 
 
 @dataclass
@@ -138,7 +138,7 @@ class _PostgresConnectorConfig(LanguageConnectorConfig):
     return cls(**asdict(config))
 
 
-class CloudSQLPostgresVectorWriterConfig(PostgresVectorWriterConfig):
+class CloudSQLPostgresVectorWriterConfig(postgres.PostgresVectorWriterConfig):
   def __init__(
       self,
       connection_config: LanguageConnectorConfig,
@@ -146,10 +146,11 @@ class 
CloudSQLPostgresVectorWriterConfig(PostgresVectorWriterConfig):
       *,
       # pylint: disable=dangerous-default-value
       write_config: WriteConfig = WriteConfig(),
-      column_specs: List[ColumnSpec] = 
ColumnSpecsBuilder.with_defaults().build(
-      ),
-      conflict_resolution: Optional[ConflictResolution] = ConflictResolution(
-          on_conflict_fields=[], action='IGNORE')):
+      column_specs: List[postgres_common.ColumnSpec] = postgres_common.
+      ColumnSpecsBuilder.with_defaults().build(),
+      conflict_resolution: Optional[
+          postgres_common.ConflictResolution] = postgres_common.
+      ConflictResolution(on_conflict_fields=[], action='IGNORE')):
     """Configuration for writing vectors to ClouSQL Postgres.
     
     Supports flexible schema configuration through column specifications and
@@ -218,3 +219,45 @@ class 
CloudSQLPostgresVectorWriterConfig(PostgresVectorWriterConfig):
         table_name=table_name,
         column_specs=column_specs,
         conflict_resolution=conflict_resolution)
+
+
+@dataclass
+class _MySQLConnectorConfig(LanguageConnectorConfig):
+  def to_jdbc_url(self) -> str:
+    """Convert options to a properly formatted MySQL JDBC URL."""
+    return self._build_jdbc_url(
+        socketFactory="com.google.cloud.sql.mysql.SocketFactory",
+        database_type="mysql")
+
+  def additional_jdbc_args(self) -> Dict[str, List[Any]]:
+    return {
+        'classpath': [
+            "mysql:mysql-connector-java:8.0.22",
+            "com.google.cloud.sql:mysql-socket-factory-connector-j-8:1.25.0"
+        ]
+    }
+
+  @classmethod
+  def from_base_config(cls, config: LanguageConnectorConfig):
+    return cls(**asdict(config))
+
+
+class CloudSQLMySQLVectorWriterConfig(mysql.MySQLVectorWriterConfig):
+  def __init__(
+      self,
+      connection_config: LanguageConnectorConfig,
+      table_name: str,
+      *,
+      write_config: WriteConfig = WriteConfig(),
+      # pylint: disable=dangerous-default-value
+      column_specs: List[mysql_common.ColumnSpec] = mysql_common.
+      ColumnSpecsBuilder.with_defaults().build(),
+      conflict_resolution: Optional[mysql_common.ConflictResolution] = None):
+    self.connector_config = _MySQLConnectorConfig.from_base_config(
+        connection_config)
+    super().__init__(
+        connection_config=self.connector_config.to_connection_config(),
+        write_config=write_config,
+        table_name=table_name,
+        column_specs=column_specs,
+        conflict_resolution=conflict_resolution)
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py 
b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py
index 959e4cadb13..7ae49ba5182 100644
--- a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py
+++ b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py
@@ -15,209 +15,1001 @@
 # limitations under the License.
 #
 
+import json
 import logging
 import os
 import secrets
 import time
 import unittest
+from dataclasses import dataclass
+from typing import Any
+from typing import List
+from typing import Literal
+from typing import Optional
 
 import pytest
 import sqlalchemy
 from google.cloud.sql.connector import Connector
+from parameterized import parameterized
 from sqlalchemy import text
 
 import apache_beam as beam
 from apache_beam.io.jdbc import ReadFromJdbc
+from apache_beam.ml.rag.ingestion import mysql_common
+from apache_beam.ml.rag.ingestion import postgres_common
 from apache_beam.ml.rag.ingestion import test_utils
 from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteTransform
+from apache_beam.ml.rag.ingestion.cloudsql import 
CloudSQLMySQLVectorWriterConfig
 from apache_beam.ml.rag.ingestion.cloudsql import 
CloudSQLPostgresVectorWriterConfig
 from apache_beam.ml.rag.ingestion.cloudsql import LanguageConnectorConfig
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Content
+from apache_beam.ml.rag.types import Embedding
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 
+_LOGGER = logging.getLogger(__name__)
+
+
+@dataclass
+class DatabaseTestConfig:
+  """Database-specific test configuration."""
+  database_type: Literal["postgresql", "mysql"]
+  writer_config_class: type
+  jdbc_driver: str
+  connector_module: Literal["pg8000", "pymysql"]
+  table_prefix: str
+
+  password_env_var: str
+  username: str
+  database: str
+  instance_uri: str
+
+  vector_column_type: str
+  metadata_column_type: str
+  common_module: Any
+  id_column_type: str = "VARCHAR(255)"
+
+
+class DatabaseTestHelper:
+  """Helper class to manage database setup, connections, and operations."""
+  def __init__(self, db_config: DatabaseTestConfig, table_suffix: str):
+    self.db_config = db_config
+    self.table_suffix = table_suffix
+    self.connector = None
+    self.engine = None
+    self.connection_config = None
+
+    self.default_table_name = f"{db_config.table_prefix}{table_suffix}"
+    self.custom_table_name = f"{db_config.table_prefix}_custom_{table_suffix}"
+    self.metadata_conflicts_table = f"{db_config.table_prefix}_meta_conf_" \
+      f"{table_suffix}"
+
+    self._setup_read_queries()
+
+  def _setup_read_queries(self):
+    if self.db_config.database_type == "postgresql":
+      self.read_queries = {
+          self.default_table_name: f"""
+                    SELECT 
+                        CAST(id AS VARCHAR(255)),
+                        CAST(content AS VARCHAR(255)),
+                        CAST(embedding AS text),
+                        CAST(metadata AS text)
+                    FROM {self.default_table_name}
+                """,
+          self.custom_table_name: f"""
+                    SELECT 
+                        CAST(custom_id AS VARCHAR(255)),
+                        CAST(embedding_vec AS text),
+                        CAST(content_col AS VARCHAR(255)),
+                        CAST(metadata AS text)
+                    FROM {self.custom_table_name}
+                    ORDER BY custom_id
+                """,
+          self.metadata_conflicts_table: f"""
+                    SELECT 
+                        CAST(id AS VARCHAR(255)),
+                        CAST(embedding AS text),
+                        CAST(content AS VARCHAR(255)),
+                        CAST(source AS VARCHAR(255)),
+                        CAST(timestamp AS VARCHAR(255))
+                    FROM {self.metadata_conflicts_table}
+                    ORDER BY timestamp, id
+                """
+      }
+    elif self.db_config.database_type == "mysql":
+      self.read_queries = {
+          self.default_table_name: f"""
+                    SELECT 
+                        CAST(id AS CHAR(255)) as id,
+                        CAST(content AS CHAR(255)) as content,
+                        vector_to_string(embedding) as embedding,
+                        CAST(metadata AS CHAR(10000)) as metadata
+                    FROM {self.default_table_name}
+                """,
+          self.custom_table_name: f"""
+                    SELECT 
+                        CAST(custom_id AS CHAR(255)) as custom_id,
+                        vector_to_string(embedding_vec) as embedding_vec,
+                        CAST(content_col AS CHAR(255)) as content_col,
+                        CAST(metadata AS CHAR(10000)) as metadata
+                    FROM {self.custom_table_name}
+                    ORDER BY custom_id
+                """,
+          self.metadata_conflicts_table: f"""
+                    SELECT 
+                        CAST(id AS CHAR(255)) as id,
+                        vector_to_string(embedding) as embedding,
+                        CAST(content AS CHAR(255)) as content,
+                        CAST(source AS CHAR(255)) as source,
+                        CAST(timestamp AS CHAR(255)) as timestamp
+                    FROM {self.metadata_conflicts_table}
+                    ORDER BY timestamp, id
+                """
+      }
+
+  def get_read_query(self, table_name: str) -> str:
+    if table_name not in self.read_queries:
+      raise ValueError(f"No read query defined for table: {table_name}")
+    return self.read_queries[table_name]
+
+  def setup_connection(self):
+    """Set up database connection and engine."""
+    if not os.environ.get(self.db_config.password_env_var):
+      raise ValueError("Password environment variable not set.")
+    password = os.environ.get(self.db_config.password_env_var)
+
+    self.connection_config = LanguageConnectorConfig(
+        username=self.db_config.username,
+        password=password,
+        database_name=self.db_config.database,
+        instance_name=self.db_config.instance_uri)
+
+    self.connector = Connector(refresh_strategy="LAZY")
 
[email protected]_gcp_java_expansion_service
[email protected](
-    os.environ.get('EXPANSION_JARS'),
-    "EXPANSION_JARS environment var is not provided, "
-    "indicating that jars have not been built")
[email protected](
-    os.environ.get('ALLOYDB_PASSWORD'),
-    "ALLOYDB_PASSWORD environment var is not provided")
-class CloudSQLPostgresVectorWriterConfigTest(unittest.TestCase):
-  POSTGRES_TABLE_PREFIX = 'python_rag_postgres_'
-
-  @classmethod
-  def _create_engine(cls):
-    """Create SQLAlchemy engine using Cloud SQL connector."""
     def getconn():
-      conn = cls.connector.connect(
-          cls.instance_uri,
-          "pg8000",
-          user=cls.username,
-          password=cls.password,
-          db=cls.database,
+      return self.connector.connect(
+          self.db_config.instance_uri,
+          self.db_config.connector_module,
+          user=self.db_config.username,
+          password=password,
+          db=self.db_config.database,
       )
-      return conn
-
-    engine = sqlalchemy.create_engine(
-        "postgresql+pg8000://",
-        creator=getconn,
-    )
-    return engine
-
-  @classmethod
-  def setUpClass(cls):
-    cls.database = os.environ.get('POSTGRES_DATABASE', 'postgres')
-    cls.username = os.environ.get('POSTGRES_USERNAME', 'postgres')
-    if not os.environ.get('ALLOYDB_PASSWORD'):
-      raise ValueError('ALLOYDB_PASSWORD env not set')
-    cls.password = os.environ.get('ALLOYDB_PASSWORD')
-    cls.instance_uri = os.environ.get(
-        'POSTGRES_INSTANCE_URI',
-        'apache-beam-testing:us-central1:beam-integration-tests')
-
-    # Create unique table name suffix
-    cls.table_suffix = '%d%s' % (int(time.time()), secrets.token_hex(3))
-
-    # Setup database connection
-    cls.connector = Connector(refresh_strategy="LAZY")
-    cls.engine = cls._create_engine()
 
-  def skip_if_dataflow_runner(self):
-    if self._runner and "dataflowrunner" in self._runner.lower():
-      self.skipTest(
-          "Skipping some tests on Dataflow Runner to avoid bloat and timeouts")
+    dialect = "postgresql+pg8000" \
+      if self.db_config.database_type == "postgresql" else "mysql+pymysql"
+    self.engine = sqlalchemy.create_engine(f"{dialect}://", creator=getconn)
 
-  def setUp(self):
-    self.write_test_pipeline = TestPipeline(is_integration_test=True)
-    self.read_test_pipeline = TestPipeline(is_integration_test=True)
-    self._runner = type(self.read_test_pipeline.runner).__name__
+  def create_all_tables(self):
+    if not self.engine:
+      raise ValueError("Engine not initialized. Call setup_connection() 
first.")
 
-    self.default_table_name = f"{self.POSTGRES_TABLE_PREFIX}" \
-      f"{self.table_suffix}"
+    vector_type_large = self.db_config.vector_column_type.format(
+        size=test_utils.VECTOR_SIZE)
+    vector_type_small = self.db_config.vector_column_type.format(size=2)
+    metadata_type = self.db_config.metadata_column_type
+    id_type = self.db_config.id_column_type
 
-    # Create test table
     with self.engine.connect() as connection:
-      connection.execute(
-          text(
-              f"""
+      default_table_sql = f"""
                 CREATE TABLE {self.default_table_name} (
-                    id TEXT PRIMARY KEY,
-                    embedding VECTOR({test_utils.VECTOR_SIZE}),
+                    id {id_type} PRIMARY KEY,
+                    embedding {vector_type_large},
                     content TEXT,
-                    metadata JSONB
+                    metadata {metadata_type}
                 )
-            """))
-      connection.commit()
-    _LOGGER = logging.getLogger(__name__)
-    _LOGGER.info("Created table %s", self.default_table_name)
+            """
+      connection.execute(text(default_table_sql))
 
-  def tearDown(self):
-    # Drop test table
-    with self.engine.connect() as connection:
-      connection.execute(
-          text(f"DROP TABLE IF EXISTS {self.default_table_name}"))
+      custom_table_sql = f"""
+                CREATE TABLE {self.custom_table_name} (
+                    custom_id {id_type} PRIMARY KEY,
+                    embedding_vec {vector_type_small},
+                    content_col TEXT,
+                    metadata {metadata_type}
+                )
+            """
+      connection.execute(text(custom_table_sql))
+
+      if self.db_config.database_type == "postgresql":
+        metadata_conflicts_sql = f"""
+                    CREATE TABLE {self.metadata_conflicts_table} (
+                        id {id_type},
+                        source TEXT,
+                        timestamp TIMESTAMP,
+                        content TEXT,
+                        embedding {vector_type_small},
+                        PRIMARY KEY (id),
+                        UNIQUE (source, timestamp)
+                    )
+                """
+      elif self.db_config.database_type == "mysql":
+        metadata_conflicts_sql = f"""
+                    CREATE TABLE {self.metadata_conflicts_table} (
+                        id {id_type},
+                        source TEXT,
+                        timestamp TIMESTAMP,
+                        content TEXT,
+                        embedding {vector_type_small},
+                        PRIMARY KEY (id),
+                        UNIQUE KEY unique_source_timestamp (source(255), 
timestamp)
+                    )
+                """
+      connection.execute(text(metadata_conflicts_sql))
       connection.commit()
-    _LOGGER = logging.getLogger(__name__)
-    _LOGGER.info("Dropped table %s", self.default_table_name)
-
-  @classmethod
-  def tearDownClass(cls):
-    if hasattr(cls, 'connector'):
-      cls.connector.close()
-    if hasattr(cls, 'engine'):
-      cls.engine.dispose()
-
-  def test_language_connector(self):
-    """Test language connector."""
-    self.skip_if_dataflow_runner()
 
-    connection_config = LanguageConnectorConfig(
-        username=self.username,
-        password=self.password,
-        database_name=self.database,
-        instance_name=self.instance_uri)
-    writer_config = CloudSQLPostgresVectorWriterConfig(
-        connection_config=connection_config, 
table_name=self.default_table_name)
+  def create_writer_config(
+      self,
+      table_name: Optional[str] = None,
+      column_specs=None,
+      conflict_resolution=None):
+    if not self.connection_config:
+      raise ValueError(
+          "Connection not initialized. Call setup_connection() first.")
 
-    # Create test chunks
-    num_records = 150
-    sample_size = min(500, num_records // 2)
-    chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records)
+    table_name = table_name or self.default_table_name
 
-    self.write_test_pipeline.not_use_test_runner_api = True
+    kwargs = {
+        'connection_config': self.connection_config,
+        'table_name': table_name,
+    }
 
-    with self.write_test_pipeline as p:
-      _ = (
-          p
-          | beam.Create(chunks)
-          | VectorDatabaseWriteTransform(writer_config))
+    if column_specs is not None:
+      kwargs['column_specs'] = column_specs
+    if conflict_resolution is not None:
+      kwargs['conflict_resolution'] = conflict_resolution
 
-    self.read_test_pipeline.not_use_test_runner_api = True
-    read_query = f"""
-          SELECT 
-              CAST(id AS VARCHAR(255)),
-              CAST(content AS VARCHAR(255)),
-              CAST(embedding AS text),
-              CAST(metadata AS text)
-          FROM {self.default_table_name}
-          """
+    return self.db_config.writer_config_class(**kwargs)
 
-    with self.read_test_pipeline as p:
-      rows = (
-          p
-          | ReadFromJdbc(
-              table_name=self.default_table_name,
-              driver_class_name="org.postgresql.Driver",
-              jdbc_url=writer_config.connector_config.to_connection_config(
-              ).jdbc_url,
-              username=self.username,
-              password=self.password,
-              query=read_query,
-              classpath=writer_config.connector_config.additional_jdbc_args()
-              ['classpath']))
+  def cleanup(self):
+    if self.engine:
+      table_names = [
+          self.default_table_name,
+          self.custom_table_name,
+          self.metadata_conflicts_table
+      ]
+
+      try:
+        with self.engine.connect() as connection:
+          for table_name in table_names:
+            connection.execute(text(f"DROP TABLE IF EXISTS {table_name}"))
+          connection.commit()
+        _LOGGER.info(
+            "Dropped %s tables: %s",
+            self.db_config.database_type,
+            ', '.join(table_names))
+      except Exception as e:
+        _LOGGER.warning(
+            "Error dropping %s tables: %s", self.db_config.database_type, e)
+
+    if self.connector:
+      try:
+        self.connector.close()
+      except Exception as e:
+        _LOGGER.warning("Error closing connector: %s", e)
+
+    if self.engine:
+      try:
+        self.engine.dispose()
+      except Exception as e:
+        _LOGGER.warning("Error disposing engine: %s", e)
+
+
+class PipelineVerificationHelper:
+  """Helper class for common pipeline verification patterns."""
+  @staticmethod
+  def build_jdbc_params(helper: DatabaseTestHelper, table_name: str) -> dict:
+    """Build JDBC parameters dictionary for ReadFromJdbc."""
+    writer_config = helper.create_writer_config(table_name)
+
+    return {
+        'table_name': table_name,
+        'driver_class_name': helper.db_config.jdbc_driver,
+        'jdbc_url': writer_config.connector_config.to_connection_config().
+        jdbc_url,
+        'username': helper.db_config.username,
+        'password': helper.connection_config.password,
+        'query': helper.get_read_query(table_name),
+        'classpath': writer_config.connector_config.additional_jdbc_args()
+        ['classpath']
+    }
 
+  @staticmethod
+  def verify_standard_operations(
+      pipeline, jdbc_params: dict, expected_chunks: List[Chunk]):
+    num_records = len(expected_chunks)
+    sample_size = min(500, num_records // 2)
+
+    with pipeline as p:
+      rows = (p | ReadFromJdbc(**jdbc_params))
+
+      # Count verification
       count_result = rows | "Count All" >> beam.combiners.Count.Globally()
       assert_that(count_result, equal_to([num_records]), label='count_check')
 
+      # Hash verification
       chunks = (rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk))
       chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally(
           test_utils.HashingFn())
-      assert_that(
-          chunk_hashes,
-          equal_to([test_utils.generate_expected_hash(num_records)]),
-          label='hash_check')
+      expected_hash = test_utils.generate_expected_hash(num_records)
+      assert_that(chunk_hashes, equal_to([expected_hash]), label='hash_check')
 
-      # Sample validation
+      # Sample validation - first N
       first_n = (
           chunks
           | "Key on Index" >> beam.Map(test_utils.key_on_id)
           | f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of(
               sample_size, key=lambda x: x[0], reverse=True)
           | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs]))
-      expected_first_n = test_utils.ChunkTestUtils.get_expected_values(
-          0, sample_size)
+      expected_first_n = expected_chunks[:sample_size]
       assert_that(
           first_n,
           equal_to([expected_first_n]),
           label=f"first_{sample_size}_check")
 
+      # Sample validation - last N
       last_n = (
           chunks
           | "Key on Index 2" >> beam.Map(test_utils.key_on_id)
           | f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of(
               sample_size, key=lambda x: x[0])
           | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs]))
-      expected_last_n = test_utils.ChunkTestUtils.get_expected_values(
-          num_records - sample_size, num_records)[::-1]
+      expected_last_n = expected_chunks[-sample_size:][::-1]
       assert_that(
           last_n,
           equal_to([expected_last_n]),
           label=f"last_{sample_size}_check")
 
 
+# Database configurations
+POSTGRES_CONFIG = DatabaseTestConfig(
+    database_type="postgresql",
+    writer_config_class=CloudSQLPostgresVectorWriterConfig,
+    jdbc_driver="org.postgresql.Driver",
+    connector_module="pg8000",
+    table_prefix="python_rag_postgres_",
+    password_env_var="ALLOYDB_PASSWORD",
+    username="postgres",
+    database="postgres",
+    instance_uri="apache-beam-testing:us-central1:beam-integration-tests",
+    vector_column_type="VECTOR({size})",
+    metadata_column_type="JSONB",
+    common_module=postgres_common)
+
+MYSQL_CONFIG = DatabaseTestConfig(
+    database_type="mysql",
+    writer_config_class=CloudSQLMySQLVectorWriterConfig,
+    jdbc_driver="com.mysql.cj.jdbc.Driver",
+    connector_module="pymysql",
+    table_prefix="python_rag_mysql_",
+    password_env_var="ALLOYDB_PASSWORD",
+    username="mysql",
+    database="embeddings",
+    
instance_uri="apache-beam-testing:us-central1:beam-integration-tests-mysql",
+    vector_column_type="VECTOR({size}) USING VARBINARY",
+    metadata_column_type="JSON",
+    common_module=mysql_common)
+
+
[email protected]_gcp_java_expansion_service
[email protected](
+    os.environ.get('EXPANSION_JARS'),
+    "EXPANSION_JARS environment var is not provided, "
+    "indicating that jars have not been built")
+class CloudSQLVectorWriterConfigTest(unittest.TestCase):
+  def setUp(self):
+    self.write_test_pipeline = TestPipeline(is_integration_test=True)
+    self.read_test_pipeline = TestPipeline(is_integration_test=True)
+    self.write_test_pipeline2 = TestPipeline(is_integration_test=True)
+    self.read_test_pipeline2 = TestPipeline(is_integration_test=True)
+
+    self.write_test_pipeline.not_use_test_runner_api = True
+    self.read_test_pipeline.not_use_test_runner_api = True
+    self.write_test_pipeline2.not_use_test_runner_api = True
+    self.read_test_pipeline2.not_use_test_runner_api = True
+    self._runner = type(self.read_test_pipeline.runner).__name__
+
+    self.db_helpers = {}
+    self.table_suffix = '%d%s' % (int(time.time()), secrets.token_hex(3))
+
+    # Set up database helpers
+    for config in [POSTGRES_CONFIG, MYSQL_CONFIG]:
+      helper = DatabaseTestHelper(config, self.table_suffix)
+      helper.setup_connection()
+      helper.create_all_tables()
+      self.db_helpers[config.database_type] = helper
+      _LOGGER.info("Successfully set up %s database", config.database_type)
+
+  def tearDown(self):
+    for helper in self.db_helpers.values():
+      helper.cleanup()
+
+  def skip_if_dataflow_runner(self):
+    if self._runner and "dataflowrunner" in self._runner.lower():
+      self.skipTest(
+          "Skipping some tests on Dataflow Runner to avoid bloat and timeouts")
+
+  @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+  def test_default_config(self, db_config):
+    """Test basic write and read operations with default configuration.
+      
+      This test validates the most basic CloudSQL vector database 
functionality:
+      - Default table schema: id (VARCHAR), content (TEXT), embedding (VECTOR),
+        metadata (JSON/JSONB) 
+      - Default column specifications (no customization)
+      - Default conflict resolution (IGNORE on primary key conflicts)
+      - Write chunks to database and read them back
+      - Verify data integrity through count, hash, and sample validation
+      """
+    self.skip_if_dataflow_runner()
+
+    helper = self.db_helpers[db_config.database_type]
+    num_records = 150
+
+    # Create test data
+    test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records)
+
+    # Write test
+    writer_config = helper.create_writer_config()
+    self.write_test_pipeline.not_use_test_runner_api = True
+    with self.write_test_pipeline as p:
+      _ = (
+          p | beam.Create(test_chunks)
+          | VectorDatabaseWriteTransform(writer_config))
+
+    # Read and verify
+    self.read_test_pipeline.not_use_test_runner_api = True
+    jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+        helper, helper.default_table_name)
+    PipelineVerificationHelper.verify_standard_operations(
+        self.read_test_pipeline, jdbc_params, test_chunks)
+
+  @parameterized.expand([
+      (POSTGRES_CONFIG, "UPDATE", ["embedding", "content"]),
+      (MYSQL_CONFIG, "UPDATE", ["embedding", "content"]),
+      (POSTGRES_CONFIG, "IGNORE", None),
+      (MYSQL_CONFIG, "IGNORE", None),
+      (POSTGRES_CONFIG, "UPDATE_ALL", None),  # Default update fields
+      (MYSQL_CONFIG, "UPDATE_ALL", None),
+  ])
+  def test_conflict_resolution(self, db_config, action, update_fields):
+    """Test conflict resolution strategies when primary key conflicts occur.
+    
+      This test validates different approaches to handling duplicate primary
+      keys:
+      
+      UPDATE with specific fields:
+      - When duplicate ID encountered, update only specified fields (embedding,
+        content)
+      - Other fields (metadata) remain unchanged from original record
+      
+      IGNORE:
+      - When duplicate ID encountered, keep original record unchanged
+      
+      UPDATE_ALL (default update fields):
+      - When duplicate ID encountered, update ALL non-key fields
+      - This includes content, embedding, AND metadata
+      
+      Scenario for all strategies:
+      1. Insert initial records
+      2. Insert records with same IDs but different content/embeddings  
+      3. Verify final state matches expected conflict resolution behavior
+      """
+    self.skip_if_dataflow_runner()
+
+    helper = self.db_helpers[db_config.database_type]
+    num_records = 20
+
+    common_module = db_config.common_module
+    if action == "IGNORE":
+      if db_config.database_type == "mysql":
+        conflict_resolution = common_module.ConflictResolution(
+            action="IGNORE", primary_key_field="id")
+      else:
+        conflict_resolution = None  # Default behavior for PostgreSQL
+    elif action == "UPDATE":
+      if db_config.database_type == "postgresql":
+        conflict_resolution = common_module.ConflictResolution(
+            on_conflict_fields="id",
+            action="UPDATE",
+            update_fields=update_fields)
+      else:
+        conflict_resolution = common_module.ConflictResolution(
+            action="UPDATE", update_fields=update_fields)
+    else:  # UPDATE_ALL
+      if db_config.database_type == "postgresql":
+        conflict_resolution = common_module.ConflictResolution(
+            on_conflict_fields="id", action="UPDATE")
+      else:
+        conflict_resolution = common_module.ConflictResolution(action="UPDATE")
+
+    initial_chunks = test_utils.ChunkTestUtils.get_expected_values(
+        0, num_records)
+    writer_config = helper.create_writer_config(
+        conflict_resolution=conflict_resolution)
+
+    self.write_test_pipeline.not_use_test_runner_api = True
+    with self.write_test_pipeline as p:
+      _ = (
+          p | "Write Initial" >> beam.Create(initial_chunks)
+          | VectorDatabaseWriteTransform(writer_config))
+
+    # Write conflicting data
+    updated_chunks = test_utils.ChunkTestUtils.get_expected_values(
+        0, num_records, content_prefix="Updated", seed_multiplier=2)
+
+    self.write_test_pipeline2.not_use_test_runner_api = True
+    with self.write_test_pipeline2 as p:
+      _ = (
+          p | "Write Conflicts" >> beam.Create(updated_chunks)
+          | VectorDatabaseWriteTransform(writer_config))
+
+    jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+        helper, helper.default_table_name)
+    expected_chunks = updated_chunks if action != "IGNORE" else initial_chunks
+
+    self.read_test_pipeline.not_use_test_runner_api = True
+    with self.read_test_pipeline as p:
+      rows = (p | ReadFromJdbc(**jdbc_params))
+
+      count_result = rows | "Count All" >> beam.combiners.Count.Globally()
+      assert_that(count_result, equal_to([num_records]), label='count_check')
+
+      chunks = rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk)
+      assert_that(chunks, equal_to(expected_chunks), label='chunks_check')
+
+  @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+  def test_custom_column_names_and_value_functions(self, db_config):
+    """Test completely custom column specifications with custom value
+      extraction.
+    
+      This test validates advanced customization of how chunk data is stored:
+      
+      Custom column names:
+      - custom_id (instead of 'id')
+      - embedding_vec (instead of 'embedding') 
+      - content_col (instead of 'content')
+      
+      Custom value extraction functions:
+      - ID: Extract timestamp from metadata and prefix with "timestamp_"
+      - Content: Prefix content with its character length "10:actual_content"
+      - Embedding: Use custom embedding extraction function
+      
+      This tests the flexibility to completely reshape how chunk data maps 
+      to database columns, useful for integrating with existing database 
schemas
+      or applying business-specific transformations.
+      """
+    self.skip_if_dataflow_runner()
+
+    helper = self.db_helpers[db_config.database_type]
+    num_records = 20
+    common_module = db_config.common_module
+
+    test_chunks = [
+        Chunk(
+            id=str(i),
+            content=Content(text=f"content_{i}"),
+            embedding=Embedding(dense_embedding=[float(i), float(i + 1)]),
+            metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"})
+        for i in range(num_records)
+    ]
+
+    chunk_embedding_fn = common_module.chunk_embedding_fn
+    specs = (
+        common_module.ColumnSpecsBuilder().add_custom_column_spec(
+            common_module.ColumnSpec.text(
+                column_name="custom_id",
+                value_fn=lambda chunk:
+                f"timestamp_{chunk.metadata.get('timestamp', '')}")
+        ).add_custom_column_spec(
+            common_module.ColumnSpec.vector(
+                column_name="embedding_vec",
+                value_fn=chunk_embedding_fn)).add_custom_column_spec(
+                    common_module.ColumnSpec.text(
+                        column_name="content_col",
+                        value_fn=lambda chunk:
+                        f"{len(chunk.content.text)}:{chunk.content.text}")).
+        with_metadata_spec().build())
+
+    def custom_row_to_chunk(row):
+      timestamp = row.custom_id.split('timestamp_')[1]
+      i = int(timestamp.split('T')[1][:2])
+
+      embedding_list = [
+          float(x) for x in row.embedding_vec.strip('[]').split(',')
+      ]
+
+      content = row.content_col.split(':', 1)[1]
+
+      return Chunk(
+          id=str(i),
+          content=Content(text=content),
+          embedding=Embedding(dense_embedding=embedding_list),
+          metadata=json.loads(row.metadata))
+
+    writer_config = helper.create_writer_config(helper.custom_table_name, 
specs)
+    self.write_test_pipeline.not_use_test_runner_api = True
+    with self.write_test_pipeline as p:
+      _ = (
+          p | beam.Create(test_chunks)
+          | VectorDatabaseWriteTransform(writer_config))
+
+    jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+        helper, helper.custom_table_name)
+
+    self.read_test_pipeline.not_use_test_runner_api = True
+    with self.read_test_pipeline as p:
+      rows = (p | ReadFromJdbc(**jdbc_params))
+
+      count_result = rows | "Count All" >> beam.combiners.Count.Globally()
+      assert_that(count_result, equal_to([num_records]), label='count_check')
+
+      chunks = rows | "To Chunks" >> beam.Map(custom_row_to_chunk)
+      assert_that(chunks, equal_to(test_chunks), label='chunks_check')
+
+  @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+  def test_custom_type_conversion_with_default_columns(self, db_config):
+    """Test custom type conversion and SQL typecasting with modified column
+      names.
+    
+      This test validates data type handling and database-specific SQL 
features:
+      
+      Type conversion:
+      - Convert string IDs to integers before storage
+      - Apply length-prefix transformation to content
+      
+      SQL typecasting (database-specific):
+      - PostgreSQL: Use ::text typecast for converted integers
+      - MySQL: Rely on automatic type conversion (no explicit typecast)
+      
+      Column name customization:
+      - Use custom names but with standard spec builders (not completely custom
+        functions)
+      
+      This tests the ability to adapt data types for database constraints
+      while maintaining the standard chunk-to-database mapping logic.
+    """
+    self.skip_if_dataflow_runner()
+
+    helper = self.db_helpers[db_config.database_type]
+    num_records = 20
+    common_module = db_config.common_module
+
+    test_chunks = [
+        Chunk(
+            id=str(i),
+            content=Content(text=f"content_{i}"),
+            embedding=Embedding(dense_embedding=[float(i), float(i + 1)]),
+            metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"})
+        for i in range(num_records)
+    ]
+
+    if db_config.database_type == "postgresql":
+      specs = (
+          common_module.ColumnSpecsBuilder().with_id_spec(
+              column_name="custom_id",
+              python_type=int,
+              convert_fn=lambda x: int(x),
+              sql_typecast="::text").with_content_spec(
+                  column_name="content_col",
+                  convert_fn=lambda x: f"{len(x)}:{x}"  # Add length prefix
+              ).with_embedding_spec(
+                  column_name="embedding_vec").with_metadata_spec().build())
+    else:  # MySQL
+      specs = (
+          common_module.ColumnSpecsBuilder().with_id_spec(
+              column_name="custom_id",
+              python_type=int,
+              convert_fn=lambda x: int(x)).with_content_spec(
+                  column_name="content_col",
+                  convert_fn=lambda x: f"{len(x)}:{x}").with_embedding_spec(
+                      
column_name="embedding_vec").with_metadata_spec().build())
+
+    def type_conversion_row_to_chunk(row):
+      embedding_list = [
+          float(x) for x in row.embedding_vec.strip('[]').split(',')
+      ]
+
+      content = row.content_col.split(':', 1)[1]
+
+      return Chunk(
+          id=row.custom_id,  # custom_id is the converted ID field
+          content=Content(text=content),
+          embedding=Embedding(dense_embedding=embedding_list),
+          metadata=json.loads(row.metadata))
+
+    writer_config = helper.create_writer_config(helper.custom_table_name, 
specs)
+    self.write_test_pipeline.not_use_test_runner_api = True
+    with self.write_test_pipeline as p:
+      _ = (
+          p | beam.Create(test_chunks)
+          | VectorDatabaseWriteTransform(writer_config))
+
+    jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+        helper, helper.custom_table_name)
+
+    self.read_test_pipeline.not_use_test_runner_api = True
+    with self.read_test_pipeline as p:
+      rows = (p | ReadFromJdbc(**jdbc_params))
+
+      count_result = rows | "Count All" >> beam.combiners.Count.Globally()
+      assert_that(count_result, equal_to([num_records]), label='count_check')
+
+      chunks = rows | "To Chunks" >> beam.Map(type_conversion_row_to_chunk)
+      assert_that(chunks, equal_to(test_chunks), label='chunks_check')
+
+  @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+  def test_default_id_embedding_specs(self, db_config):
+    """Test minimal schema with only ID and embedding columns.
+  
+      This test validates the ability to create a minimal vector database
+      schema:
+      - Only stores id and embedding fields
+      - content and metadata columns are excluded from the table
+      - Tests that the system correctly handles missing/null fields
+      
+      Use case: When you only need vector similarity search without storing 
+      the original content or metadata (perhaps stored elsewhere).
+      
+      Validation:
+      - Chunks written with content/metadata are stored with those fields as
+        null
+      - Reading back produces chunks with null content and empty metadata
+      - Vector similarity operations still work normally
+      """
+    self.skip_if_dataflow_runner()
+
+    helper = self.db_helpers[db_config.database_type]
+    num_records = 20
+    common_module = db_config.common_module
+
+    specs = (
+        
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+        build())
+
+    writer_config = helper.create_writer_config(column_specs=specs)
+
+    test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records)
+
+    with self.write_test_pipeline as p:
+      _ = (
+          p | beam.Create(test_chunks)
+          | VectorDatabaseWriteTransform(writer_config))
+
+    expected_chunks = test_utils.ChunkTestUtils.get_expected_values(
+        0, num_records)
+    for chunk in expected_chunks:
+      chunk.content.text = None  # Content column not included in schema
+      chunk.metadata = {}  # Metadata column not included in schema
+
+    jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+        helper, helper.default_table_name)
+    if db_config.database_type == "postgresql":
+      jdbc_params['query'] = f"""
+            SELECT 
+                CAST(id AS VARCHAR(255)),
+                CAST(embedding AS text)
+            FROM {helper.default_table_name}
+            ORDER BY id
+        """
+    elif db_config.database_type == "mysql":
+      jdbc_params['query'] = f"""
+            SELECT 
+                CAST(id AS CHAR(255)) as id,
+                vector_to_string(embedding) as embedding
+            FROM {helper.default_table_name}
+        """
+
+    with self.read_test_pipeline as p:
+      rows = (p | ReadFromJdbc(**jdbc_params))
+      chunks = rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk)
+      assert_that(chunks, equal_to(expected_chunks), label='chunks_check')
+
+  @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+  def test_metadata_field_extraction(self, db_config):
+    """Test extracting specific metadata fields into separate database columns.
+      
+      This test validates the ability to:
+      - Extract specific fields from the JSON metadata object 
+      - Map them to dedicated database columns (e.g., metadata.source -> source
+        column)
+      - Apply database-specific SQL typecasts (PostgreSQL ::timestamp vs MySQL
+        default)
+      - Store and retrieve the extracted fields correctly
+      
+      This is different from default metadata handling which stores the entire 
+      metadata object as JSON in a single column.
+      """
+    self.skip_if_dataflow_runner()
+
+    helper = self.db_helpers[db_config.database_type]
+    num_records = 20
+    common_module = db_config.common_module
+
+    if db_config.database_type == "postgresql":
+      specs = (
+          
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+          ).with_content_spec().add_metadata_field(
+              field="source",
+              column_name="source",
+              python_type=str,
+              sql_typecast=None).add_metadata_field(
+                  field="timestamp",
+                  python_type=str,
+                  sql_typecast="::timestamp").build())
+    else:
+      specs = (
+          
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+          ).with_content_spec().add_metadata_field(
+              field="source", column_name="source",
+              python_type=str).add_metadata_field(
+                  field="timestamp", python_type=str).build())
+
+    writer_config = helper.create_writer_config(
+        helper.metadata_conflicts_table, specs, conflict_resolution=None)
+
+    test_chunks = [
+        Chunk(
+            id=str(i),
+            content=Content(text=f"content_{i}"),
+            embedding=Embedding(dense_embedding=[float(i), float(i + 1)]),
+            metadata={
+                "source": f"source_{i % 3}",
+                "timestamp": f"2024-02-02T{i:02d}:00:00"
+            }) for i in range(num_records)
+    ]
+
+    self.write_test_pipeline.not_use_test_runner_api = True
+    with self.write_test_pipeline as p:
+      _ = (
+          p | beam.Create(test_chunks)
+          | VectorDatabaseWriteTransform(writer_config))
+
+    def metadata_row_to_chunk(row):
+      embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')]
+      timestamp = row.timestamp.replace(
+          ' ', 'T') if ' ' in row.timestamp else row.timestamp
+      return Chunk(
+          id=row.id,
+          content=Content(text=row.content),
+          embedding=Embedding(dense_embedding=embedding_list),
+          metadata={
+              "source": row.source, "timestamp": timestamp
+          })
+
+    jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+        helper, helper.metadata_conflicts_table)
+
+    self.read_test_pipeline.not_use_test_runner_api = True
+    with self.read_test_pipeline as p:
+      rows = (p | ReadFromJdbc(**jdbc_params))
+      chunks = rows | "To Chunks" >> beam.Map(metadata_row_to_chunk)
+      assert_that(chunks, equal_to(test_chunks), label='chunks_check')
+
+  @parameterized.expand([(POSTGRES_CONFIG), (MYSQL_CONFIG)])
+  def test_composite_unique_constraint_conflicts(self, db_config):
+    """Test conflict resolution when unique constraints span multiple columns.
+      
+      This test validates conflict resolution when the unique constraint is 
NOT 
+      on the primary key, but on a combination of other columns (source +
+      timestamp).
+      
+      Scenario:
+      1. Insert records with unique (source, timestamp) combinations
+      2. Attempt to insert records with same (source, timestamp) but different
+         IDs and content
+      3. Verify that conflict resolution (UPDATE) works correctly based on
+         composite key
+      
+      This is different from test_conflict_resolution which tests conflicts on 
+      the primary key field only.
+      """
+    self.skip_if_dataflow_runner()
+
+    helper = self.db_helpers[db_config.database_type]
+    num_records = 5
+    common_module = db_config.common_module
+
+    if db_config.database_type == "postgresql":
+      specs = (
+          
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+          ).with_content_spec().add_metadata_field(
+              field="source",
+              column_name="source",
+              python_type=str,
+              sql_typecast=None).add_metadata_field(
+                  field="timestamp",
+                  python_type=str,
+                  sql_typecast="::timestamp").build())
+
+      conflict_resolution = common_module.ConflictResolution(
+          on_conflict_fields=["source", "timestamp"],
+          action="UPDATE",
+          update_fields=["embedding", "content"])
+    elif db_config.database_type == "mysql":
+      specs = (
+          
common_module.ColumnSpecsBuilder().with_id_spec().with_embedding_spec(
+          ).with_content_spec().add_metadata_field(
+              field="source", column_name="source",
+              python_type=str).add_metadata_field(
+                  field="timestamp", python_type=str).build())
+
+      # MySQL conflict resolution - detects unique constraint automatically
+      conflict_resolution = common_module.ConflictResolution(
+          action="UPDATE", update_fields=["embedding", "content"])
+
+    writer_config = helper.create_writer_config(
+        helper.metadata_conflicts_table, specs, conflict_resolution)
+
+    initial_chunks = [
+        Chunk(
+            id=str(i),
+            content=Content(text=f"content_{i}"),
+            embedding=Embedding(dense_embedding=[float(i), float(i + 1)]),
+            metadata={
+                "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00"
+            }) for i in range(num_records)
+    ]
+
+    with self.write_test_pipeline as p:
+      _ = (
+          p | "Write Initial" >> beam.Create(initial_chunks)
+          | VectorDatabaseWriteTransform(writer_config))
+
+    conflicting_chunks = [
+        Chunk(
+            id=f"new_{i}",
+            content=Content(text=f"updated_content_{i}"),
+            embedding=Embedding(
+                dense_embedding=[float(i) * 2, float(i + 1) * 2]),
+            metadata={
+                "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00"
+            }) for i in range(num_records)
+    ]
+
+    with self.write_test_pipeline2 as p:
+      _ = (
+          p | "Write Conflicts" >> beam.Create(conflicting_chunks)
+          | VectorDatabaseWriteTransform(writer_config))
+
+    expected_chunks = [
+        Chunk(
+            id=str(i),
+            content=Content(text=f"updated_content_{i}"),
+            embedding=Embedding(
+                dense_embedding=[float(i) * 2, float(i + 1) * 2]),
+            metadata={
+                "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00"
+            }) for i in range(num_records)
+    ]
+
+    def metadata_row_to_chunk(row):
+      embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')]
+      timestamp = row.timestamp.replace(
+          ' ', 'T') if ' ' in row.timestamp else row.timestamp
+      return Chunk(
+          id=row.id,
+          content=Content(text=row.content),
+          embedding=Embedding(dense_embedding=embedding_list),
+          metadata={
+              "source": row.source, "timestamp": timestamp
+          })
+
+    jdbc_params = PipelineVerificationHelper.build_jdbc_params(
+        helper, helper.metadata_conflicts_table)
+
+    with self.read_test_pipeline as p:
+      rows = (p | ReadFromJdbc(**jdbc_params))
+
+      count_result = rows | "Count All" >> beam.combiners.Count.Globally()
+      assert_that(count_result, equal_to([num_records]), label='count_check')
+
+      chunks = rows | "To Chunks" >> beam.Map(metadata_row_to_chunk)
+      assert_that(chunks, equal_to(expected_chunks), label='chunks_check')
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql.py 
b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py
new file mode 100644
index 00000000000..c64c083b6c9
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql.py
@@ -0,0 +1,268 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+from abc import ABC
+from abc import abstractmethod
+from typing import Callable
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+
+import apache_beam as beam
+from apache_beam.coders import registry
+from apache_beam.coders.row_coder import RowCoder
+from apache_beam.io.jdbc import WriteToJdbc
+from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig
+from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig
+from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig
+from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpec
+from apache_beam.ml.rag.ingestion.mysql_common import ColumnSpecsBuilder
+from apache_beam.ml.rag.ingestion.mysql_common import ConflictResolution
+from apache_beam.ml.rag.types import Chunk
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class _ConflictResolutionStrategy(ABC):
+  """Abstract base class for conflict resolution strategies."""
+  @abstractmethod
+  def get_conflict_clause(self, all_columns: List[str]) -> str:
+    """Generate the MySQL conflict clause."""
+    pass
+
+
+class _NoConflictStrategy(_ConflictResolutionStrategy):
+  """Strategy for when no conflict resolution is needed."""
+  def get_conflict_clause(self, all_columns: List[str]) -> str:
+    return ""
+
+
+class _UpdateStrategy(_ConflictResolutionStrategy):
+  """Strategy for UPDATE action on conflict."""
+  def __init__(self, update_fields: Optional[List[str]] = None):
+    self.update_fields = update_fields
+
+  def get_conflict_clause(self, all_columns: List[str]) -> str:
+    # Use provided fields or default to all columns
+    fields_to_update = self.update_fields or all_columns
+    assert len(fields_to_update) > 0
+
+    updates = [f"{field} = VALUES({field})" for field in fields_to_update]
+    return f"ON DUPLICATE KEY UPDATE {', '.join(updates)}"
+
+
+class _IgnoreStrategy(_ConflictResolutionStrategy):
+  """Strategy for IGNORE action on conflict."""
+  def __init__(self, primary_key_field: str):
+    self.primary_key_field = primary_key_field
+
+  def get_conflict_clause(self, all_columns: List[str]) -> str:
+    return f"ON DUPLICATE KEY UPDATE {self.primary_key_field}"\
+       f" = {self.primary_key_field}"
+
+
+def _create_conflict_strategy(
+    conflict_resolution: Optional[ConflictResolution]
+) -> _ConflictResolutionStrategy:
+  if conflict_resolution is None:
+    return _NoConflictStrategy()
+  if conflict_resolution.action == "UPDATE":
+    return _UpdateStrategy(conflict_resolution.update_fields)
+  if conflict_resolution.action == "IGNORE":
+    assert conflict_resolution.primary_key_field is not None
+    return _IgnoreStrategy(conflict_resolution.primary_key_field)
+  raise ValueError(f"Unknown conflict resolution {conflict_resolution.action}")
+
+
+class _MySQLQueryBuilder:
+  def __init__(
+      self,
+      table_name: str,
+      *,
+      column_specs: List[ColumnSpec],
+      conflict_resolution: Optional[ConflictResolution] = None):
+    """Builds SQL queries for writing Chunks with Embeddings to MySQL.
+    """
+    self.table_name = table_name
+
+    self.column_specs = column_specs
+    self.conflict_resolution_strategy = _create_conflict_strategy(
+        conflict_resolution)
+
+    names = [col.column_name for col in self.column_specs]
+    duplicates = set(name for name in names if names.count(name) > 1)
+    if duplicates:
+      raise ValueError(f"Duplicate column names found: {duplicates}")
+
+    fields = [(col.column_name, col.python_type) for col in self.column_specs]
+    type_name = f"VectorRecord_{table_name}"
+    self.record_type = NamedTuple(type_name, fields)  # type: ignore
+
+    registry.register_coder(self.record_type, RowCoder)
+
+  def build_insert(self) -> str:
+    fields = [col.column_name for col in self.column_specs]
+    placeholders = [col.placeholder for col in self.column_specs]
+
+    # Build base query
+    query = f"""
+        INSERT INTO {self.table_name}
+        ({', '.join(fields)})
+        VALUES ({', '.join(placeholders)})
+    """
+    conflict_clause = self.conflict_resolution_strategy.get_conflict_clause(
+        all_columns=fields)
+    query += f" {conflict_clause}"
+
+    _LOGGER.info("MySQL Query with placeholders %s", query)
+    return query
+
+  def create_converter(self) -> Callable[[Chunk], NamedTuple]:
+    """Creates a function to convert Chunks to records."""
+    def convert(chunk: Chunk) -> self.record_type:  # type: ignore
+      return self.record_type(
+          **{col.column_name: col.value_fn(chunk)
+             for col in self.column_specs})  # type: ignore
+
+    return convert
+
+
+class MySQLVectorWriterConfig(VectorDatabaseWriteConfig):
+  def __init__(
+      self,
+      connection_config: ConnectionConfig,
+      table_name: str,
+      *,
+      # pylint: disable=dangerous-default-value
+      write_config: WriteConfig = WriteConfig(),
+      column_specs: List[ColumnSpec] = 
ColumnSpecsBuilder.with_defaults().build(
+      ),
+      conflict_resolution: Optional[ConflictResolution] = None):
+    """Configuration for writing vectors to MySQL using jdbc.
+    
+    Supports flexible schema configuration through column specifications and
+    conflict resolution strategies with MySQL-specific syntax.
+
+    Args:
+        connection_config:
+          :class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`.
+        table_name: Target table name.
+        write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control
+          batch sizes, authosharding, etc.
+        column_specs:
+            Use :class:`~.mysql_common.ColumnSpecsBuilder` to configure how
+            embeddings and metadata are written to the database
+            schema. If None, uses default Chunk schema with MySQL vector
+            functions.
+        conflict_resolution: Optional
+            :class:`~.mysql_common.ConflictResolution`
+            strategy for handling insert conflicts. ON DUPLICATE KEY UPDATE.
+            None by default, meaning errors are thrown when attempting to 
insert
+            duplicates.
+    
+    Examples:
+        Simple case with default schema:
+
+        >>> config = MySQLVectorWriterConfig(
+        ...     connection_config=ConnectionConfig(...),
+        ...     table_name='embeddings'
+        ... )
+
+        Custom schema with metadata fields and MySQL functions:
+
+        >>> specs = (ColumnSpecsBuilder()
+        ...         .with_id_spec(column_name="my_id_column")
+        ...         .with_embedding_spec(
+        ...             column_name="embedding_vec",
+        ...             placeholder="string_to_vector(?)"
+        ...         )
+        ...         .add_metadata_field(field="source", column_name="src")
+        ...         .add_metadata_field(
+        ...             "timestamp",
+        ...             column_name="created_at",
+        ...             placeholder="STR_TO_DATE(?, '%Y-%m-%d %H:%i:%s')"
+        ...         )
+        ...         .build())
+
+        Minimal schema (only ID + embedding written):
+
+        >>> column_specs = (ColumnSpecsBuilder()
+        ...     .with_id_spec()
+        ...     .with_embedding_spec()
+        ...     .build())
+
+        >>> config = MySQLVectorWriterConfig(
+        ...     connection_config=ConnectionConfig(...),
+        ...     table_name='embeddings',
+        ...     column_specs=specs,
+        ...     conflict_resolution=ConflictResolution(
+        ...         on_conflict_fields=["id"],
+        ...         action="UPDATE",
+        ...         update_fields=["embedding", "content"]
+        ...     )
+        ... )
+
+        Using MySQL JSON functions:
+
+        >>> specs = (ColumnSpecsBuilder()
+        ...     .with_id_spec()
+        ...     .with_embedding_spec()
+        ...     .with_metadata_spec(
+        ...         column_name="metadata_json",
+        ...         placeholder="CAST(? AS JSON)"
+        ...     )
+        ...     .build())
+    """
+    self.connection_config = connection_config
+    self.write_config = write_config
+    # NamedTuple is created and registered here during pipeline construction
+    self.query_builder = _MySQLQueryBuilder(
+        table_name,
+        column_specs=column_specs,
+        conflict_resolution=conflict_resolution)
+
+  def create_write_transform(self) -> beam.PTransform:
+    return _WriteToMySQLVectorDatabase(self)
+
+
+class _WriteToMySQLVectorDatabase(beam.PTransform):
+  """Implementation of MySQL vector database write."""
+  def __init__(self, config: MySQLVectorWriterConfig):
+    self.config = config
+    self.query_builder = config.query_builder
+    self.connection_config = config.connection_config
+    self.write_config = config.write_config
+
+  def expand(self, pcoll: beam.PCollection[Chunk]):
+    return (
+        pcoll
+        |
+        "Convert to Records" >> beam.Map(self.query_builder.create_converter())
+        | "Write to MySQL" >> WriteToJdbc(
+            table_name=self.query_builder.table_name,
+            driver_class_name="com.mysql.cj.jdbc.Driver",
+            jdbc_url=self.connection_config.jdbc_url,
+            username=self.connection_config.username,
+            password=self.connection_config.password,
+            statement=self.query_builder.build_insert(),
+            connection_properties=self.connection_config.connection_properties,
+            connection_init_sqls=self.connection_config.connection_init_sqls,
+            autosharding=self.write_config.autosharding,
+            max_connections=self.write_config.max_connections,
+            write_batch_size=self.write_config.write_batch_size,
+            **self.connection_config.additional_jdbc_args))
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py 
b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py
new file mode 100644
index 00000000000..c1ee703a5f2
--- /dev/null
+++ b/sdks/python/apache_beam/ml/rag/ingestion/mysql_common.py
@@ -0,0 +1,433 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from dataclasses import dataclass
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Type
+
+from apache_beam.ml.rag.types import Chunk
+
+
+def chunk_embedding_fn(chunk: Chunk) -> str:
+  """Convert embedding to MySQL vector string format.
+    
+    Formats dense embedding as a MySQL-compatible vector string.
+    Example: [1.0, 2.0] -> '[1.0,2.0]'
+    
+    Args:
+        chunk: Input Chunk object.
+    
+    Returns:
+        str: MySQL vector string representation of the embedding.
+    
+    Raises:
+        ValueError: If chunk has no dense embedding.
+    """
+  if chunk.embedding is None or chunk.embedding.dense_embedding is None:
+    raise ValueError(f'Expected chunk to contain embedding. {chunk}')
+  return '[' + ','.join(str(x) for x in chunk.embedding.dense_embedding) + ']'
+
+
+@dataclass
+class ColumnSpec:
+  """Specification for mapping Chunk fields to MySQL columns for insertion.
+    
+    Defines how to extract and format values from Chunks into MySQL database
+    columns, handling the full pipeline from Python value to SQL insertion.
+
+    The insertion process works as follows:
+    - value_fn extracts a value from the Chunk and formats it as needed
+    - The value is stored in a NamedTuple field with the specified python_type
+    - During SQL insertion, the value is bound to a ? placeholder
+
+    Attributes:
+        column_name: The column name in the database table.
+        python_type: Python type for the NamedTuple field that will hold the
+            value. Must be compatible with 
+            :class:`~apache_beam.coders.row_coder.RowCoder`.
+        value_fn: Function to extract and format the value from a Chunk.
+            Takes a Chunk and returns a value of python_type.
+        placeholder: Optional placeholder to apply typecasts or functions to
+            value ? placeholder e.g. "string_to_vector(?)" for vector columns.
+    
+    Examples:
+
+        Basic text column (uses standard JDBC type mapping):
+
+        >>> ColumnSpec.text(
+        ...     column_name="content",
+        ...     value_fn=lambda chunk: chunk.content.text
+        ... )
+        ... # Results in: INSERT INTO table (content) VALUES (?)
+
+        Timestamp from metadata:
+
+        >>> ColumnSpec(
+        ...     column_name="created_at",
+        ...     python_type=str,
+        ...     value_fn=lambda chunk: chunk.metadata.get("timestamp")
+        ... )
+        ... # Results in: INSERT INTO table (created_at) VALUES (?)
+
+
+    Factory Methods:
+        text: Creates a text column specification.
+        integer: Creates an integer column specification.
+        float: Creates a float column specification.
+        vector: Creates a vector column specification with string_to_vector().
+        json: Creates a JSON column specification.
+    """
+  column_name: str
+  python_type: Type
+  value_fn: Callable[[Chunk], Any]
+  placeholder: str = '?'
+
+  @classmethod
+  def text(
+      cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec':
+    """Create a text column specification."""
+    return cls(column_name, str, value_fn)
+
+  @classmethod
+  def integer(
+      cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec':
+    """Create an integer column specification."""
+    return cls(column_name, int, value_fn)
+
+  @classmethod
+  def float(
+      cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec':
+    """Create a float column specification."""
+    return cls(column_name, float, value_fn)
+
+  @classmethod
+  def vector(
+      cls,
+      column_name: str,
+      value_fn: Callable[[Chunk], Any] = chunk_embedding_fn) -> 'ColumnSpec':
+    """Create a vector column specification with string_to_vector() 
function."""
+    return cls(column_name, str, value_fn, "string_to_vector(?)")
+
+  @classmethod
+  def json(
+      cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec':
+    """Create a JSON column specification."""
+    return cls(column_name, str, value_fn)
+
+
+def embedding_to_string(embedding: List[float]) -> str:
+  """Convert embedding to MySQL vector string format."""
+  return '[' + ','.join(str(x) for x in embedding) + ']'
+
+
+class ColumnSpecsBuilder:
+  """Builder for :class:`.ColumnSpec`'s with chainable methods."""
+  def __init__(self):
+    self._specs: List[ColumnSpec] = []
+
+  @staticmethod
+  def with_defaults() -> 'ColumnSpecsBuilder':
+    """Add all default column specifications."""
+    return (
+        ColumnSpecsBuilder().with_id_spec().with_embedding_spec().
+        with_content_spec().with_metadata_spec())
+
+  def with_id_spec(
+      self,
+      column_name: str = "id",
+      python_type: Type = str,
+      convert_fn: Optional[Callable[[str],
+                                    Any]] = None) -> 'ColumnSpecsBuilder':
+    """Add ID :class:`.ColumnSpec` with optional type and conversion.
+        
+        Args:
+            column_name: Name for the ID column (defaults to "id")
+            python_type: Python type for the column (defaults to str)
+            convert_fn: Optional function to convert the chunk ID
+                       If None, uses ID as-is
+        
+        Returns:
+            Self for method chaining
+        
+        Example:
+            >>> builder.with_id_spec(
+            ...     column_name="doc_id",
+            ...     python_type=int,
+            ...     convert_fn=lambda id: int(id.split('_')[1])
+            ... )
+        """
+    def value_fn(chunk: Chunk) -> Any:
+      value = chunk.id
+      return convert_fn(value) if convert_fn else value
+
+    self._specs.append(
+        ColumnSpec(
+            column_name=column_name, python_type=python_type,
+            value_fn=value_fn))
+    return self
+
+  def with_content_spec(
+      self,
+      column_name: str = "content",
+      python_type: Type = str,
+      convert_fn: Optional[Callable[[str],
+                                    Any]] = None) -> 'ColumnSpecsBuilder':
+    """Add content :class:`.ColumnSpec` with optional type and conversion.
+      
+      Args:
+          column_name: Name for the content column (defaults to "content")
+          python_type: Python type for the column (defaults to str)
+          convert_fn: Optional function to convert the content text
+                      If None, uses content text as-is
+      
+      Returns:
+          Self for method chaining
+      
+      Example:
+          >>> builder.with_content_spec(
+          ...     column_name="content_length",
+          ...     python_type=int,
+          ...     convert_fn=len  # Store content length instead of content
+          ... )
+      """
+    def value_fn(chunk: Chunk) -> Any:
+      if chunk.content.text is None:
+        raise ValueError(f'Expected chunk to contain content. {chunk}')
+      value = chunk.content.text
+      return convert_fn(value) if convert_fn else value
+
+    self._specs.append(
+        ColumnSpec(
+            column_name=column_name, python_type=python_type,
+            value_fn=value_fn))
+    return self
+
+  def with_metadata_spec(
+      self,
+      column_name: str = "metadata",
+      python_type: Type = str,
+      convert_fn: Optional[Callable[[Dict[str, Any]], Any]] = None
+  ) -> 'ColumnSpecsBuilder':
+    """Add metadata :class:`.ColumnSpec` with optional type and conversion.
+      
+      Args:
+          column_name: Name for the metadata column (defaults to "metadata")
+          python_type: Python type for the column (defaults to str)
+          convert_fn: Optional function to convert the metadata dictionary
+                      If None and python_type is str, converts to JSON string
+      
+      Returns:
+          Self for method chaining
+      
+      Example:
+          >>> builder.with_metadata_spec(
+          ...     column_name="meta_tags",
+          ...     python_type=str,
+          ...     convert_fn=lambda meta: ','.join(meta.keys())
+          ... )
+      """
+    def value_fn(chunk: Chunk) -> Any:
+      if convert_fn:
+        return convert_fn(chunk.metadata)
+      return json.dumps(
+          chunk.metadata) if python_type == str else chunk.metadata
+
+    self._specs.append(
+        ColumnSpec(
+            column_name=column_name, python_type=python_type,
+            value_fn=value_fn))
+    return self
+
+  def with_embedding_spec(
+      self,
+      column_name: str = "embedding",
+      convert_fn: Callable[[List[float]], Any] = embedding_to_string
+  ) -> 'ColumnSpecsBuilder':
+    """Add embedding :class:`.ColumnSpec` with optional conversion.
+      
+      Args:
+          column_name: Name for the embedding column (defaults to "embedding")
+          convert_fn: Optional function to convert the dense embedding values
+                      If None, uses default MySQL vector format
+      
+      Returns:
+          Self for method chaining
+      
+      Example:
+          >>> builder.with_embedding_spec(
+          ...     column_name="embedding_vector",
+          ...     convert_fn=lambda values: '[' + ','.join(f"{x:.4f}" 
+          ...       for x in values) + ']'
+          ... )
+      """
+    def value_fn(chunk: Chunk) -> Any:
+      if chunk.embedding is None or chunk.embedding.dense_embedding is None:
+        raise ValueError(f'Expected chunk to contain embedding. {chunk}')
+      values = chunk.embedding.dense_embedding
+      return convert_fn(values)
+
+    self._specs.append(
+        ColumnSpec.vector(column_name=column_name, value_fn=value_fn))
+    return self
+
+  def add_metadata_field(
+      self,
+      field: str,
+      python_type: Type,
+      column_name: Optional[str] = None,
+      convert_fn: Optional[Callable[[Any], Any]] = None,
+      default: Any = None) -> 'ColumnSpecsBuilder':
+    """Add a :class:`.ColumnSpec` that extracts and converts a field from
+        chunk metadata.
+
+        Args:
+            field: Key to extract from chunk metadata
+            python_type: Python type for the column (e.g. str, int, float)
+            column_name: Name for the column (defaults to metadata field name)
+            convert_fn: Optional function to convert the extracted value to
+                      desired type. If None, value is used as-is
+            default: Default value if field is missing from metadata
+        
+        Returns:
+            Self for chaining
+
+        Examples:
+            Simple string field:
+            >>> builder.add_metadata_field("source", str)
+
+            Integer with default:
+            >>> builder.add_metadata_field(
+            ...     field="count",
+            ...     python_type=int,
+            ...     column_name="item_count",
+            ...     default=0
+            ... )
+
+            Float with conversion and default:
+            >>> builder.add_metadata_field(
+            ...     field="confidence",
+            ...     python_type=float,
+            ...     convert_fn=lambda x: round(float(x), 2),
+            ...     default=0.0
+            ... )
+
+            Timestamp with conversion:
+            >>> builder.add_metadata_field(
+            ...     field="created_at",
+            ...     python_type=str,
+            ...     convert_fn=lambda ts: ts.replace('T', ' ')
+            ... )
+        """
+    name = column_name or field
+
+    def value_fn(chunk: Chunk) -> Any:
+      value = chunk.metadata.get(field, default)
+      if value is not None and convert_fn is not None:
+        value = convert_fn(value)
+      return value
+
+    spec = ColumnSpec(
+        column_name=name, python_type=python_type, value_fn=value_fn)
+
+    self._specs.append(spec)
+    return self
+
+  def add_custom_column_spec(self, spec: ColumnSpec) -> 'ColumnSpecsBuilder':
+    """Add a custom :class:`.ColumnSpec` to the builder.
+    
+    Use this method when you need complete control over the
+    :class:`.ColumnSpec`, including custom value extraction and type handling.
+    
+    Args:
+        spec: A :class:`.ColumnSpec` instance defining the column name, type,
+            value extraction, and optional MySQL function.
+    
+    Returns:
+        Self for method chaining
+    
+    Examples:
+        Custom text column from chunk metadata:
+        >>> builder.add_custom_column_spec(
+        ...     ColumnSpec.text(
+        ...         column_name="source_and_id",
+        ...         value_fn=lambda chunk: 
+        ...             f"{chunk.metadata.get('source')}_{chunk.id}"
+        ...     )
+        ... )
+    """
+    self._specs.append(spec)
+    return self
+
+  def build(self) -> List[ColumnSpec]:
+    """Build the final list of column specifications."""
+    return self._specs.copy()
+
+
+@dataclass
+class ConflictResolution:
+  """Specification for how to handle conflicts during insert.
+
+    Configures conflict handling behavior when inserting records that may
+    violate unique constraints using MySQL's ON DUPLICATE KEY UPDATE syntax.
+    
+    MySQL automatically detects conflicts based on PRIMARY KEY or UNIQUE 
+    constraints defined on the table.
+
+    Attributes:
+        action: How to handle conflicts - either "UPDATE" or "IGNORE".
+            UPDATE: Updates existing record with new values.
+            IGNORE: Skips conflicting records (uses no-op update).
+        update_fields: Optional list of fields to update on conflict. If None,
+            all fields are updated (for UPDATE action only).
+        primary_key_field: Required for IGNORE action. The primary key field
+            name to use for the no-op update.
+        
+    Examples:
+        Update all fields on conflict:
+        >>> ConflictResolution(action="UPDATE")
+        
+        Update specific fields on conflict:
+        >>> ConflictResolution(
+        ...     action="UPDATE",
+        ...     update_fields=["embedding", "content"]
+        ... )
+        
+        Ignore conflicts with explicit primary key:
+        >>> ConflictResolution(
+        ...     action="IGNORE", 
+        ...     primary_key_field="id"
+        ... )
+        
+        Ignore conflicts with custom primary key:
+        >>> ConflictResolution(
+        ...     action="IGNORE",
+        ...     primary_key_field="custom_id"
+        ... )
+    """
+  action: Literal["UPDATE", "IGNORE"] = "UPDATE"
+  update_fields: Optional[List[str]] = None
+  primary_key_field: Optional[str] = None
+
+  def __post_init__(self):
+    """Validate configuration after initialization."""
+    if self.action == "IGNORE" and self.primary_key_field is None:
+      raise ValueError("primary_key_field is required when action='IGNORE'")
diff --git a/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py 
b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py
index cd30766a288..0373874c09d 100644
--- a/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py
+++ b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py
@@ -18,21 +18,12 @@
 import hashlib
 import json
 from typing import List
-from typing import NamedTuple
 
 import apache_beam as beam
-from apache_beam.coders import registry
-from apache_beam.coders.row_coder import RowCoder
 from apache_beam.ml.rag.types import Chunk
 from apache_beam.ml.rag.types import Content
 from apache_beam.ml.rag.types import Embedding
 
-TestRow = NamedTuple(
-    'TestRow',
-    [('id', str), ('embedding', List[float]), ('content', str),
-     ('metadata', str)])
-registry.register_coder(TestRow, RowCoder)
-
 VECTOR_SIZE = 768
 
 

Reply via email to