gopidesupavan commented on code in PR #62793:
URL: https://github.com/apache/airflow/pull/62793#discussion_r2880702329


##########
providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for cross-system schema drift detection powered by LLM 
reasoning."""
+
+from __future__ import annotations
+
+import json
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from pydantic import BaseModel, Field
+
+try:
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.compat.sdk import BaseHook
+
+if TYPE_CHECKING:
+    from airflow.providers.common.sql.config import DataSourceConfig
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+    from airflow.sdk import Context
+
+
+class SchemaMismatch(BaseModel):
+    """A single schema mismatch between data sources."""
+
+    source: str = Field(description="Source table")
+    target: str = Field(description="Target table")
+    column: str = Field(description="Column name where the mismatch was 
detected")
+    source_type: str = Field(description="Data type in the source system")
+    target_type: str = Field(description="Data type in the target system")
+    severity: str = Field(description="One of: critical, warning, info")
+    description: str = Field(description="Human-readable description of the 
mismatch")
+    suggested_action: str = Field(description="Recommended action to resolve 
the mismatch")
+    migration_query: str = Field(description="Provide migration query to 
resolve the mismatch")
+
+
+class SchemaCompareResult(BaseModel):
+    """Structured output from schema comparison."""
+
+    compatible: bool = Field(description="Whether the schemas are compatible 
for data loading")
+    mismatches: list[SchemaMismatch] = Field(default_factory=list)
+    summary: str = Field(description="High-level summary of the comparison")
+
+
+class LLMSchemaCompareOperator(LLMOperator):
+    """
+    Compare schemas across different database systems and detect drift using 
LLM reasoning.
+
+    The LLM handles complex cross-system type mapping that simple equality 
checks
+    miss (e.g., ``varchar(255)`` vs ``string``, ``timestamp`` vs 
``timestamptz``).
+
+    Accepts data sources via two patterns:
+
+    1. **data_sources** — a list of
+       :class:`~airflow.providers.common.sql.config.DataSourceConfig` for each
+       system. If the connection resolves to a
+       :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`, schema is
+       introspected via SQLAlchemy; otherwise DataFusion is used.
+    2. **db_conn_ids + table_names** — shorthand for comparing the same table
+       across multiple database connections (all must resolve to 
``DbApiHook``).
+
+    :param prompt: Instructions for the LLM on what to compare and flag.
+    :param llm_conn_id: Connection ID for the LLM provider.
+    :param model_id: Model identifier (e.g. ``"openai:gpt-5"``).
+    :param system_prompt: Additional instructions appended to the built-in
+        schema comparison prompt.
+    :param agent_params: Extra keyword arguments for the pydantic-ai ``Agent``.
+    :param data_sources: List of DataSourceConfig objects, one per system.
+    :param db_conn_ids: Connection IDs for databases to compare (used with
+        ``table_names``).
+    :param table_names: Tables to introspect from each ``db_conn_id``.
+    :param context_strategy: ``"basic"`` for column names and types only;
+        ``"full"`` to include primary keys, foreign keys, and indexes.
+        Default ``"full"``.
+    :param reasoning_mode: Strongly recommended — cross-system type mapping
+    benefits from step-by-step analysis.
+    """
+
+    template_fields: Sequence[str] = (
+        *LLMOperator.template_fields,
+        "db_conn_ids",
+        "table_names",
+    )
+
+    def __init__(
+        self,
+        *,
+        data_sources: list[DataSourceConfig] | None = None,
+        db_conn_ids: list[str] | None = None,
+        table_names: list[str] | None = None,
+        context_strategy: str | None = "full",
+        reasoning_mode: bool = True,
+        **kwargs: Any,
+    ) -> None:
+        kwargs.pop("output_type", None)
+        super().__init__(**kwargs)
+        self.data_sources = data_sources or []
+        self.db_conn_ids = db_conn_ids or []
+        self.table_names = table_names or []
+        self.context_strategy = context_strategy
+        self.reasoning_mode = reasoning_mode
+
+        if not self.data_sources and not self.db_conn_ids:
+            raise ValueError("Provide at least one of 'data_sources' or 
'db_conn_ids'.")
+
+        if self.db_conn_ids and not self.table_names:
+            raise ValueError("'table_names' is required when using 
'db_conn_ids'.")
+
+    @staticmethod
+    def _get_db_hook(conn_id: str) -> DbApiHook:
+        """Resolve a connection ID to a DbApiHook."""
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        connection = BaseHook.get_connection(conn_id)
+        hook = connection.get_hook()
+        if not isinstance(hook, DbApiHook):
+            raise ValueError(
+                f"Connection {conn_id!r} does not provide a DbApiHook. Got 
{type(hook).__name__}."
+            )
+        return hook
+
+    @staticmethod
+    def _is_dbapi_connection(conn_id: str) -> bool:
+        """Check whether a connection resolves to a DbApiHook."""
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        try:
+            connection = BaseHook.get_connection(conn_id)
+            hook = connection.get_hook()
+            return isinstance(hook, DbApiHook)
+        except Exception:

Review Comment:
   updated.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to