This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0ffb28ed0d1 Add ObjectStorage support to LLMSQLQueryOperator via 
DataFusion (#62640)
0ffb28ed0d1 is described below

commit 0ffb28ed0d1d21e3f7a232ebec5dbf6ff13772f8
Author: GPK <[email protected]>
AuthorDate: Sat Feb 28 22:11:31 2026 +0000

    Add ObjectStorage support to LLMSQLQueryOperator via DataFusion (#62640)
    
    * Add ObjectStorage support to LLMSQLQueryOperator via DataFusion
    
    * Move DataFusionEngine import to AirflowOptionalProviderFeatureException 
try block
    
    * Update deps
    
    * Fixup tests
---
 providers/common/ai/docs/operators/llm_sql.rst     |  47 ++++++
 providers/common/ai/pyproject.toml                 |   1 +
 .../common/ai/example_dags/example_llm_sql.py      |  24 +++
 .../providers/common/ai/operators/llm_sql.py       |  19 ++-
 .../tests/unit/common/ai/operators/test_llm_sql.py | 164 +++++++++++++++++++++
 .../providers/common/sql/datafusion/engine.py      |   5 +
 .../unit/common/sql/datafusion/test_engine.py      |  42 ++++++
 7 files changed, 301 insertions(+), 1 deletion(-)

diff --git a/providers/common/ai/docs/operators/llm_sql.rst 
b/providers/common/ai/docs/operators/llm_sql.rst
index d2ccadb12bf..7efe4aaaa7b 100644
--- a/providers/common/ai/docs/operators/llm_sql.rst
+++ b/providers/common/ai/docs/operators/llm_sql.rst
@@ -51,6 +51,53 @@ the actual column names and types:
     :start-after: [START howto_operator_llm_sql_schema]
     :end-before: [END howto_operator_llm_sql_schema]
 
+With Object Storage
+-------------------
+
+Use ``datasource_config`` to generate queries for data stored in object storage
+(e.g., S3, local filesystem) via `DataFusion 
<https://datafusion.apache.org/>`_.
+The operator uses 
:class:`~airflow.providers.common.sql.config.DataSourceConfig`
+to register the object storage source as a table so the LLM can include it in
+the schema context.
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
+    :language: python
+    :start-after: [START howto_operator_llm_sql_with_object_storage]
+    :end-before: [END howto_operator_llm_sql_with_object_storage]
+
+Once the SQL is generated, you can execute it against object storage data using
+:class:`~airflow.providers.common.sql.operators.analytics.AnalyticsOperator`.
+Chain the two operators so that the generated query flows into the analytics
+execution step:
+
+.. code-block:: python
+
+    from airflow.providers.common.ai.operators.llm_sql import 
LLMSQLQueryOperator
+    from airflow.providers.common.sql.config import DataSourceConfig
+    from airflow.providers.common.sql.operators.analytics import 
AnalyticsOperator
+
+    datasource_config = DataSourceConfig(
+        conn_id="aws_default",
+        table_name="sales_data",
+        uri="s3://my-bucket/data/sales/",
+        format="parquet",
+    )
+
+    generate_sql = LLMSQLQueryOperator(
+        task_id="generate_sql",
+        prompt="Find the top 5 products by total sales amount",
+        llm_conn_id="pydantic_ai_default",
+        datasource_config=datasource_config,
+    )
+
+    run_query = AnalyticsOperator(
+        task_id="run_query",
+        datasource_configs=[datasource_config],
+        queries=["{{ ti.xcom_pull(task_ids='generate_sql') }}"],
+    )
+
+    generate_sql >> run_query
+
 TaskFlow Decorator
 ------------------
 
diff --git a/providers/common/ai/pyproject.toml 
b/providers/common/ai/pyproject.toml
index c80c500a50a..b8726b0248b 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -89,6 +89,7 @@ dev = [
     "apache-airflow-providers-common-sql",
     # Additional devel dependencies (do not remove this line and add extra 
development dependencies)
     "sqlglot>=26.0.0",
+    "apache-airflow-providers-common-sql[datafusion]"
 ]
 
 # To build docs:
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
index 2a7e52f5b6a..77bb6b63dc6 100644
--- 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 
 from airflow.providers.common.ai.operators.llm_sql import LLMSQLQueryOperator
 from airflow.providers.common.compat.sdk import dag, task
+from airflow.providers.common.sql.config import DataSourceConfig
 
 
 # [START howto_operator_llm_sql_basic]
@@ -100,3 +101,26 @@ def example_llm_sql_expand():
 # [END howto_operator_llm_sql_expand]
 
 example_llm_sql_expand()
+
+
+# [START howto_operator_llm_sql_with_object_storage]
+@dag
+def example_llm_sql_with_object_storage():
+    datasource_config = DataSourceConfig(
+        conn_id="aws_default",
+        table_name="sales_data",
+        uri="s3://my-bucket/data/sales/",
+        format="parquet",
+    )
+
+    LLMSQLQueryOperator(
+        task_id="generate_sql",
+        prompt="Find the top 5 products by total sales amount",
+        llm_conn_id="pydantic_ai_default",
+        datasource_config=datasource_config,
+    )
+
+
+# [END howto_operator_llm_sql_with_object_storage]
+
+example_llm_sql_with_object_storage()
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py 
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
index 4501b4c1c63..81bec01ec77 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
@@ -27,6 +27,7 @@ try:
         DEFAULT_ALLOWED_TYPES,
         validate_sql as _validate_sql,
     )
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
 except ImportError as e:
     from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
 
