This is an automated email from the ASF dual-hosted git repository.
gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0ffb28ed0d1 Add ObjectStorage support to LLMSQLQueryOperator via
DataFusion (#62640)
0ffb28ed0d1 is described below
commit 0ffb28ed0d1d21e3f7a232ebec5dbf6ff13772f8
Author: GPK <[email protected]>
AuthorDate: Sat Feb 28 22:11:31 2026 +0000
Add ObjectStorage support to LLMSQLQueryOperator via DataFusion (#62640)
* Add ObjectStorage support to LLMSQLQueryOperator via DataFusion
* Move DataFusionEngine import to AirflowOptionalProviderFeatureException
try block
* Update deps
* Fixup tests
---
providers/common/ai/docs/operators/llm_sql.rst | 47 ++++++
providers/common/ai/pyproject.toml | 1 +
.../common/ai/example_dags/example_llm_sql.py | 24 +++
.../providers/common/ai/operators/llm_sql.py | 19 ++-
.../tests/unit/common/ai/operators/test_llm_sql.py | 164 +++++++++++++++++++++
.../providers/common/sql/datafusion/engine.py | 5 +
.../unit/common/sql/datafusion/test_engine.py | 42 ++++++
7 files changed, 301 insertions(+), 1 deletion(-)
diff --git a/providers/common/ai/docs/operators/llm_sql.rst
b/providers/common/ai/docs/operators/llm_sql.rst
index d2ccadb12bf..7efe4aaaa7b 100644
--- a/providers/common/ai/docs/operators/llm_sql.rst
+++ b/providers/common/ai/docs/operators/llm_sql.rst
@@ -51,6 +51,53 @@ the actual column names and types:
:start-after: [START howto_operator_llm_sql_schema]
:end-before: [END howto_operator_llm_sql_schema]
+With Object Storage
+-------------------
+
+Use ``datasource_config`` to generate queries for data stored in object storage
+(e.g., S3, local filesystem) via `DataFusion
<https://datafusion.apache.org/>`_.
+The operator uses
:class:`~airflow.providers.common.sql.config.DataSourceConfig`
+to register the object storage source as a table so the LLM can include it in
+the schema context.
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
+ :language: python
+ :start-after: [START howto_operator_llm_sql_with_object_storage]
+ :end-before: [END howto_operator_llm_sql_with_object_storage]
+
+Once the SQL is generated, you can execute it against object storage data using
+:class:`~airflow.providers.common.sql.operators.analytics.AnalyticsOperator`.
+Chain the two operators so that the generated query flows into the analytics
+execution step:
+
+.. code-block:: python
+
+ from airflow.providers.common.ai.operators.llm_sql import
LLMSQLQueryOperator
+ from airflow.providers.common.sql.config import DataSourceConfig
+ from airflow.providers.common.sql.operators.analytics import
AnalyticsOperator
+
+ datasource_config = DataSourceConfig(
+ conn_id="aws_default",
+ table_name="sales_data",
+ uri="s3://my-bucket/data/sales/",
+ format="parquet",
+ )
+
+ generate_sql = LLMSQLQueryOperator(
+ task_id="generate_sql",
+ prompt="Find the top 5 products by total sales amount",
+ llm_conn_id="pydantic_ai_default",
+ datasource_config=datasource_config,
+ )
+
+ run_query = AnalyticsOperator(
+ task_id="run_query",
+ datasource_configs=[datasource_config],
+ queries=["{{ ti.xcom_pull(task_ids='generate_sql') }}"],
+ )
+
+ generate_sql >> run_query
+
TaskFlow Decorator
------------------
diff --git a/providers/common/ai/pyproject.toml
b/providers/common/ai/pyproject.toml
index c80c500a50a..b8726b0248b 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -89,6 +89,7 @@ dev = [
"apache-airflow-providers-common-sql",
# Additional devel dependencies (do not remove this line and add extra
development dependencies)
"sqlglot>=26.0.0",
+ "apache-airflow-providers-common-sql[datafusion]"
]
# To build docs:
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
index 2a7e52f5b6a..77bb6b63dc6 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
@@ -20,6 +20,7 @@ from __future__ import annotations
from airflow.providers.common.ai.operators.llm_sql import LLMSQLQueryOperator
from airflow.providers.common.compat.sdk import dag, task
+from airflow.providers.common.sql.config import DataSourceConfig
# [START howto_operator_llm_sql_basic]
@@ -100,3 +101,26 @@ def example_llm_sql_expand():
# [END howto_operator_llm_sql_expand]
example_llm_sql_expand()
+
+
+# [START howto_operator_llm_sql_with_object_storage]
+@dag
+def example_llm_sql_with_object_storage():
+ datasource_config = DataSourceConfig(
+ conn_id="aws_default",
+ table_name="sales_data",
+ uri="s3://my-bucket/data/sales/",
+ format="parquet",
+ )
+
+ LLMSQLQueryOperator(
+ task_id="generate_sql",
+ prompt="Find the top 5 products by total sales amount",
+ llm_conn_id="pydantic_ai_default",
+ datasource_config=datasource_config,
+ )
+
+
+# [END howto_operator_llm_sql_with_object_storage]
+
+example_llm_sql_with_object_storage()
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
index 4501b4c1c63..81bec01ec77 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
@@ -27,6 +27,7 @@ try:
DEFAULT_ALLOWED_TYPES,
validate_sql as _validate_sql,
)
+ from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
except ImportError as e:
from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
@@ -38,6 +39,7 @@ from airflow.providers.common.compat.sdk import BaseHook
if TYPE_CHECKING:
from sqlglot import exp
+ from airflow.providers.common.sql.config import DataSourceConfig
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sdk import Context
@@ -101,6 +103,7 @@ class LLMSQLQueryOperator(LLMOperator):
validate_sql: bool = True,
allowed_sql_types: tuple[type[exp.Expression], ...] =
DEFAULT_ALLOWED_TYPES,
dialect: str | None = None,
+ datasource_config: DataSourceConfig | None = None,
**kwargs: Any,
) -> None:
kwargs.pop("output_type", None) # SQL operator always returns str
@@ -111,6 +114,7 @@ class LLMSQLQueryOperator(LLMOperator):
self.validate_sql = validate_sql
self.allowed_sql_types = allowed_sql_types
self.dialect = dialect
+ self.datasource_config = datasource_config
@cached_property
def db_hook(self) -> DbApiHook | None:
@@ -129,6 +133,7 @@ class LLMSQLQueryOperator(LLMOperator):
def execute(self, context: Context) -> str:
schema_info = self._get_schema_context()
+
full_system_prompt = self._build_system_prompt(schema_info)
agent = self.llm_hook.create_agent(
@@ -159,8 +164,9 @@ class LLMSQLQueryOperator(LLMOperator):
"""Return schema context from manual override or database
introspection."""
if self.schema_context:
return self.schema_context
- if self.db_hook and self.table_names:
+ if (self.db_hook and self.table_names) or self.datasource_config:
return self._introspect_schemas()
+
return ""
def _introspect_schemas(self) -> str:
@@ -178,8 +184,19 @@ class LLMSQLQueryOperator(LLMOperator):
f"None of the requested tables ({self.table_names}) returned
schema information. "
"Check that the table names are correct and the database
connection has access."
)
+
+ if self.datasource_config:
+ object_storage_schema = self._introspect_object_storage_schema()
+ parts.append(f"Table:
{self.datasource_config.table_name}\nColumns: {object_storage_schema}")
+
return "\n\n".join(parts)
+ def _introspect_object_storage_schema(self):
+ """Use DataFusion Engine to get the schema of object stores."""
+ engine = DataFusionEngine()
+ engine.register_datasource(self.datasource_config)
+ return engine.get_schema(self.datasource_config.table_name)
+
def _build_system_prompt(self, schema_info: str) -> str:
"""Construct the system prompt for the LLM."""
dialect_label = self._resolved_dialect or "SQL"
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
index 97d943c14c0..445df28c17c 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
@@ -22,6 +22,7 @@ import pytest
from airflow.providers.common.ai.operators.llm_sql import LLMSQLQueryOperator
from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+from airflow.providers.common.sql.config import DataSourceConfig
def _make_mock_agent(output: str):
@@ -213,6 +214,169 @@ class TestLLMSQLQueryOperatorSchemaIntrospection:
)
assert op._get_schema_context() == "My custom schema info"
+ @patch(
+ "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+ autospec=True,
+ )
+ def test_introspect_object_storage_schema(self, mock_engine_cls):
+ """_introspect_object_storage_schema registers datasource and returns
schema."""
+ mock_engine = mock_engine_cls.return_value
+ schema_text = "cust_id: int64\nname: string\namount: float64"
+ mock_engine.get_schema.return_value = schema_text
+
+ ds_config = DataSourceConfig(
+ conn_id="aws_default",
+ table_name="sales",
+ uri="s3://bucket/data/",
+ format="parquet",
+ )
+ op = LLMSQLQueryOperator(
+ task_id="test",
+ prompt="test",
+ llm_conn_id="my_llm",
+ datasource_config=ds_config,
+ )
+ result = op._introspect_object_storage_schema()
+
+ mock_engine.register_datasource.assert_called_once_with(ds_config)
+ mock_engine.get_schema.assert_called_once_with("sales")
+ assert result == schema_text
+
+ @patch(
+ "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+ autospec=True,
+ )
+ def test_introspect_schemas_with_db_and_datasource_config(self,
mock_engine_cls):
+ """_introspect_schemas includes both db table and object storage
schema."""
+ mock_engine = mock_engine_cls.return_value
+ object_schema = "col_a: int64\ncol_b: string"
+ mock_engine.get_schema.return_value = object_schema
+
+ ds_config = DataSourceConfig(
+ conn_id="aws_default",
+ table_name="remote_table",
+ uri="s3://bucket/path/",
+ format="csv",
+ )
+ mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+ mock_db_hook.get_table_schema.return_value = [
+ {"name": "id", "type": "INTEGER"},
+ ]
+
+ op = LLMSQLQueryOperator(
+ task_id="test",
+ prompt="test",
+ llm_conn_id="my_llm",
+ db_conn_id="pg_default",
+ table_names=["local_table"],
+ datasource_config=ds_config,
+ )
+
+ with patch.object(type(op), "db_hook", new_callable=PropertyMock,
return_value=mock_db_hook):
+ result = op._introspect_schemas()
+
+ assert "Table: local_table" in result
+ assert "id INTEGER" in result
+ assert "Table: remote_table" in result
+ assert object_schema in result
+
+ @patch(
+ "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+ autospec=True,
+ )
+ def test_introspect_schemas_datasource_config_without_db_tables(self,
mock_engine_cls):
+ """_introspect_schemas works when only datasource_config is provided
(no db tables)."""
+ mock_engine = mock_engine_cls.return_value
+ mock_engine.get_schema.return_value = "ts: TIMESTAMP\nvalue: DOUBLE"
+
+ ds_config = DataSourceConfig(
+ conn_id="aws_default",
+ table_name="s3_data",
+ uri="s3://bucket/metrics/",
+ format="parquet",
+ )
+ op = LLMSQLQueryOperator(
+ task_id="test",
+ prompt="test",
+ llm_conn_id="my_llm",
+ db_conn_id="pg_default",
+ table_names=[],
+ datasource_config=ds_config,
+ )
+ mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+ mock_db_hook.get_table_schema.return_value = []
+
+ with patch.object(type(op), "db_hook", new_callable=PropertyMock,
return_value=mock_db_hook):
+ result = op._introspect_schemas()
+
+ assert "Table: s3_data" in result
+ assert "ts: TIMESTAMP\nvalue: DOUBLE" in result
+
+ @patch(
+ "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+ autospec=True,
+ )
+ def test_introspect_schemas_raises_when_no_tables_and_no_datasource(self,
mock_engine_cls):
+ """ValueError is raised when no db tables return schema and no
datasource_config is set."""
+ mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+ mock_db_hook.get_table_schema.return_value = []
+
+ op = LLMSQLQueryOperator(
+ task_id="test",
+ prompt="test",
+ llm_conn_id="my_llm",
+ db_conn_id="pg_default",
+ table_names=["missing_table"],
+ )
+
+ with patch.object(type(op), "db_hook", new_callable=PropertyMock,
return_value=mock_db_hook):
+ with pytest.raises(ValueError, match="None of the requested
tables"):
+ op._introspect_schemas()
+
+ @patch("airflow.providers.common.ai.operators.llm.PydanticAIHook",
autospec=True)
+ @patch(
+ "airflow.providers.common.ai.operators.llm_sql.DataFusionEngine",
+ autospec=True,
+ )
+ def test_execute_with_datasource_config_and_db_tables(self,
mock_engine_cls, mock_hook_cls):
+ """Full execute flow with both db tables and object storage
datasource."""
+ mock_engine = mock_engine_cls.return_value
+ mock_engine.get_schema.return_value = "event: TEXT\nts: TIMESTAMP"
+
+ mock_agent = _make_mock_agent("SELECT u.id, e.event FROM users u JOIN
events e ON u.id = e.user_id")
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ ds_config = DataSourceConfig(
+ conn_id="aws_default",
+ table_name="events",
+ uri="s3://bucket/events/",
+ format="parquet",
+ )
+ mock_db_hook = MagicMock(spec=["get_table_schema", "dialect_name"])
+ mock_db_hook.get_table_schema.return_value = [
+ {"name": "id", "type": "INTEGER"},
+ {"name": "name", "type": "VARCHAR"},
+ ]
+ mock_db_hook.dialect_name = "postgresql"
+
+ op = LLMSQLQueryOperator(
+ task_id="test",
+ prompt="Join users with events",
+ llm_conn_id="my_llm",
+ db_conn_id="pg_default",
+ table_names=["users"],
+ datasource_config=ds_config,
+ )
+
+ with patch.object(type(op), "db_hook", new_callable=PropertyMock,
return_value=mock_db_hook):
+ result = op.execute(context=MagicMock())
+
+ assert "SELECT" in result
+ instructions =
mock_hook_cls.return_value.create_agent.call_args[1]["instructions"]
+ assert "users" in instructions
+ assert "events" in instructions
+ assert "event: TEXT\nts: TIMESTAMP" in instructions
+
class TestLLMSQLQueryOperatorDialect:
def test_resolved_dialect_from_param(self):
diff --git
a/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
b/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
index 498655e46e5..e5c574600af 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py
@@ -165,3 +165,8 @@ class DataFusionEngine(LoggingMixin):
def _remove_none_values(params: dict[str, Any]) -> dict[str, Any]:
"""Filter out None values from the dictionary."""
return {k: v for k, v in params.items() if v is not None}
+
+ def get_schema(self, table_name: str):
+ """Get the schema of a table."""
+ schema = str(self.session_context.table(table_name).schema())
+ return schema
diff --git
a/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
b/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
index ef7413ca89d..24606f7e7c8 100644
--- a/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
+++ b/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py
@@ -251,3 +251,45 @@ class TestDataFusionEngine:
with pytest.raises(ValueError, match="Unknown connection type dummy"):
engine._get_credentials(mock_conn)
+
+ def test_get_schema_success(self):
+ engine = DataFusionEngine()
+ engine.df_ctx = MagicMock(spec=SessionContext)
+ mock_table = MagicMock()
+ mock_schema = MagicMock()
+ mock_schema.__str__ = lambda self: "id: int64, name: string"
+ mock_table.schema.return_value = mock_schema
+ engine.df_ctx.table.return_value = mock_table
+
+ result = engine.get_schema("test_table")
+
+ engine.df_ctx.table.assert_called_once_with("test_table")
+ mock_table.schema.assert_called_once()
+ assert result == "id: int64, name: string"
+
+ @patch.object(DataFusionEngine, "_get_connection_config")
+ def test_get_schema_with_local_csv(self, mock_get_conn):
+ mock_get_conn.return_value = None
+
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv",
delete=False) as f:
+ f.write("name,age\nAlice,30\nBob,25\n")
+ csv_path = f.name
+
+ try:
+ engine = DataFusionEngine()
+ datasource_config = DataSourceConfig(
+ table_name="test_csv",
+ uri=f"file://{csv_path}",
+ format="csv",
+ storage_type="local",
+ conn_id="",
+ )
+
+ engine.register_datasource(datasource_config)
+
+ result = engine.get_schema("test_csv")
+
+ assert "name: string" in result
+ assert "age: int64" in result
+ finally:
+ os.unlink(csv_path)