kaxil commented on code in PR #62793: URL: https://github.com/apache/airflow/pull/62793#discussion_r2880805686
########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py: ########## @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for cross-system schema drift detection powered by LLM reasoning.""" + +from __future__ import annotations + +import json +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import BaseModel, Field + +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.compat.sdk import BaseHook +from airflow.sdk.exceptions import AirflowException + +if TYPE_CHECKING: + from airflow.providers.common.sql.config import DataSourceConfig + from airflow.providers.common.sql.hooks.sql import DbApiHook + from airflow.sdk import Context + + +class SchemaMismatch(BaseModel): + """A single schema mismatch between data sources.""" + + source: str = Field(description="Source table") + target: str = Field(description="Target table") + column: str = Field(description="Column name where the mismatch was detected") + source_type: str = Field(description="Data type in the source system") + target_type: str = Field(description="Data type in the target system") + severity: Literal["critical", "warning", "info"] = Field(description="Mismatch severity") + description: str = Field(description="Human-readable description of the mismatch") + suggested_action: str = Field(description="Recommended action to resolve the mismatch") + migration_query: str = Field(description="Provide migration query to resolve the mismatch") + + +class SchemaCompareResult(BaseModel): + """Structured output from schema comparison.""" + + compatible: bool = Field(description="Whether the schemas are compatible for data loading") + mismatches: list[SchemaMismatch] = Field(default_factory=list) + summary: str = Field(description="High-level summary of the comparison") + + +class LLMSchemaCompareOperator(LLMOperator): + """ + Compare schemas across different database systems and detect drift using LLM reasoning. + + The LLM handles complex cross-system type mapping that simple equality checks + miss (e.g., ``varchar(255)`` vs ``string``, ``timestamp`` vs ``timestamptz``). + + Accepts data sources via two patterns: + + 1. **data_sources** — a list of + :class:`~airflow.providers.common.sql.config.DataSourceConfig` for each + system. If the connection resolves to a + :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`, schema is + introspected via SQLAlchemy; otherwise DataFusion is used. + 2. **db_conn_ids + table_names** — shorthand for comparing the same table + across multiple database connections (all must resolve to ``DbApiHook``). + + :param prompt: Instructions for the LLM on what to compare and flag. + :param llm_conn_id: Connection ID for the LLM provider. + :param model_id: Model identifier (e.g. ``"openai:gpt-5"``). + :param system_prompt: Additional instructions appended to the built-in + schema comparison prompt. + :param agent_params: Extra keyword arguments for the pydantic-ai ``Agent``. + :param data_sources: List of DataSourceConfig objects, one per system. + :param db_conn_ids: Connection IDs for databases to compare (used with + ``table_names``). + :param table_names: Tables to introspect from each ``db_conn_id``. + :param context_strategy: ``"basic"`` for column names and types only; + ``"full"`` to include primary keys, foreign keys, and indexes. + Default ``"full"``. + :param reasoning_mode: Strongly recommended — cross-system type mapping + benefits from step-by-step analysis. + """ + + template_fields: Sequence[str] = ( + *LLMOperator.template_fields, + "db_conn_ids", + "table_names", + ) + + def __init__( + self, + *, + data_sources: list[DataSourceConfig] | None = None, + db_conn_ids: list[str] | None = None, + table_names: list[str] | None = None, + context_strategy: Literal["basic", "full"] = "full", + **kwargs: Any, + ) -> None: + kwargs.pop("output_type", None) + super().__init__(**kwargs) + self.data_sources = data_sources or [] + self.db_conn_ids = db_conn_ids or [] + self.table_names = table_names or [] + self.context_strategy = context_strategy + + if not self.data_sources and not self.db_conn_ids: + raise ValueError("Provide at least one of 'data_sources' or 'db_conn_ids'.") + + if self.db_conn_ids and not self.table_names: + raise ValueError("'table_names' is required when using 'db_conn_ids'.") + + total_sources = len(self.db_conn_ids) + len(self.data_sources) + if total_sources < 2: + raise ValueError( + "Provide at-least two combinations of 'db_conn_ids' and 'table_names' or 'data_sources' " + "to compare." + ) + + @staticmethod + def _get_db_hook(conn_id: str) -> DbApiHook: + """Resolve a connection ID to a DbApiHook.""" + from airflow.providers.common.sql.hooks.sql import DbApiHook + + connection = BaseHook.get_connection(conn_id) + hook = connection.get_hook() + if not isinstance(hook, DbApiHook): + raise ValueError( + f"Connection {conn_id!r} does not provide a DbApiHook. Got {type(hook).__name__}." + ) + return hook + + def _is_dbapi_connection(self, conn_id: str) -> bool: + """Check whether a connection resolves to a DbApiHook.""" + from airflow.providers.common.sql.hooks.sql import DbApiHook + + try: + connection = BaseHook.get_connection(conn_id) + hook = connection.get_hook() + return isinstance(hook, DbApiHook) + except (AirflowException, ValueError) as exc: + self.log.debug("Connection %s does not resolve to a DbApiHook: %s", conn_id, exc, exc_info=True) + return False + + def _introspect_db_schema(self, hook: DbApiHook, table_name: str) -> str: + """Introspect schema from a database connection via DbApiHook.""" + columns = hook.get_table_schema(table_name) + if not columns: + self.log.warning("Table %r returned no columns — it may not exist.", table_name) + return "" + + col_info = ", ".join(f"{c['name']} {c['type']}" for c in columns) + parts = [f"Columns: {col_info}"] + + if self.context_strategy == "full": + try: + pks = hook.dialect.get_primary_keys(table_name) + if pks: + parts.append(f"Primary Key: {', '.join(pks)}") + except NotImplementedError: + self.log.warning("primary key introspection not implemented for dialect", hook.dialect_name) Review Comment: Bug: These `self.log.warning()` calls pass extra args but the format strings have no `%s` placeholders. Python's logging silently discards the extra args, so the dialect name and exception are never printed. Same issue on lines 172, 182, 184, 194, 196. ```python # Before (broken — dialect_name is silently discarded): self.log.warning("primary key introspection not implemented for dialect", hook.dialect_name) self.log.warning("Could not retrieve PK", ex) # After: self.log.warning("Primary key introspection not implemented for dialect %s", hook.dialect_name) self.log.warning("Could not retrieve PK for %r: %s", table_name, ex) ``` ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py: ########## @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for cross-system schema drift detection powered by LLM reasoning.""" + +from __future__ import annotations + +import json +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import BaseModel, Field + +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.compat.sdk import BaseHook +from airflow.sdk.exceptions import AirflowException + +if TYPE_CHECKING: + from airflow.providers.common.sql.config import DataSourceConfig + from airflow.providers.common.sql.hooks.sql import DbApiHook + from airflow.sdk import Context + + +class SchemaMismatch(BaseModel): + """A single schema mismatch between data sources.""" + + source: str = Field(description="Source table") + target: str = Field(description="Target table") + column: str = Field(description="Column name where the mismatch was detected") + source_type: str = Field(description="Data type in the source system") + target_type: str = Field(description="Data type in the target system") + severity: Literal["critical", "warning", "info"] = Field(description="Mismatch severity") + description: str = Field(description="Human-readable description of the mismatch") + suggested_action: str = Field(description="Recommended action to resolve the mismatch") + migration_query: str = Field(description="Provide migration query to resolve the mismatch") + + +class SchemaCompareResult(BaseModel): + """Structured output from schema comparison.""" + + compatible: bool = Field(description="Whether the schemas are compatible for data loading") + mismatches: list[SchemaMismatch] = Field(default_factory=list) + summary: str = Field(description="High-level summary of the comparison") + + +class LLMSchemaCompareOperator(LLMOperator): + """ + Compare schemas across different database systems and detect drift using LLM reasoning. + + The LLM handles complex cross-system type mapping that simple equality checks + miss (e.g., ``varchar(255)`` vs ``string``, ``timestamp`` vs ``timestamptz``). + + Accepts data sources via two patterns: + + 1. **data_sources** — a list of + :class:`~airflow.providers.common.sql.config.DataSourceConfig` for each + system. If the connection resolves to a + :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`, schema is + introspected via SQLAlchemy; otherwise DataFusion is used. + 2. **db_conn_ids + table_names** — shorthand for comparing the same table + across multiple database connections (all must resolve to ``DbApiHook``). + + :param prompt: Instructions for the LLM on what to compare and flag. + :param llm_conn_id: Connection ID for the LLM provider. + :param model_id: Model identifier (e.g. ``"openai:gpt-5"``). + :param system_prompt: Additional instructions appended to the built-in + schema comparison prompt. + :param agent_params: Extra keyword arguments for the pydantic-ai ``Agent``. + :param data_sources: List of DataSourceConfig objects, one per system. + :param db_conn_ids: Connection IDs for databases to compare (used with + ``table_names``). + :param table_names: Tables to introspect from each ``db_conn_id``. + :param context_strategy: ``"basic"`` for column names and types only; + ``"full"`` to include primary keys, foreign keys, and indexes. + Default ``"full"``. + :param reasoning_mode: Strongly recommended — cross-system type mapping + benefits from step-by-step analysis. + """ + + template_fields: Sequence[str] = ( + *LLMOperator.template_fields, + "db_conn_ids", + "table_names", + ) + + def __init__( + self, + *, + data_sources: list[DataSourceConfig] | None = None, + db_conn_ids: list[str] | None = None, + table_names: list[str] | None = None, + context_strategy: Literal["basic", "full"] = "full", + **kwargs: Any, + ) -> None: + kwargs.pop("output_type", None) + super().__init__(**kwargs) + self.data_sources = data_sources or [] + self.db_conn_ids = db_conn_ids or [] + self.table_names = table_names or [] + self.context_strategy = context_strategy + + if not self.data_sources and not self.db_conn_ids: + raise ValueError("Provide at least one of 'data_sources' or 'db_conn_ids'.") + + if self.db_conn_ids and not self.table_names: + raise ValueError("'table_names' is required when using 'db_conn_ids'.") + + total_sources = len(self.db_conn_ids) + len(self.data_sources) + if total_sources < 2: + raise ValueError( + "Provide at-least two combinations of 'db_conn_ids' and 'table_names' or 'data_sources' " + "to compare." + ) + + @staticmethod + def _get_db_hook(conn_id: str) -> DbApiHook: + """Resolve a connection ID to a DbApiHook.""" + from airflow.providers.common.sql.hooks.sql import DbApiHook + + connection = BaseHook.get_connection(conn_id) + hook = connection.get_hook() + if not isinstance(hook, DbApiHook): + raise ValueError( + f"Connection {conn_id!r} does not provide a DbApiHook. Got {type(hook).__name__}." + ) + return hook + + def _is_dbapi_connection(self, conn_id: str) -> bool: + """Check whether a connection resolves to a DbApiHook.""" + from airflow.providers.common.sql.hooks.sql import DbApiHook + + try: + connection = BaseHook.get_connection(conn_id) + hook = connection.get_hook() + return isinstance(hook, DbApiHook) + except (AirflowException, ValueError) as exc: + self.log.debug("Connection %s does not resolve to a DbApiHook: %s", conn_id, exc, exc_info=True) + return False + + def _introspect_db_schema(self, hook: DbApiHook, table_name: str) -> str: + """Introspect schema from a database connection via DbApiHook.""" + columns = hook.get_table_schema(table_name) + if not columns: + self.log.warning("Table %r returned no columns — it may not exist.", table_name) + return "" + + col_info = ", ".join(f"{c['name']} {c['type']}" for c in columns) + parts = [f"Columns: {col_info}"] + + if self.context_strategy == "full": + try: + pks = hook.dialect.get_primary_keys(table_name) + if pks: + parts.append(f"Primary Key: {', '.join(pks)}") + except NotImplementedError: + self.log.warning("primary key introspection not implemented for dialect", hook.dialect_name) + except Exception as ex: + self.log.warning("Could not retrieve PK", ex) + + try: + fks = hook.inspector.get_foreign_keys(table_name) + for fk in fks: + cols = ", ".join(fk.get("constrained_columns", [])) + ref = fk.get("referred_table", "?") + ref_cols = ", ".join(fk.get("referred_columns", [])) + parts.append(f"Foreign Key: ({cols}) -> {ref}({ref_cols})") + except NotImplementedError: + self.log.warning("foreign key introspection not implemented for dialect", hook.dialect_name) + except Exception as ex: + self.log.warning("Could not retrieve FK", ex) + + try: + indexes = hook.inspector.get_indexes(table_name) + for idx in indexes: + column_names = [c for c in idx.get("column_names", []) if c is not None] + idx_cols = ", ".join(column_names) + unique = " UNIQUE" if idx.get("unique") else "" + parts.append(f"Index{unique}: {idx.get('name', '?')} ({idx_cols})") + except NotImplementedError: + self.log.warning("index introspection not implemented for dialect", hook.dialect_name) + except Exception as ex: + self.log.warning("Could not retrieve index", ex) + + return "\n".join(parts) + + if self.context_strategy == "basic": + return "\n".join(parts) + + raise ValueError(f"Invalid context_strategy: {self.context_strategy}") + + def _introspect_datasource_schema(self, ds_config: DataSourceConfig) -> str: + """Introspect schema from a DataSourceConfig, choosing DbApiHook or DataFusion.""" + if self._is_dbapi_connection(ds_config.conn_id): + hook = self._get_db_hook(ds_config.conn_id) + dialect_name = getattr(hook, "dialect_name", "unknown") + schema_text = self._introspect_db_schema(hook, ds_config.table_name) + return ( + f"Source: {ds_config.conn_id} ({dialect_name})\nTable: {ds_config.table_name}\n{schema_text}" + ) + + return self._introspect_schema_from_datafusion(ds_config) + + @cached_property + def _df_engine(self): + try: + from airflow.providers.common.sql.datafusion.engine import DataFusionEngine + except ImportError as e: + from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) + engine = DataFusionEngine() + return engine + + def _introspect_schema_from_datafusion(self, ds_config: DataSourceConfig): + self._df_engine.register_datasource(ds_config) + schema_text = self._df_engine.get_schema(ds_config.table_name) + + return f"Source: {ds_config.conn_id} \nFormat: ({ds_config.format})\nTable: {ds_config.table_name}\nColumns: {schema_text}" + + def _build_schema_context(self) -> str: + """Collect schemas from all configured sources each clearly.""" + sections: list[str] = [] + + for conn_id in self.db_conn_ids: + hook = self._get_db_hook(conn_id) + dialect_name = getattr(hook, "dialect_name", "unknown") + for table in self.table_names: + schema_text = self._introspect_db_schema(hook, table) + if schema_text: + sections.append(f"Source: {dialect_name}\nTable: {table}\n{schema_text}") + + for ds_config in self.data_sources: + sections.append(self._introspect_datasource_schema(ds_config)) + + if not sections: + raise ValueError( + "No schema information could be retrieved from any of the configured sources. " + "Check that connection IDs, table names, and data source configs are correct." + ) + + return "\n\n".join(sections) + + def _build_system_prompt(self, schema_context: str) -> str: Review Comment: The `reasoning_mode` parameter was removed along with the type-equivalence hints and severity level definitions. The previous version prompted the LLM with cross-system type mappings (`varchar(n) / text / string` etc.) and severity definitions (`critical`, `warning`, `info`). Now the LLM gets none of that context. My original question was whether `reasoning_mode` should be a *flag* vs part of `system_prompt` — not whether the hints should be removed. These hints are valuable for getting good schema comparison results. Consider either: 1. Fold them into the default system prompt unconditionally (they're always useful for this operator) 2. Keep the flag to let users opt out if they want a simpler/cheaper prompt Also, the docstring at line 90 still documents `:param reasoning_mode:`, and the example DAGs still pass `reasoning_mode=True` (lines 56, 116 of example_dags). If the parameter is gone, those need cleanup too. ########## providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_schema_compare.py: ########## @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAGs demonstrating LLMSchemaCompareOperator usage.""" + +from __future__ import annotations + +from airflow.providers.common.ai.operators.llm_schema_compare import LLMSchemaCompareOperator +from airflow.providers.common.compat.sdk import dag, task +from airflow.providers.common.sql.config import DataSourceConfig + + +# [START howto_operator_llm_schema_compare_basic] +@dag +def example_llm_schema_compare_basic(): + LLMSchemaCompareOperator( + task_id="detect_schema_drift", + prompt="Identify schema mismatches that would break data loading between systems", + llm_conn_id="pydantic_ai_default", + db_conn_ids=["postgres_default", "snowflake_default"], + table_names=["customers"], + ) + + +# [END howto_operator_llm_schema_compare_basic] + +example_llm_schema_compare_basic() + + +# [START howto_operator_llm_schema_compare_full] +@dag +def example_llm_schema_compare_full_context(): + LLMSchemaCompareOperator( + task_id="detect_schema_drift", + prompt=( + "Compare schemas and generate a migration plan. " + "Flag any differences that would break nightly ETL loads." + ), + llm_conn_id="pydantic_ai_default", + db_conn_ids=["postgres_source", "snowflake_target"], + table_names=["customers", "orders"], + context_strategy="full", + reasoning_mode=True, Review Comment: `reasoning_mode` is no longer an explicit parameter in `LLMSchemaCompareOperator.__init__`. This kwarg will pass through `**kwargs` to `LLMOperator.__init__()`. If the parent doesn't accept it, this example DAG will fail at parse time. Same issue on line 116. Either remove `reasoning_mode=True` from the examples, or re-add the parameter to the operator. ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py: ########## @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for cross-system schema drift detection powered by LLM reasoning.""" + +from __future__ import annotations + +import json +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import BaseModel, Field + +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.compat.sdk import BaseHook +from airflow.sdk.exceptions import AirflowException Review Comment: Direct import from `airflow.sdk.exceptions` in provider code. Providers should go through `airflow.providers.common.compat.sdk` to maintain Airflow 2.x compatibility (same reason `BaseHook` is imported from compat on line 29). If `AirflowException` isn't available in compat.sdk, consider catching a broader base like `Exception` with a specific message check, or adding it to the compat layer. ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py: ########## @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for cross-system schema drift detection powered by LLM reasoning.""" + +from __future__ import annotations + +import json +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import BaseModel, Field + +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.compat.sdk import BaseHook +from airflow.sdk.exceptions import AirflowException + +if TYPE_CHECKING: + from airflow.providers.common.sql.config import DataSourceConfig + from airflow.providers.common.sql.hooks.sql import DbApiHook + from airflow.sdk import Context + + +class SchemaMismatch(BaseModel): + """A single schema mismatch between data sources.""" + + source: str = Field(description="Source table") + target: str = Field(description="Target table") + column: str = Field(description="Column name where the mismatch was detected") + source_type: str = Field(description="Data type in the source system") + target_type: str = Field(description="Data type in the target system") + severity: Literal["critical", "warning", "info"] = Field(description="Mismatch severity") + description: str = Field(description="Human-readable description of the mismatch") + suggested_action: str = Field(description="Recommended action to resolve the mismatch") + migration_query: str = Field(description="Provide migration query to resolve the mismatch") + + +class SchemaCompareResult(BaseModel): + """Structured output from schema comparison.""" + + compatible: bool = Field(description="Whether the schemas are compatible for data loading") + mismatches: list[SchemaMismatch] = Field(default_factory=list) + summary: str = Field(description="High-level summary of the comparison") + + +class LLMSchemaCompareOperator(LLMOperator): + """ + Compare schemas across different database systems and detect drift using LLM reasoning. + + The LLM handles complex cross-system type mapping that simple equality checks + miss (e.g., ``varchar(255)`` vs ``string``, ``timestamp`` vs ``timestamptz``). + + Accepts data sources via two patterns: + + 1. **data_sources** — a list of + :class:`~airflow.providers.common.sql.config.DataSourceConfig` for each + system. If the connection resolves to a + :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`, schema is + introspected via SQLAlchemy; otherwise DataFusion is used. + 2. **db_conn_ids + table_names** — shorthand for comparing the same table + across multiple database connections (all must resolve to ``DbApiHook``). + + :param prompt: Instructions for the LLM on what to compare and flag. + :param llm_conn_id: Connection ID for the LLM provider. + :param model_id: Model identifier (e.g. ``"openai:gpt-5"``). + :param system_prompt: Additional instructions appended to the built-in + schema comparison prompt. + :param agent_params: Extra keyword arguments for the pydantic-ai ``Agent``. + :param data_sources: List of DataSourceConfig objects, one per system. + :param db_conn_ids: Connection IDs for databases to compare (used with + ``table_names``). + :param table_names: Tables to introspect from each ``db_conn_id``. + :param context_strategy: ``"basic"`` for column names and types only; + ``"full"`` to include primary keys, foreign keys, and indexes. + Default ``"full"``. + :param reasoning_mode: Strongly recommended — cross-system type mapping + benefits from step-by-step analysis. + """ + + template_fields: Sequence[str] = ( + *LLMOperator.template_fields, + "db_conn_ids", + "table_names", + ) + + def __init__( + self, + *, + data_sources: list[DataSourceConfig] | None = None, + db_conn_ids: list[str] | None = None, + table_names: list[str] | None = None, + context_strategy: Literal["basic", "full"] = "full", + **kwargs: Any, + ) -> None: + kwargs.pop("output_type", None) + super().__init__(**kwargs) + self.data_sources = data_sources or [] + self.db_conn_ids = db_conn_ids or [] + self.table_names = table_names or [] + self.context_strategy = context_strategy + + if not self.data_sources and not self.db_conn_ids: + raise ValueError("Provide at least one of 'data_sources' or 'db_conn_ids'.") + + if self.db_conn_ids and not self.table_names: + raise ValueError("'table_names' is required when using 'db_conn_ids'.") + + total_sources = len(self.db_conn_ids) + len(self.data_sources) + if total_sources < 2: + raise ValueError( + "Provide at-least two combinations of 'db_conn_ids' and 'table_names' or 'data_sources' " + "to compare." + ) + + @staticmethod + def _get_db_hook(conn_id: str) -> DbApiHook: + """Resolve a connection ID to a DbApiHook.""" + from airflow.providers.common.sql.hooks.sql import DbApiHook + + connection = BaseHook.get_connection(conn_id) + hook = connection.get_hook() + if not isinstance(hook, DbApiHook): + raise ValueError( + f"Connection {conn_id!r} does not provide a DbApiHook. Got {type(hook).__name__}." + ) + return hook + + def _is_dbapi_connection(self, conn_id: str) -> bool: + """Check whether a connection resolves to a DbApiHook.""" + from airflow.providers.common.sql.hooks.sql import DbApiHook + + try: + connection = BaseHook.get_connection(conn_id) + hook = connection.get_hook() + return isinstance(hook, DbApiHook) + except (AirflowException, ValueError) as exc: + self.log.debug("Connection %s does not resolve to a DbApiHook: %s", conn_id, exc, exc_info=True) + return False + + def _introspect_db_schema(self, hook: DbApiHook, table_name: str) -> str: + """Introspect schema from a database connection via DbApiHook.""" + columns = hook.get_table_schema(table_name) + if not columns: + self.log.warning("Table %r returned no columns — it may not exist.", table_name) + return "" + + col_info = ", ".join(f"{c['name']} {c['type']}" for c in columns) + parts = [f"Columns: {col_info}"] + + if self.context_strategy == "full": + try: + pks = hook.dialect.get_primary_keys(table_name) + if pks: + parts.append(f"Primary Key: {', '.join(pks)}") + except NotImplementedError: + self.log.warning("primary key introspection not implemented for dialect", hook.dialect_name) + except Exception as ex: + self.log.warning("Could not retrieve PK", ex) + + try: + fks = hook.inspector.get_foreign_keys(table_name) + for fk in fks: + cols = ", ".join(fk.get("constrained_columns", [])) + ref = fk.get("referred_table", "?") + ref_cols = ", ".join(fk.get("referred_columns", [])) + parts.append(f"Foreign Key: ({cols}) -> {ref}({ref_cols})") + except NotImplementedError: + self.log.warning("foreign key introspection not implemented for dialect", hook.dialect_name) + except Exception as ex: + self.log.warning("Could not retrieve FK", ex) + + try: + indexes = hook.inspector.get_indexes(table_name) + for idx in indexes: + column_names = [c for c in idx.get("column_names", []) if c is not None] + idx_cols = ", ".join(column_names) + unique = " UNIQUE" if idx.get("unique") else "" + parts.append(f"Index{unique}: {idx.get('name', '?')} ({idx_cols})") + except NotImplementedError: + self.log.warning("index introspection not implemented for dialect", hook.dialect_name) + except Exception as ex: + self.log.warning("Could not retrieve index", ex) + + return "\n".join(parts) + + if self.context_strategy == "basic": + return "\n".join(parts) + + raise ValueError(f"Invalid context_strategy: {self.context_strategy}") + + def _introspect_datasource_schema(self, ds_config: DataSourceConfig) -> str: + """Introspect schema from a DataSourceConfig, choosing DbApiHook or DataFusion.""" + if self._is_dbapi_connection(ds_config.conn_id): + hook = self._get_db_hook(ds_config.conn_id) + dialect_name = getattr(hook, "dialect_name", "unknown") + schema_text = self._introspect_db_schema(hook, ds_config.table_name) + return ( + f"Source: {ds_config.conn_id} ({dialect_name})\nTable: {ds_config.table_name}\n{schema_text}" + ) + + return self._introspect_schema_from_datafusion(ds_config) + + @cached_property + def _df_engine(self): + try: + from airflow.providers.common.sql.datafusion.engine import DataFusionEngine + except ImportError as e: + from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException + + raise AirflowOptionalProviderFeatureException(e) + engine = DataFusionEngine() + return engine + + def _introspect_schema_from_datafusion(self, ds_config: DataSourceConfig): + self._df_engine.register_datasource(ds_config) + schema_text = self._df_engine.get_schema(ds_config.table_name) + + return f"Source: {ds_config.conn_id} \nFormat: ({ds_config.format})\nTable: {ds_config.table_name}\nColumns: {schema_text}" + + def _build_schema_context(self) -> str: + """Collect schemas from all configured sources each clearly.""" + sections: list[str] = [] + + for conn_id in self.db_conn_ids: + hook = self._get_db_hook(conn_id) + dialect_name = getattr(hook, "dialect_name", "unknown") + for table in self.table_names: + schema_text = self._introspect_db_schema(hook, table) + if schema_text: + sections.append(f"Source: {dialect_name}\nTable: {table}\n{schema_text}") Review Comment: The `conn_id` was dropped from the Source label — previously `Source: {conn_id} ({dialect_name})`, now just `Source: {dialect_name}`. If a user compares two PostgreSQL databases (e.g., `postgres_source` vs `postgres_replica`), the LLM sees: ``` Source: postgresql Table: orders Columns: ... Source: postgresql Table: orders Columns: ... ``` It can't tell which is which. Meanwhile `_introspect_datasource_schema` (line 212) still includes the `conn_id` in its label. Should be consistent — include `conn_id` in both paths: ```python sections.append(f"Source: {conn_id} ({dialect_name})\nTable: {table}\n{schema_text}") ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