@@ -38,6 +39,7 @@ from airflow.providers.common.compat.sdk import BaseHook
 if TYPE_CHECKING:
     from sqlglot import exp
 
+    from airflow.providers.common.sql.config import DataSourceConfig
     from airflow.providers.common.sql.hooks.sql import DbApiHook
     from airflow.sdk import Context
 
@@ -101,6 +103,7 @@ class LLMSQLQueryOperator(LLMOperator):
         validate_sql: bool = True,
         allowed_sql_types: tuple[type[exp.Expression], ...] = 
DEFAULT_ALLOWED_TYPES,
         dialect: str | None = None,
+        datasource_config: DataSourceConfig | None = None,
         **kwargs: Any,
     ) -> None:
         kwargs.pop("output_type", None)  # SQL operator always returns str
@@ -111,6 +114,7 @@ class LLMSQLQueryOperator(LLMOperator):
         self.validate_sql = validate_sql
         self.allowed_sql_types = allowed_sql_types
         self.dialect = dialect
+        self.datasource_config = datasource_config
 
     @cached_property
     def db_hook(self) -> DbApiHook | None:
@@ -129,6 +133,7 @@ class LLMSQLQueryOperator(LLMOperator):
 
     def execute(self, context: Context) -> str:
         schema_info = self._get_schema_context()
+
         full_system_prompt = self._build_system_prompt(schema_info)
 
         agent = self.llm_hook.create_agent(
@@ -159,8 +164,9 @@ class LLMSQLQueryOperator(LLMOperator):
         """Return schema context from manual override or database 
introspection."""
         if self.schema_context:
             return self.schema_context
-        if self.db_hook and self.table_names:
+        if (self.db_hook and self.table_names) or self.datasource_config:
             return self._introspect_schemas()
+
         return ""
 
     def _introspect_schemas(self) -> str:
@@ -178,8 +184,19 @@ class LLMSQLQueryOperator(LLMOperator):
                 f"None of the requested tables ({self.table_names}) returned 
schema information. "
                 "Check that the table names are correct and the database 
connection has access."
             )
+
+        if self.datasource_config:
+            object_storage_schema = self._introspect_object_storage_schema()
+            parts.append(f"Table: 
{self.datasource_config.table_name}\nColumns: {object_storage_schema}")
+
         return "\n\n".join(parts)
 
+    def _introspect_object_storage_schema(self):
+        """Use DataFusion Engine to get the schema of object stores."""
+        engine = DataFusionEngine()
+        engine.register_datasource(self.datasource_config)
+        return engine.get_schema(self.datasource_config.table_name)
+
     def _build_system_prompt(self, schema_info: str) -> str:
         """Construct the system prompt for the LLM."""
         dialect_label = self._resolved_dialect or "SQL"
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
index 97d943c14c0..445df28c17c 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
@@ -22,6 +22,7 @@ import pytest
 
 from airflow.providers.common.ai.operators.llm_sql import LLMSQLQueryOperator
 from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+from airflow.providers.common.sql.config import DataSourceConfig
 
 
 def _make_mock_agent(output: str):
@@ -213,6 +214,169 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
         )
         assert op._get_schema_context() == "My custom schema info"
 
