This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new da4ce85ace8 Add `LLMSQLQueryOperator` and `@task.llm_sql` to common.ai 
provider (#62599)
da4ce85ace8 is described below

commit da4ce85ace8e9ef64da84c6b9c8f2e4cc9865bd4
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Feb 28 15:45:55 2026 +0000

    Add `LLMSQLQueryOperator` and `@task.llm_sql` to common.ai provider (#62599)
    
    SQL query generation from natural language, inheriting from `LLMOperator`
    for shared LLM connection handling, `agent_params`, and `system_prompt`.
    
    - Schema introspection via `db_conn_id` + `table_names` using DbApiHook
    - Defense-in-depth SQL safety: AST validation via sqlglot (allowlist +
      single-statement enforcement + system prompt instructions)
    - User-provided `system_prompt` appended to built-in SQL safety prompt
      for domain-specific guidance (e.g. "prefer CTEs over subqueries")
    - `agent_params` inherited from LLMOperator (retries, temperature, etc.)
    - Generate-only mode: returns SQL string, does not execute
    
    Co-authored-by: GPK <[email protected]>
---
 dev/breeze/tests/test_selective_checks.py          |   2 +-
 docs/spelling_wordlist.txt                         |   3 +
 providers/common/ai/docs/index.rst                 |   1 +
 providers/common/ai/docs/operators/llm_sql.rst     |  87 +++++++
 providers/common/ai/provider.yaml                  |   4 +
 providers/common/ai/pyproject.toml                 |   9 +
 .../providers/common/ai/decorators/llm_sql.py      | 126 +++++++++++
 .../common/ai/example_dags/example_llm_sql.py      | 102 +++++++++
 .../providers/common/ai/get_provider_info.py       |  16 +-
 .../providers/common/ai/operators/llm_sql.py       | 217 ++++++++++++++++++
 .../airflow/providers/common/ai/utils/__init__.py  |  16 ++
 .../providers/common/ai/utils/sql_validation.py    | 100 ++++++++
 .../unit/common/ai/decorators/test_llm_sql.py      |  84 +++++++
 .../tests/unit/common/ai/operators/test_llm_sql.py | 252 +++++++++++++++++++++
 .../ai/tests/unit/common/ai/utils/__init__.py      |  16 ++
 .../unit/common/ai/utils/test_sql_validation.py    | 161 +++++++++++++
 .../src/airflow/providers/common/sql/hooks/sql.py  |  11 +
 .../sql/tests/unit/common/sql/hooks/test_sql.py    |  37 ++-
 18 files changed, 1239 insertions(+), 5 deletions(-)

diff --git a/dev/breeze/tests/test_selective_checks.py 
b/dev/breeze/tests/test_selective_checks.py
index bc5bb943a79..18e972c6ba1 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -2249,7 +2249,7 @@ def test_upgrade_to_newer_dependencies(
             
("providers/common/sql/src/airflow/providers/common/sql/common_sql_python.py",),
             {
                 "docs-list-as-string": "amazon apache.drill apache.druid 
apache.hive "
-                "apache.impala apache.pinot common.compat common.sql 
databricks elasticsearch "
+                "apache.impala apache.pinot common.ai common.compat common.sql 
databricks elasticsearch "
                 "exasol google jdbc microsoft.mssql mysql odbc openlineage "
                 "oracle pgvector postgres presto slack snowflake sqlite 
teradata trino vertica ydb",
             },
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 015f8d77d7b..e7bef2859af 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -502,6 +502,7 @@ del
 delim
 delitem
 deltalake
+denylist
 dep
 DependencyMixin
 deployable
@@ -1726,6 +1727,7 @@ sql
 sqla
 Sqlalchemy
 sqlalchemy
+sqlglot
 Sqlite
 sqlite
 sqlproxy
@@ -1803,6 +1805,7 @@ Subpath
 subpath
 subprocess
 subprocesses
+subqueries
 subquery
 SubscriberClient
 subscriptionId
diff --git a/providers/common/ai/docs/index.rst 
b/providers/common/ai/docs/index.rst
index dbd01f4b3d3..fffcfc494ba 100644
--- a/providers/common/ai/docs/index.rst
+++ b/providers/common/ai/docs/index.rst
@@ -121,6 +121,7 @@ You can install such cross-provider dependencies when 
installing from PyPI. For
 Dependent package                                                              
                                     Extra
 
==================================================================================================================
  =================
 `apache-airflow-providers-common-compat 
<https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_  
``common.compat``
+`apache-airflow-providers-common-sql 
<https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_        
``common.sql``
 
==================================================================================================================
  =================
 
 Downloading official packages
diff --git a/providers/common/ai/docs/operators/llm_sql.rst 
b/providers/common/ai/docs/operators/llm_sql.rst
new file mode 100644
index 00000000000..d2ccadb12bf
--- /dev/null
+++ b/providers/common/ai/docs/operators/llm_sql.rst
@@ -0,0 +1,87 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/operator:llm_sql_query:
+
+``LLMSQLQueryOperator``
+========================
+
+Use 
:class:`~airflow.providers.common.ai.operators.llm_sql.LLMSQLQueryOperator` to 
generate
+SQL queries from natural language using an LLM.
+
+The operator generates SQL but does not execute it. The generated query is 
returned
+as XCom and can be passed to ``SQLExecuteQueryOperator`` or used in downstream 
tasks.
+
+.. seealso::
+    :ref:`Connection configuration <howto/connection:pydantic_ai>`
+
+Basic Usage
+-----------
+
+Provide a natural language ``prompt`` and the operator generates a SQL query:
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
+    :language: python
+    :start-after: [START howto_operator_llm_sql_basic]
+    :end-before: [END howto_operator_llm_sql_basic]
+
+With Schema Introspection
+-------------------------
+
+Use ``db_conn_id`` and ``table_names`` to automatically include database schema
+in the LLM's context. This produces more accurate queries because the LLM knows
+the actual column names and types:
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
+    :language: python
+    :start-after: [START howto_operator_llm_sql_schema]
+    :end-before: [END howto_operator_llm_sql_schema]
+
+TaskFlow Decorator
+------------------
+
+The ``@task.llm_sql`` decorator lets you write a function that returns the
+prompt. The decorator handles LLM connection, schema introspection, SQL 
generation,
+and safety validation:
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
+    :language: python
+    :start-after: [START howto_decorator_llm_sql]
+    :end-before: [END howto_decorator_llm_sql]
+
+Dynamic Task Mapping
+--------------------
+
+Generate SQL for multiple prompts in parallel using ``expand()``:
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
+    :language: python
+    :start-after: [START howto_operator_llm_sql_expand]
+    :end-before: [END howto_operator_llm_sql_expand]
+
+SQL Safety Validation
+---------------------
+
+By default, the operator validates generated SQL using an allowlist approach:
+
+- Only ``SELECT``, ``UNION``, ``INTERSECT``, and ``EXCEPT`` statements are 
allowed.
+- Multi-statement SQL (semicolon-separated) is rejected.
+- Disallowed statements (``INSERT``, ``UPDATE``, ``DELETE``, ``DROP``, etc.) 
raise
+  :class:`~airflow.providers.common.ai.utils.sql_validation.SQLSafetyError`.
+
+You can disable validation with ``validate_sql=False`` or customize the allowed
+statement types with ``allowed_sql_types``.
diff --git a/providers/common/ai/provider.yaml 
b/providers/common/ai/provider.yaml
index e2c4ca8fe60..7ef0acd7c3c 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -33,6 +33,7 @@ integrations:
     external-doc-url: 
https://airflow.apache.org/docs/apache-airflow-providers-common-ai/
     how-to-guide:
       - /docs/apache-airflow-providers-common-ai/operators/llm.rst
+      - /docs/apache-airflow-providers-common-ai/operators/llm_sql.rst
     tags: [ai]
   - integration-name: Pydantic AI
     external-doc-url: https://ai.pydantic.dev/
@@ -61,7 +62,10 @@ operators:
   - integration-name: Common AI
     python-modules:
       - airflow.providers.common.ai.operators.llm
+      - airflow.providers.common.ai.operators.llm_sql
 
 task-decorators:
   - class-name: airflow.providers.common.ai.decorators.llm.llm_task
     name: llm
+  - class-name: airflow.providers.common.ai.decorators.llm_sql.llm_sql_task
+    name: llm_sql
diff --git a/providers/common/ai/pyproject.toml 
b/providers/common/ai/pyproject.toml
index ae88717446a..c80c500a50a 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -72,6 +72,13 @@ dependencies = [
 "common.compat" = [
     "apache-airflow-providers-common-compat"
 ]
+"sql" = [
+    "apache-airflow-providers-common-sql",
+    "sqlglot>=26.0.0",
+]
+"common.sql" = [
+    "apache-airflow-providers-common-sql"
+]
 
 [dependency-groups]
 dev = [
@@ -79,7 +86,9 @@ dev = [
     "apache-airflow-task-sdk",
     "apache-airflow-devel-common",
     "apache-airflow-providers-common-compat",
+    "apache-airflow-providers-common-sql",
     # Additional devel dependencies (do not remove this line and add extra 
development dependencies)
+    "sqlglot>=26.0.0",
 ]
 
 # To build docs:
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_sql.py 
b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_sql.py
new file mode 100644
index 00000000000..25fa57b5d14
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/decorators/llm_sql.py
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TaskFlow decorator for LLM SQL generation.
+
+The user writes a function that **returns the prompt**. The decorator handles
+the LLM call, schema introspection, and safety validation. The decorated task's
+XCom output is the generated SQL string.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Collection, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from airflow.providers.common.ai.operators.llm_sql import LLMSQLQueryOperator
+from airflow.providers.common.compat.sdk import (
+    DecoratedOperator,
+    TaskDecorator,
+    context_merge,
+    task_decorator_factory,
+)
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
+from airflow.utils.operator_helpers import determine_kwargs
+
+if TYPE_CHECKING:
+    from airflow.sdk import Context
+
+
+class _LLMSQLDecoratedOperator(DecoratedOperator, LLMSQLQueryOperator):
+    """
+    Wraps a callable that returns a prompt for LLM SQL generation.
+
+    The user function is called at execution time to produce the prompt string.
+    All other parameters (``llm_conn_id``, ``db_conn_id``, ``table_names``, 
etc.)
+    are passed through to 
:class:`~airflow.providers.common.ai.operators.llm_sql.LLMSQLQueryOperator`.
+
+    :param python_callable: A reference to a callable that returns the prompt 
string.
+    :param op_args: Positional arguments for the callable.
+    :param op_kwargs: Keyword arguments for the callable.
+    """
+
+    template_fields: Sequence[str] = (
+        *DecoratedOperator.template_fields,
+        *LLMSQLQueryOperator.template_fields,
+    )
+    template_fields_renderers: ClassVar[dict[str, str]] = {
+        **DecoratedOperator.template_fields_renderers,
+    }
+
+    custom_operator_name: str = "@task.llm_sql"
+
+    def __init__(
+        self,
+        *,
+        python_callable: Callable,
+        op_args: Collection[Any] | None = None,
+        op_kwargs: Mapping[str, Any] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            python_callable=python_callable,
+            op_args=op_args,
+            op_kwargs=op_kwargs,
+            prompt=SET_DURING_EXECUTION,
+            **kwargs,
+        )
+
+    def execute(self, context: Context) -> Any:
+        context_merge(context, self.op_kwargs)
+        kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+
+        self.prompt = self.python_callable(*self.op_args, **kwargs)
+
+        if not isinstance(self.prompt, str) or not self.prompt.strip():
+            raise TypeError("The returned value from the @task.llm_sql 
callable must be a non-empty string.")
+
+        self.render_template_fields(context)
+        # Call LLMSQLQueryOperator.execute directly, not super().execute(),
+        # because we need to skip DecoratedOperator.execute — the callable
+        # invocation is already handled above.
+        return LLMSQLQueryOperator.execute(self, context)
+
+
+def llm_sql_task(
+    python_callable: Callable | None = None,
+    **kwargs,
+) -> TaskDecorator:
+    """
+    Wrap a function that returns a natural language prompt into an LLM SQL 
task.
+
+    The function body constructs the prompt (can use Airflow context, XCom, 
etc.).
+    The decorator handles: LLM connection, schema introspection, SQL 
generation,
+    and safety validation.
+
+    Usage::
+
+        @task.llm_sql(
+            llm_conn_id="openai_default",
+            db_conn_id="postgres_default",
+            table_names=["customers", "orders"],
+        )
+        def build_query(ds=None):
+            return f"Find top 10 customers by revenue in {ds}"
+
+    :param python_callable: Function to decorate.
+    """
+    return task_decorator_factory(
+        python_callable=python_callable,
+        decorated_operator_class=_LLMSQLDecoratedOperator,
+        **kwargs,
+    )
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
new file mode 100644
index 00000000000..2a7e52f5b6a
--- /dev/null
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAGs demonstrating LLMSQLQueryOperator usage."""
+
+from __future__ import annotations
+
+from airflow.providers.common.ai.operators.llm_sql import LLMSQLQueryOperator
+from airflow.providers.common.compat.sdk import dag, task
+
+
+# [START howto_operator_llm_sql_basic]
+@dag
+def example_llm_sql_basic():
+    LLMSQLQueryOperator(
+        task_id="generate_sql",
+        prompt="Find the top 10 customers by total revenue",
+        llm_conn_id="pydantic_ai_default",
+        schema_context=(
+            "Table: customers\n"
+            "Columns: id INT, name TEXT, email TEXT\n\n"
+            "Table: orders\n"
+            "Columns: id INT, customer_id INT, total DECIMAL, created_at 
TIMESTAMP"
+        ),
+    )
+
+
+# [END howto_operator_llm_sql_basic]
+
+example_llm_sql_basic()
+
+
+# [START howto_operator_llm_sql_schema]
+@dag
+def example_llm_sql_schema_introspection():
+    LLMSQLQueryOperator(
+        task_id="generate_sql",
+        prompt="Calculate monthly revenue for 2024",
+        llm_conn_id="pydantic_ai_default",
+        db_conn_id="postgres_default",
+        table_names=["orders", "customers"],
+        dialect="postgres",
+    )
+
+
+# [END howto_operator_llm_sql_schema]
+
+example_llm_sql_schema_introspection()
+
+
+# [START howto_decorator_llm_sql]
+@dag
+def example_llm_sql_decorator():
+    @task.llm_sql(
+        llm_conn_id="pydantic_ai_default",
+        schema_context="Table: users\nColumns: id INT, name TEXT, signup_date 
DATE",
+    )
+    def build_churn_query(ds=None):
+        return f"Find users who signed up before {ds} and have no orders"
+
+    build_churn_query()
+
+
+# [END howto_decorator_llm_sql]
+
+example_llm_sql_decorator()
+
+
+# [START howto_operator_llm_sql_expand]
+@dag
+def example_llm_sql_expand():
+    LLMSQLQueryOperator.partial(
+        task_id="generate_sql",
+        llm_conn_id="pydantic_ai_default",
+        schema_context=(
+            "Table: orders\nColumns: id INT, customer_id INT, total DECIMAL, 
created_at TIMESTAMP"
+        ),
+    ).expand(
+        prompt=[
+            "Total revenue by month",
+            "Top 10 customers by order count",
+            "Average order value by day of week",
+        ]
+    )
+
+
+# [END howto_operator_llm_sql_expand]
+
+example_llm_sql_expand()
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py 
b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
index f7ae5317401..091f4b849ac 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
@@ -30,7 +30,10 @@ def get_provider_info():
             {
                 "integration-name": "Common AI",
                 "external-doc-url": 
"https://airflow.apache.org/docs/apache-airflow-providers-common-ai/";,
-                "how-to-guide": 
["/docs/apache-airflow-providers-common-ai/operators/llm.rst"],
+                "how-to-guide": [
+                    
"/docs/apache-airflow-providers-common-ai/operators/llm.rst",
+                    
"/docs/apache-airflow-providers-common-ai/operators/llm_sql.rst",
+                ],
                 "tags": ["ai"],
             },
             {
@@ -60,9 +63,16 @@ def get_provider_info():
             }
         ],
         "operators": [
-            {"integration-name": "Common AI", "python-modules": 
["airflow.providers.common.ai.operators.llm"]}
+            {
+                "integration-name": "Common AI",
+                "python-modules": [
+                    "airflow.providers.common.ai.operators.llm",
+                    "airflow.providers.common.ai.operators.llm_sql",
+                ],
+            }
         ],
         "task-decorators": [
-            {"class-name": 
"airflow.providers.common.ai.decorators.llm.llm_task", "name": "llm"}
+            {"class-name": 
"airflow.providers.common.ai.decorators.llm.llm_task", "name": "llm"},
+            {"class-name": 
"airflow.providers.common.ai.decorators.llm_sql.llm_sql_task", "name": 
"llm_sql"},
         ],
     }
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py 
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
new file mode 100644
index 00000000000..4501b4c1c63
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
@@ -0,0 +1,217 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for generating SQL queries from natural language using LLMs."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.ai.utils.sql_validation import (
+        DEFAULT_ALLOWED_TYPES,
+        validate_sql as _validate_sql,
+    )
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.compat.sdk import BaseHook
+
+if TYPE_CHECKING:
+    from sqlglot import exp
+
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+    from airflow.sdk import Context
+
+# SQLAlchemy dialect_name → sqlglot dialect mapping for names that differ.
+_SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = {
+    "postgresql": "postgres",
+    "mssql": "tsql",
+}
+
+
+class LLMSQLQueryOperator(LLMOperator):
+    """
+    Generate SQL queries from natural language using an LLM.
+
+    Inherits from 
:class:`~airflow.providers.common.ai.operators.llm.LLMOperator`
+    for LLM access and optionally uses a
+    :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`
+    for schema introspection. The operator generates SQL but does not execute 
it —
+    the generated SQL is returned as XCom and can be passed to
+    ``SQLExecuteQueryOperator`` or used in downstream tasks.
+
+    When ``system_prompt`` is provided, it is appended to the built-in SQL 
safety
+    instructions — use it for domain-specific guidance (e.g. "prefer CTEs over
+    subqueries", "always use LEFT JOINs").
+
+    :param prompt: Natural language description of the desired query.
+    :param llm_conn_id: Connection ID for the LLM provider.
+    :param model_id: Model identifier (e.g. ``"openai:gpt-4o"``).
+        Overrides the model stored in the connection's extra field.
+    :param system_prompt: Additional instructions appended to the built-in SQL
+        safety prompt. Use for domain-specific guidance.
+    :param agent_params: Additional keyword arguments passed to the pydantic-ai
+        ``Agent`` constructor (e.g. ``retries``, ``model_settings``).
+    :param db_conn_id: Connection ID for database schema introspection.
+        The connection must resolve to a ``DbApiHook``.
+    :param table_names: Tables to include in the LLM's schema context.
+        Used with ``db_conn_id`` for automatic introspection.
+    :param schema_context: Manual schema context string. When provided,
+        this is used instead of ``db_conn_id`` introspection.
+    :param validate_sql: Whether to validate generated SQL via AST parsing.
+        Default ``True`` (safe by default).
+    :param allowed_sql_types: SQL statement types to allow.
+        Default: ``(Select, Union, Intersect, Except)``.
+    :param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.).
+        Auto-detected from the database hook if not set.
+    """
+
+    template_fields: Sequence[str] = (
+        *LLMOperator.template_fields,
+        "db_conn_id",
+        "table_names",
+        "schema_context",
+    )
+
+    def __init__(
+        self,
+        *,
+        db_conn_id: str | None = None,
+        table_names: list[str] | None = None,
+        schema_context: str | None = None,
+        validate_sql: bool = True,
+        allowed_sql_types: tuple[type[exp.Expression], ...] = 
DEFAULT_ALLOWED_TYPES,
+        dialect: str | None = None,
+        **kwargs: Any,
+    ) -> None:
+        kwargs.pop("output_type", None)  # SQL operator always returns str
+        super().__init__(**kwargs)
+        self.db_conn_id = db_conn_id
+        self.table_names = table_names
+        self.schema_context = schema_context
+        self.validate_sql = validate_sql
+        self.allowed_sql_types = allowed_sql_types
+        self.dialect = dialect
+
+    @cached_property
+    def db_hook(self) -> DbApiHook | None:
+        """Return DbApiHook for the configured database connection, or None."""
+        if not self.db_conn_id:
+            return None
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        connection = BaseHook.get_connection(self.db_conn_id)
+        hook = connection.get_hook()
+        if not isinstance(hook, DbApiHook):
+            raise ValueError(
+                f"Connection {self.db_conn_id!r} does not provide a DbApiHook. 
Got {type(hook).__name__}."
+            )
+        return hook
+
+    def execute(self, context: Context) -> str:
+        schema_info = self._get_schema_context()
+        full_system_prompt = self._build_system_prompt(schema_info)
+
+        agent = self.llm_hook.create_agent(
+            output_type=str, instructions=full_system_prompt, 
**self.agent_params
+        )
+        result = agent.run_sync(self.prompt)
+        sql = self._strip_llm_output(result.output)
+
+        if self.validate_sql:
+            _validate_sql(sql, allowed_types=self.allowed_sql_types, 
dialect=self._resolved_dialect)
+
+        self.log.info("Generated SQL:\n%s", sql)
+        return sql
+
+    @staticmethod
+    def _strip_llm_output(raw: str) -> str:
+        """Strip whitespace and markdown code fences from LLM output."""
+        text = raw.strip()
+        if text.startswith("```"):
+            lines = text.split("\n")
+            # Remove opening fence (```sql, ```, etc.) and closing fence
+            if len(lines) >= 2:
+                end = -1 if lines[-1].strip().startswith("```") else len(lines)
+                text = "\n".join(lines[1:end]).strip()
+        return text
+
+    def _get_schema_context(self) -> str:
+        """Return schema context from manual override or database 
introspection."""
+        if self.schema_context:
+            return self.schema_context
+        if self.db_hook and self.table_names:
+            return self._introspect_schemas()
+        return ""
+
+    def _introspect_schemas(self) -> str:
+        """Build schema context by introspecting tables via the database 
hook."""
+        parts: list[str] = []
+        for table in self.table_names or []:
+            columns = self.db_hook.get_table_schema(table)  # type: 
ignore[union-attr]
+            if not columns:
+                self.log.warning("Table %r returned no columns — it may not 
exist.", table)
+                continue
+            col_info = ", ".join(f"{c['name']} {c['type']}" for c in columns)
+            parts.append(f"Table: {table}\nColumns: {col_info}")
+        if not parts and self.table_names:
+            raise ValueError(
+                f"None of the requested tables ({self.table_names}) returned 
schema information. "
+                "Check that the table names are correct and the database 
connection has access."
+            )
+        return "\n\n".join(parts)
+
+    def _build_system_prompt(self, schema_info: str) -> str:
+        """Construct the system prompt for the LLM."""
+        dialect_label = self._resolved_dialect or "SQL"
+        prompt = (
+            f"You are a {dialect_label} expert. "
+            "Generate a single SQL query based on the user's request.\n"
+            "Return ONLY the SQL query, no explanation or markdown.\n"
+        )
+        if schema_info:
+            prompt += f"\nAvailable schema:\n{schema_info}\n"
+        prompt += (
+            "\nRules:\n"
+            "- Generate only SELECT queries (including CTEs, JOINs, 
subqueries, UNION)\n"
+            "- Never generate data modification statements "
+            "(INSERT, UPDATE, DELETE, DROP, etc.)\n"
+            "- Use proper syntax for the specified dialect\n"
+        )
+        if self.system_prompt:
+            prompt += f"\nAdditional instructions:\n{self.system_prompt}\n"
+        return prompt
+
+    @cached_property
+    def _resolved_dialect(self) -> str | None:
+        """
+        Resolve the SQL dialect from explicit parameter or database hook.
+
+        Normalizes SQLAlchemy dialect names to sqlglot equivalents
+        (e.g. ``postgresql`` → ``postgres``).
+        """
+        raw = self.dialect
+        if not raw and self.db_hook and hasattr(self.db_hook, "dialect_name"):
+            raw = self.db_hook.dialect_name
+        if raw:
+            return _SQLALCHEMY_TO_SQLGLOT_DIALECT.get(raw, raw)
+        return None
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/utils/__init__.py 
b/providers/common/ai/src/airflow/providers/common/ai/utils/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/utils/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py 
b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
new file mode 100644
index 00000000000..22361949f3b
--- /dev/null
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+SQL safety validation for LLM-generated queries.
+
+Uses an allowlist approach: only explicitly permitted statement types pass.
+This is safer than a denylist because new/unexpected statement types
+(INSERT, UPDATE, MERGE, TRUNCATE, COPY, etc.) are blocked by default.
+"""
+
+from __future__ import annotations
+
+import sqlglot
+from sqlglot import exp
+from sqlglot.errors import ErrorLevel
+
+# Allowlist: only these top-level statement types pass validation by default.
+# - Select: plain queries and CTE-wrapped queries (WITH ... AS ... SELECT is 
parsed
+#   as Select with a `with` clause property — still a Select node at the top 
level)
+# - Union/Intersect/Except: set operations on SELECT results
+DEFAULT_ALLOWED_TYPES: tuple[type[exp.Expression], ...] = (
+    exp.Select,
+    exp.Union,
+    exp.Intersect,
+    exp.Except,
+)
+
+
+class SQLSafetyError(Exception):
+    """Generated SQL failed safety validation."""
+
+
+def validate_sql(
+    sql: str,
+    *,
+    allowed_types: tuple[type[exp.Expression], ...] | None = None,
+    dialect: str | None = None,
+    allow_multiple_statements: bool = False,
+) -> list[exp.Expression]:
+    """
+    Parse SQL and verify all statements are in the allowed types list.
+
+    By default, only a single SELECT-family statement is allowed. 
Multi-statement
+    SQL (separated by semicolons) is rejected unless 
``allow_multiple_statements=True``,
+    because multi-statement inputs can hide dangerous operations after a 
benign SELECT.
+
+    Returns parsed statements on success, raises :class:`SQLSafetyError` on 
violation.
+
+    :param sql: SQL string to validate.
+    :param allowed_types: Tuple of sqlglot expression types to permit.
+        Defaults to ``(Select, Union, Intersect, Except)``.
+    :param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.).
+    :param allow_multiple_statements: Whether to allow multiple 
semicolon-separated
+        statements. Default ``False``.
+    :return: List of parsed sqlglot Expression objects.
+    :raises SQLSafetyError: If the SQL is empty, contains disallowed statement 
types,
+        or has multiple statements when not permitted.
+    """
+    if not sql or not sql.strip():
+        raise SQLSafetyError("Empty SQL input.")
+
+    types = allowed_types or DEFAULT_ALLOWED_TYPES
+
+    try:
+        statements = sqlglot.parse(sql, dialect=dialect, 
error_level=ErrorLevel.RAISE)
+    except sqlglot.errors.ParseError as e:
+        raise SQLSafetyError(f"SQL parse error: {e}") from e
+
+    # sqlglot.parse can return [None] for empty input
+    parsed: list[exp.Expression] = [s for s in statements if s is not None]
+    if not parsed:
+        raise SQLSafetyError("Empty SQL input.")
+
+    if not allow_multiple_statements and len(parsed) > 1:
+        raise SQLSafetyError(
+            f"Multiple statements detected ({len(parsed)}). Only single 
statements are allowed by default."
+        )
+
+    for stmt in parsed:
+        if not isinstance(stmt, types):
+            allowed_names = ", ".join(t.__name__ for t in types)
+            raise SQLSafetyError(
+                f"Statement type '{type(stmt).__name__}' is not allowed. 
Allowed types: {allowed_names}"
+            )
+
+    return parsed
diff --git 
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py 
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
new file mode 100644
index 00000000000..849e9c1e1fd
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.decorators.llm_sql import 
_LLMSQLDecoratedOperator
+
+
+class TestLLMSQLDecoratedOperator:
+    def test_custom_operator_name(self):
+        assert _LLMSQLDecoratedOperator.custom_operator_name == "@task.llm_sql"
+
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_calls_callable_and_uses_result_as_prompt(self, 
mock_hook_cls):
+        """The user's callable return value becomes the LLM prompt."""
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = "SELECT 1"
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        def my_prompt_fn():
+            return "Get all users"
+
+        op = _LLMSQLDecoratedOperator(task_id="test", 
python_callable=my_prompt_fn, llm_conn_id="my_llm")
+        result = op.execute(context={})
+
+        assert result == "SELECT 1"
+        assert op.prompt == "Get all users"
+        mock_agent.run_sync.assert_called_once_with("Get all users")
+
+    @pytest.mark.parametrize(
+        "return_value",
+        [42, "", "   ", None],
+        ids=["non-string", "empty", "whitespace", "none"],
+    )
+    def test_execute_raises_on_invalid_prompt(self, return_value):
+        """TypeError when the callable returns a non-string or blank string."""
+        op = _LLMSQLDecoratedOperator(
+            task_id="test",
+            python_callable=lambda: return_value,
+            llm_conn_id="my_llm",
+        )
+        with pytest.raises(TypeError, match="non-empty string"):
+            op.execute(context={})
+
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls):
+        """op_kwargs are resolved by the callable to build the prompt."""
+        mock_agent = MagicMock(spec=["run_sync"])
+        mock_result = MagicMock(spec=["output"])
+        mock_result.output = "SELECT 1"
+        mock_agent.run_sync.return_value = mock_result
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        def my_prompt_fn(table_name):
+            return f"Get all rows from {table_name}"
+
+        op = _LLMSQLDecoratedOperator(
+            task_id="test",
+            python_callable=my_prompt_fn,
+            llm_conn_id="my_llm",
+            op_kwargs={"table_name": "users"},
+        )
+        op.execute(context={"task_instance": MagicMock()})
+
+        assert op.prompt == "Get all rows from users"
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
new file mode 100644
index 00000000000..97d943c14c0
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
@@ -0,0 +1,252 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, PropertyMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.operators.llm_sql import LLMSQLQueryOperator
+from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+
+
+def _make_mock_agent(output: str):
+    """Create a mock agent that returns the given output string."""
+    mock_result = MagicMock(spec=["output"])
+    mock_result.output = output
+    mock_agent = MagicMock(spec=["run_sync"])
+    mock_agent.run_sync.return_value = mock_result
+    return mock_agent
+
+
+class TestStripLLMOutput:
+    @pytest.mark.parametrize(
+        ("raw", "expected"),
+        (
+            pytest.param("SELECT 1", "SELECT 1", id="plain_sql"),
+            pytest.param("  SELECT 1  ", "SELECT 1", 
id="leading_trailing_whitespace"),
+            pytest.param("```sql\nSELECT 1\n```", "SELECT 1", 
id="sql_code_fence"),
+            pytest.param("```\nSELECT 1\n```", "SELECT 1", 
id="bare_code_fence"),
+            pytest.param("```SQL\nSELECT 1\n```", "SELECT 1", 
id="uppercase_language_tag"),
+            pytest.param(
+                "```sql\nSELECT id\nFROM users\nWHERE active\n```",
+                "SELECT id\nFROM users\nWHERE active",
+                id="multiline_query",
+            ),
+            pytest.param(
+                "```sql\nSELECT 1\n",
+                "SELECT 1",
+                id="missing_closing_fence",
+            ),
+        ),
+    )
+    def test_strip_llm_output(self, raw, expected):
+        assert LLMSQLQueryOperator._strip_llm_output(raw) == expected
+
+
+class TestLLMSQLQueryOperator:
+    def test_inherits_from_llm_operator(self):
+        from airflow.providers.common.ai.operators.llm import LLMOperator
+
+        assert issubclass(LLMSQLQueryOperator, LLMOperator)
+
+    def test_template_fields_include_parent_and_sql_specific(self):
+        expected = {
+            "prompt",
+            "llm_conn_id",
+            "model_id",
+            "system_prompt",
+            "agent_params",
+            "db_conn_id",
+            "table_names",
+            "schema_context",
+        }
+        assert set(LLMSQLQueryOperator.template_fields) == expected
+
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_with_schema_context(self, mock_hook_cls):
+        """Operator uses schema_context and returns generated SQL."""
+        mock_agent = _make_mock_agent("SELECT id, name FROM users WHERE active 
= true")
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="Get active users",
+            llm_conn_id="my_llm",
+            schema_context="Table: users\nColumns: id INT, name TEXT, active 
BOOLEAN",
+        )
+        result = op.execute(context=MagicMock())
+
+        assert result == "SELECT id, name FROM users WHERE active = true"
+        mock_agent.run_sync.assert_called_once_with("Get active users")
+
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_validation_blocks_unsafe_sql(self, mock_hook_cls):
+        """Validation catches unsafe SQL generated by the LLM."""
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("DROP TABLE users")
+
+        op = LLMSQLQueryOperator(task_id="test", prompt="Delete everything", 
llm_conn_id="my_llm")
+
+        with pytest.raises(SQLSafetyError, match="not allowed"):
+            op.execute(context=MagicMock())
+
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_validation_disabled(self, mock_hook_cls):
+        """When validate_sql=False, unsafe SQL is returned without checks."""
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("DROP TABLE users")
+
+        op = LLMSQLQueryOperator(task_id="test", prompt="Drop it", 
llm_conn_id="my_llm", validate_sql=False)
+        result = op.execute(context=MagicMock())
+
+        assert result == "DROP TABLE users"
+
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_execute_passes_agent_params(self, mock_hook_cls):
+        """agent_params inherited from LLMOperator are unpacked into 
create_agent."""
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
+
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            agent_params={"retries": 3, "model_settings": {"temperature": 0}},
+        )
+        op.execute(context=MagicMock())
+
+        create_agent_call = mock_hook_cls.return_value.create_agent.call_args
+        assert create_agent_call[1]["retries"] == 3
+        assert create_agent_call[1]["model_settings"] == {"temperature": 0}
+
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_system_prompt_appended_to_sql_instructions(self, mock_hook_cls):
+        """User-provided system_prompt is appended to built-in SQL safety 
prompt."""
+        mock_hook_cls.return_value.create_agent.return_value = 
_make_mock_agent("SELECT 1")
+
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            system_prompt="Always use LEFT JOINs.",
+        )
+        op.execute(context=MagicMock())
+
+        instructions = 
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+        assert "Always use LEFT JOINs." in instructions
+        # Built-in SQL safety prompt should still be present
+        assert "Generate only SELECT queries" in instructions
+        assert "Never generate data modification" in instructions
+
+
+class TestLLMSQLQueryOperatorSchemaIntrospection:
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    def test_introspect_schemas_via_db_hook(self, mock_hook_cls):
+        """db_conn_id + table_names triggers schema introspection."""
+        mock_agent = _make_mock_agent("SELECT id FROM users")
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+        mock_db_hook.get_table_schema.return_value = [
+            {"name": "id", "type": "INTEGER"},
+            {"name": "name", "type": "VARCHAR"},
+        ]
+        mock_db_hook.dialect_name = "postgresql"
+
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="Get user IDs",
+            llm_conn_id="my_llm",
+            db_conn_id="pg_default",
+            table_names=["users"],
+        )
+
+        with patch.object(type(op), "db_hook", new_callable=PropertyMock, 
return_value=mock_db_hook):
+            result = op.execute(context=MagicMock())
+
+        assert result == "SELECT id FROM users"
+        mock_db_hook.get_table_schema.assert_called_once_with("users")
+
+        # Verify the system prompt contains the schema info
+        instructions = 
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+        assert "users" in instructions
+        assert "id INTEGER" in instructions
+
+    def test_introspect_raises_when_no_tables_found(self):
+        """Raise ValueError when all requested tables return empty columns."""
+        mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+        mock_db_hook.get_table_schema.return_value = []
+
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            db_conn_id="pg_default",
+            table_names=["nonexistent_table"],
+        )
+
+        with patch.object(type(op), "db_hook", new_callable=PropertyMock, 
return_value=mock_db_hook):
+            with pytest.raises(ValueError, match="None of the requested 
tables"):
+                op._introspect_schemas()
+
+    def test_schema_context_overrides_introspection(self):
+        """schema_context takes priority over db_conn_id introspection."""
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            db_conn_id="pg_default",
+            table_names=["users"],
+            schema_context="My custom schema info",
+        )
+        assert op._get_schema_context() == "My custom schema info"
+
+
+class TestLLMSQLQueryOperatorDialect:
+    def test_resolved_dialect_from_param(self):
+        op = LLMSQLQueryOperator(task_id="test", prompt="test", 
llm_conn_id="my_llm", dialect="mysql")
+        assert op._resolved_dialect == "mysql"
+
+    def test_resolved_dialect_from_db_hook_normalized(self):
+        """SQLAlchemy's 'postgresql' is normalized to sqlglot's 'postgres'."""
+        mock_db_hook = MagicMock(spec=["dialect_name"])
+        mock_db_hook.dialect_name = "postgresql"
+
+        op = LLMSQLQueryOperator(task_id="test", prompt="test", 
llm_conn_id="my_llm", db_conn_id="pg_default")
+        with patch.object(type(op), "db_hook", new_callable=PropertyMock, 
return_value=mock_db_hook):
+            assert op._resolved_dialect == "postgres"
+
+    def test_resolved_dialect_none_when_nothing_set(self):
+        op = LLMSQLQueryOperator(task_id="test", prompt="test", 
llm_conn_id="my_llm")
+        assert op._resolved_dialect is None
+
+
+class TestLLMSQLQueryOperatorDbHook:
+    @patch("airflow.providers.common.ai.operators.llm_sql.BaseHook", 
autospec=True)
+    def test_db_hook_returns_none_without_conn_id(self, mock_base_hook):
+        op = LLMSQLQueryOperator(task_id="test", prompt="test", 
llm_conn_id="my_llm")
+        assert op.db_hook is None
+        mock_base_hook.get_connection.assert_not_called()
+
+    @patch("airflow.providers.common.ai.operators.llm_sql.BaseHook", 
autospec=True)
+    def test_db_hook_raises_for_non_dbapi_hook(self, mock_base_hook):
+        mock_conn = MagicMock(spec=["get_hook"])
+        mock_conn.get_hook.return_value = MagicMock()  # Not a DbApiHook
+        mock_base_hook.get_connection.return_value = mock_conn
+
+        op = LLMSQLQueryOperator(task_id="test", prompt="test", 
llm_conn_id="my_llm", db_conn_id="bad_conn")
+
+        with pytest.raises(ValueError, match="does not provide a DbApiHook"):
+            _ = op.db_hook
diff --git a/providers/common/ai/tests/unit/common/ai/utils/__init__.py 
b/providers/common/ai/tests/unit/common/ai/utils/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/utils/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git 
a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py 
b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
new file mode 100644
index 00000000000..6c6b218cadb
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py
@@ -0,0 +1,161 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+from sqlglot import exp
+
+from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError, 
validate_sql
+
+
+class TestValidateSQLAllowed:
+    """Statements that should pass validation with default settings."""
+
+    def test_simple_select(self):
+        result = validate_sql("SELECT 1")
+        assert len(result) == 1
+        assert isinstance(result[0], exp.Select)
+
+    def test_select_from_table(self):
+        result = validate_sql("SELECT id, name FROM users WHERE active = true")
+        assert len(result) == 1
+        assert isinstance(result[0], exp.Select)
+
+    def test_select_with_join(self):
+        result = validate_sql("SELECT u.name, o.total FROM users u JOIN orders 
o ON u.id = o.user_id")
+        assert len(result) == 1
+
+    def test_select_with_cte(self):
+        result = validate_sql("WITH top_users AS (SELECT id FROM users LIMIT 
10) SELECT * FROM top_users")
+        assert len(result) == 1
+        assert isinstance(result[0], exp.Select)
+
+    def test_select_with_subquery(self):
+        result = validate_sql("SELECT * FROM users WHERE id IN (SELECT user_id 
FROM orders)")
+        assert len(result) == 1
+
+    def test_union(self):
+        result = validate_sql("SELECT 1 UNION SELECT 2")
+        assert len(result) == 1
+        assert isinstance(result[0], exp.Union)
+
+    def test_union_all(self):
+        result = validate_sql("SELECT 1 UNION ALL SELECT 2")
+        assert len(result) == 1
+        assert isinstance(result[0], exp.Union)
+
+    def test_intersect(self):
+        result = validate_sql("SELECT 1 INTERSECT SELECT 1")
+        assert len(result) == 1
+        assert isinstance(result[0], exp.Intersect)
+
+    def test_except(self):
+        result = validate_sql("SELECT 1 EXCEPT SELECT 2")
+        assert len(result) == 1
+        assert isinstance(result[0], exp.Except)
+
+
+class TestValidateSQLBlocked:
+    """Statements that should be blocked with default settings."""
+
+    def test_insert_blocked(self):
+        with pytest.raises(SQLSafetyError, match="Insert.*not allowed"):
+            validate_sql("INSERT INTO users (name) VALUES ('test')")
+
+    def test_update_blocked(self):
+        with pytest.raises(SQLSafetyError, match="Update.*not allowed"):
+            validate_sql("UPDATE users SET name = 'test' WHERE id = 1")
+
+    def test_delete_blocked(self):
+        with pytest.raises(SQLSafetyError, match="Delete.*not allowed"):
+            validate_sql("DELETE FROM users WHERE id = 1")
+
+    def test_drop_blocked(self):
+        with pytest.raises(SQLSafetyError, match="Drop.*not allowed"):
+            validate_sql("DROP TABLE users")
+
+    def test_create_blocked(self):
+        with pytest.raises(SQLSafetyError, match="Create.*not allowed"):
+            validate_sql("CREATE TABLE test (id INT)")
+
+    def test_alter_blocked(self):
+        with pytest.raises(SQLSafetyError, match="Alter.*not allowed"):
+            validate_sql("ALTER TABLE users ADD COLUMN email TEXT")
+
+    def test_truncate_blocked(self):
+        with pytest.raises(SQLSafetyError, match="not allowed"):
+            validate_sql("TRUNCATE TABLE users")
+
+
+class TestValidateSQLMultiStatement:
+    """Multi-statement SQL should be blocked by default."""
+
+    def test_multiple_statements_blocked_by_default(self):
+        with pytest.raises(SQLSafetyError, match="Multiple statements 
detected"):
+            validate_sql("SELECT 1; SELECT 2")
+
+    def test_multiple_statements_allowed_when_opted_in(self):
+        result = validate_sql("SELECT 1; SELECT 2", 
allow_multiple_statements=True)
+        assert len(result) == 2
+
+    def test_dangerous_hidden_after_select(self):
+        """Multi-statement blocks even if first statement is safe."""
+        with pytest.raises(SQLSafetyError, match="Multiple statements"):
+            validate_sql("SELECT 1; DROP TABLE users")
+
+    def test_multi_statement_still_validates_types(self):
+        """Even when multi-statement is allowed, types are still checked."""
+        with pytest.raises(SQLSafetyError, match="Drop.*not allowed"):
+            validate_sql("SELECT 1; DROP TABLE users", 
allow_multiple_statements=True)
+
+
+class TestValidateSQLEdgeCases:
+    """Edge cases and error handling."""
+
+    def test_empty_string_raises(self):
+        with pytest.raises(SQLSafetyError, match="Empty SQL"):
+            validate_sql("")
+
+    def test_whitespace_only_raises(self):
+        with pytest.raises(SQLSafetyError, match="Empty SQL"):
+            validate_sql("   \n\t  ")
+
+    def test_malformed_sql_raises(self):
+        with pytest.raises(SQLSafetyError, match="SQL parse error"):
+            validate_sql("NOT VALID SQL AT ALL }{][")
+
+    def test_dialect_parameter(self):
+        result = validate_sql("SELECT 1", dialect="postgres")
+        assert len(result) == 1
+
+    def test_custom_allowed_types(self):
+        """Allow INSERT when explicitly opted in."""
+        result = validate_sql(
+            "INSERT INTO users (name) VALUES ('test')",
+            allowed_types=(exp.Insert,),
+        )
+        assert len(result) == 1
+
+    def test_custom_allowed_types_still_blocks_others(self):
+        """Custom types don't allow everything."""
+        with pytest.raises(SQLSafetyError, match="Select.*not allowed"):
+            validate_sql("SELECT 1", allowed_types=(exp.Insert,))
+
+    def test_select_with_trailing_semicolon(self):
+        """Trailing semicolon should not cause multi-statement error."""
+        result = validate_sql("SELECT 1;")
+        assert len(result) == 1
diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py 
b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py
index b890f63ed79..7a02f42c62e 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py
@@ -345,6 +345,17 @@ class DbApiHook(BaseHook):
             )
         return inspect(self.get_sqlalchemy_engine())
 
+    def get_table_schema(self, table_name: str, schema: str | None = None) -> 
list[dict[str, str]]:
+        """
+        Return column names and types for a table using SQLAlchemy Inspector.
+
+        :param table_name: Name of the table.
+        :param schema: Optional schema/namespace name.
+        :return: List of dicts with ``name`` and ``type`` keys.
+        """
+        columns = self.inspector.get_columns(table_name, schema=schema)
+        return [{"name": col["name"], "type": str(col["type"])} for col in 
columns]
+
     @cached_property
     def dialect_name(self) -> str:
         if make_url is not None:
diff --git a/providers/common/sql/tests/unit/common/sql/hooks/test_sql.py 
b/providers/common/sql/tests/unit/common/sql/hooks/test_sql.py
index e0c9c854a09..d88cc3abbc8 100644
--- a/providers/common/sql/tests/unit/common/sql/hooks/test_sql.py
+++ b/providers/common/sql/tests/unit/common/sql/hooks/test_sql.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 
 import inspect
 import logging
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock, PropertyMock, patch
 
 import pandas as pd
 import polars as pl
@@ -334,6 +334,41 @@ class TestDbApiHook:
             assert isinstance(df, expected_type)
 
 
+class TestDbApiHookGetTableSchema:
+    @pytest.mark.db_test
+    def test_get_table_schema(self):
+        dbapi_hook = mock_db_hook(DbApiHook)
+        mock_inspector = MagicMock()
+        mock_inspector.get_columns.return_value = [
+            {"name": "id", "type": "INTEGER", "nullable": True},
+            {"name": "name", "type": "VARCHAR(255)", "nullable": False},
+        ]
+        with patch.object(
+            type(dbapi_hook), "inspector", new_callable=PropertyMock, 
return_value=mock_inspector
+        ):
+            result = dbapi_hook.get_table_schema("users")
+
+        assert result == [
+            {"name": "id", "type": "INTEGER"},
+            {"name": "name", "type": "VARCHAR(255)"},
+        ]
+        mock_inspector.get_columns.assert_called_once_with("users", 
schema=None)
+
+    @pytest.mark.db_test
+    def test_get_table_schema_with_schema(self):
+        dbapi_hook = mock_db_hook(DbApiHook)
+        mock_inspector = MagicMock()
+        mock_inspector.get_columns.return_value = [
+            {"name": "col1", "type": "TEXT"},
+        ]
+        with patch.object(
+            type(dbapi_hook), "inspector", new_callable=PropertyMock, 
return_value=mock_inspector
+        ):
+            dbapi_hook.get_table_schema("my_table", schema="my_schema")
+
+        mock_inspector.get_columns.assert_called_once_with("my_table", 
schema="my_schema")
+
+
 def test_inspector_is_cached():
     """inspector should return the same object on repeated access (not create 
N engines)."""
     hook = DBApiHookForTests(conn_id=DEFAULT_CONN_ID)

Reply via email to