+    @patch(
+        "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+        autospec=True,
+    )
+    def test_introspect_object_storage_schema(self, mock_engine_cls):
+        """_introspect_object_storage_schema registers datasource and returns 
schema."""
+        mock_engine = mock_engine_cls.return_value
+        schema_text = "cust_id: int64\nname: string\namount: float64"
+        mock_engine.get_schema.return_value = schema_text
+
+        ds_config = DataSourceConfig(
+            conn_id="aws_default",
+            table_name="sales",
+            uri="s3://bucket/data/",
+            format="parquet",
+        )
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            datasource_config=ds_config,
+        )
+        result = op._introspect_object_storage_schema()
+
+        mock_engine.register_datasource.assert_called_once_with(ds_config)
+        mock_engine.get_schema.assert_called_once_with("sales")
+        assert result == schema_text
+
+    @patch(
+        "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+        autospec=True,
+    )
+    def test_introspect_schemas_with_db_and_datasource_config(self, 
mock_engine_cls):
+        """_introspect_schemas includes both db table and object storage 
schema."""
+        mock_engine = mock_engine_cls.return_value
+        object_schema = "col_a: int64\ncol_b: string"
+        mock_engine.get_schema.return_value = object_schema
+
+        ds_config = DataSourceConfig(
+            conn_id="aws_default",
+            table_name="remote_table",
+            uri="s3://bucket/path/",
+            format="csv",
+        )
+        mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+        mock_db_hook.get_table_schema.return_value = [
+            {"name": "id", "type": "INTEGER"},
+        ]
+
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            db_conn_id="pg_default",
+            table_names=["local_table"],
+            datasource_config=ds_config,
+        )
+
+        with patch.object(type(op), "db_hook", new_callable=PropertyMock, 
return_value=mock_db_hook):
+            result = op._introspect_schemas()
+
+        assert "Table: local_table" in result
+        assert "id INTEGER" in result
+        assert "Table: remote_table" in result
+        assert object_schema in result
+
+    @patch(
+        "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+        autospec=True,
+    )
+    def test_introspect_schemas_datasource_config_without_db_tables(self, 
mock_engine_cls):
+        """_introspect_schemas works when only datasource_config is provided 
(no db tables)."""
+        mock_engine = mock_engine_cls.return_value
+        mock_engine.get_schema.return_value = "ts: TIMESTAMP\nvalue: DOUBLE"
+
+        ds_config = DataSourceConfig(
+            conn_id="aws_default",
+            table_name="s3_data",
+            uri="s3://bucket/metrics/",
+            format="parquet",
+        )
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            db_conn_id="pg_default",
+            table_names=[],
+            datasource_config=ds_config,
+        )
+        mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+        mock_db_hook.get_table_schema.return_value = []
+
+        with patch.object(type(op), "db_hook", new_callable=PropertyMock, 
return_value=mock_db_hook):
+            result = op._introspect_schemas()
+
+        assert "Table: s3_data" in result
+        assert "ts: TIMESTAMP\nvalue: DOUBLE" in result
+
+    @patch(
+        "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+        autospec=True,
+    )
+    def test_introspect_schemas_raises_when_no_tables_and_no_datasource(self, 
mock_engine_cls):
+        """ValueError is raised when no db tables return schema and no 
datasource_config is set."""
+        mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+        mock_db_hook.get_table_schema.return_value = []
+
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            db_conn_id="pg_default",
+            table_names=["missing_table"],
+        )
+
+        with patch.object(type(op), "db_hook", new_callable=PropertyMock, 
return_value=mock_db_hook):
+            with pytest.raises(ValueError, match="None of the requested 
tables"):
+                op._introspect_schemas()
+
+    @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook", 
autospec=True)
+    @patch(
+        "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+        autospec=True,
+    )
+    def test_execute_with_datasource_config_and_db_tables(self, 
mock_engine_cls, mock_hook_cls):
+        """Full execute flow with both db tables and object storage 
datasource."""
+        mock_engine = mock_engine_cls.return_value
+        mock_engine.get_schema.return_value = "event: TEXT\nts: TIMESTAMP"
+
+        mock_agent = _make_mock_agent("SELECT u.id, e.event FROM users u JOIN 
events e ON u.id = e.user_id")
+        mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+        ds_config = DataSourceConfig(
+            conn_id="aws_default",
+            table_name="events",
+            uri="s3://bucket/events/",
+            format="parquet",
+        )
+        mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+        mock_db_hook.get_table_schema.return_value = [
+            {"name": "id", "type": "INTEGER"},
+            {"name": "name", "type": "VARCHAR"},
+        ]
+        mock_db_hook.dialect_name = "postgresql"
+
+        op = LLMSQLQueryOperator(
+            task_id="test",
+            prompt="Join users with events",
+            llm_conn_id="my_llm",
+            db_conn_id="pg_default",
+            table_names=["users"],
+            datasource_config=ds_config,
+        )
+
+        with patch.object(type(op), "db_hook", new_callable=PropertyMock, 
return_value=mock_db_hook):
+            result = op.execute(context=MagicMock())
+
+        assert "SELECT" in result
+        instructions = 
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+        assert "users" in instructions
+        assert "events" in instructions
+        assert "event: TEXT\nts: TIMESTAMP" in instructions
+
 
 class TestLLMSQLQueryOperatorDialect:
     def test_resolved_dialect_from_param(self):
diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py 
b/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
index 498655e46e5..e5c574600af 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
@@ -165,3 +165,8 @@ class DataFusionEngine(LoggingMixin):
     def _remove_none_values(params: dict[str, Any]) -> dict[str, Any]:
         """Filter out None values from the dictionary."""
         return {k: v for k, v in params.items() if v is not None}
+
+    def get_schema(self, table_name: str):
+        """Get the schema of a table."""
+        schema = str(self.session_context.table(table_name).schema())
+        return schema
diff --git 
a/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py 
b/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
index ef7413ca89d..24606f7e7c8 100644
--- a/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
+++ b/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
@@ -251,3 +251,45 @@ class TestDataFusionEngine:
 
         with pytest.raises(ValueError, match="Unknown connection type dummy"):
             engine._get_credentials(mock_conn)
+
+    def test_get_schema_success(self):
+        engine = DataFusionEngine()
+        engine.df_ctx = MagicMock(spec=SessionContext)
+        mock_table = MagicMock()
+        mock_schema = MagicMock()
+        mock_schema.__str__ = lambda self: "id: int64, name: string"
+        mock_table.schema.return_value = mock_schema
+        engine.df_ctx.table.return_value = mock_table
+
+        result = engine.get_schema("test_table")
+
+        engine.df_ctx.table.assert_called_once_with("test_table")
+        mock_table.schema.assert_called_once()
+        assert result == "id: int64, name: string"
+
+    @patch.object(DataFusionEngine, "_get_connection_config")
+    def test_get_schema_with_local_csv(self, mock_get_conn):
+        mock_get_conn.return_value = None
+
+        with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", 
delete=False) as f:
+            f.write("name,age\nAlice,30\nBob,25\n")
+            csv_path = f.name
+
+        try:
+            engine = DataFusionEngine()
+            datasource_config = DataSourceConfig(
+                table_name="test_csv",
+                uri=f"file://{csv_path}",
+                format="csv",
+                storage_type="local",
+                conn_id="",
+            )
+
+            engine.register_datasource(datasource_config)
+
+            result = engine.get_schema("test_csv")
+
+            assert "name: string" in result
+            assert "age: int64" in result
+        finally:
+            os.unlink(csv_path)

Reply via email to