kaxil commented on code in PR #62850:
URL: https://github.com/apache/airflow/pull/62850#discussion_r2885808262


##########
providers/common/ai/docs/toolsets.rst:
##########
@@ -121,6 +121,65 @@ Parameters
   Default ``False`` — only SELECT-family statements are permitted.
 - ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
 
+``DataFusionToolset``

Review Comment:
   The intro paragraph earlier in this file says "Two toolsets are included" — 
should be updated to three now.



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store 
workflows."""
+
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent SQL access to object-storage data 
vi Apache DataFusion.
+
+    Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+    backed by
+    :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+    Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+    registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+    local storage. Multiple configs can be registered so that SQL queries can
+    join across tables.
+
+    Requires the ``datafusion`` extra of 
``apache-airflow-providers-common-sql``.
+
+    :param datasource_configs: One or more DataFusion data-source 
configurations.
+    :param max_rows: Maximum number of rows returned from the ``query`` tool.
+        Default ``50``.
+    """
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        *,
+        max_rows: int = 50,
+    ) -> None:
+        self._datasource_configs = datasource_configs
+        self._max_rows = max_rows
+        self._engine: DataFusionEngine | None = None
+
+    @property
+    def id(self) -> str:
+        table_names = sorted(c.table_name for c in self._datasource_configs)
+        return f"sql-datafusion-{'-'.join(table_names)}"
+
+    def _get_engine(self) -> DataFusionEngine:
+        """Lazily create and configure a DataFusionEngine from 
*datasource_configs*."""
+        if self._engine is None:
+            engine = DataFusionEngine()
+            for config in self._datasource_configs:
+                engine.register_datasource(config)
+            self._engine = engine
+        return self._engine
+
+    async def get_tools(self, ctx: RunContext[Any]) -> dict[str, 
ToolsetTool[Any]]:
+        tools: dict[str, ToolsetTool[Any]] = {}
+
+        for name, description, schema in (
+            ("list_tables", "List available table names.", 
_LIST_TABLES_SCHEMA),
+            ("get_schema", "Get column names and types for a table.", 
_GET_SCHEMA_SCHEMA),
+            ("query", "Execute a SQL query and return rows as JSON.", 
_QUERY_SCHEMA),
+        ):
+            tool_def = ToolDefinition(
+                name=name,
+                description=description,
+                parameters_json_schema=schema,
+                sequential=True,
+            )
+            tools[name] = ToolsetTool(
+                toolset=self,
+                tool_def=tool_def,
+                max_retries=1,
+                args_validator=_PASSTHROUGH_VALIDATOR,
+            )
+        return tools
+
+    async def call_tool(
+        self,
+        name: str,
+        tool_args: dict[str, Any],
+        ctx: RunContext[Any],
+        tool: ToolsetTool[Any],
+    ) -> Any:
+        if name == "list_tables":
+            return self._list_tables()
+        if name == "get_schema":
+            return self._get_schema(tool_args["table_name"])
+        if name == "query":
+            return self._query(tool_args["sql"])
+        raise ValueError(f"Unknown tool: {name!r}")
+
+    def _list_tables(self) -> str:
+        engine = self._get_engine()

Review Comment:
   `SQLToolset._get_schema` catches errors and returns structured JSON 
(`{"error": "..."}`) so the LLM can self-correct. Here, if `table_name` doesn't 
exist, `session_context.table()` raises a raw DataFusion exception that 
propagates as a traceback — the LLM gets a wall of Python internals instead of 
something actionable.
   
   Also — `DataFusionEngine` already has a `get_schema(table_name)` method that 
does essentially the same thing. Any reason to reach into `session_context` 
directly instead of using the engine's API?



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store 
workflows."""
+
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent SQL access to object-storage data 
vi Apache DataFusion.
+
+    Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+    backed by
+    :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+    Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+    registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+    local storage. Multiple configs can be registered so that SQL queries can

Review Comment:
   Nit: `vi` → `via`



##########
providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py:
##########
@@ -0,0 +1,212 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.datafusion import DataFusionToolset
+
+
+def _make_mock_datasource_config(table_name: str = "sales_data"):
+    """Create a mock DataSourceConfig."""
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+    mock = MagicMock(spec=DataSourceConfig)
+    mock.table_name = table_name
+    return mock
+
+
+def _make_mock_engine(
+    registered_tables: dict[str, str] | None = None,
+    schema_fields: list[tuple[str, str]] | None = None,
+    query_result: dict[str, list] | None = None,
+):
+    """Create a mock DataFusionEngine with sensible defaults."""
+    mock = MagicMock()
+    mock.registered_tables = registered_tables or {"sales_data": 
"s3://bucket/sales/"}
+
+    fields = []
+    for name, ftype in schema_fields or [("id", "Int64"), ("amount", 
"Float64")]:
+        field = MagicMock()
+        field.name = name
+        field.type = ftype
+        fields.append(field)
+    mock.session_context.table.return_value.schema.return_value = fields
+
+    mock.execute_query.return_value = (
+        query_result
+        if query_result is not None
+        else {
+            "id": [1, 2],
+            "amount": [10.5, 20.0],
+        }
+    )
+    return mock
+
+
+class TestDataFusionToolsetInit:
+    def test_id_includes_sorted_table_names(self):
+        cfg_a = _make_mock_datasource_config("alpha")
+        cfg_b = _make_mock_datasource_config("beta")
+        ts = DataFusionToolset([cfg_b, cfg_a])
+        assert ts.id == "sql-datafusion-alpha-beta"
+
+    def test_single_table_id(self):
+        cfg = _make_mock_datasource_config("orders")
+        ts = DataFusionToolset([cfg])
+        assert ts.id == "sql-datafusion-orders"
+
+
+class TestDataFusionToolsetGetTools:
+    def test_returns_three_tools(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+        assert set(tools.keys()) == {"list_tables", "get_schema", "query"}
+
+    def test_tool_definitions_have_descriptions(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+        for tool in tools.values():
+            assert tool.tool_def.description
+
+
+class TestDataFusionToolsetListTables:
+    def test_returns_registered_tables(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine(
+            registered_tables={"sales": "s3://bucket/sales/", "orders": 
"s3://bucket/orders/"}
+        )
+
+        result = asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), 
tool=MagicMock()))
+        tables = json.loads(result)
+        assert set(tables) == {"sales", "orders"}
+
+
+class TestDataFusionToolsetGetSchema:
+    def test_returns_column_info(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine(
+            schema_fields=[("id", "Int64"), ("amount", "Float64"), ("name", 
"Utf8")]
+        )
+
+        result = asyncio.run(
+            ts.call_tool("get_schema", {"table_name": "sales_data"}, 
ctx=MagicMock(), tool=MagicMock())
+        )
+        columns = json.loads(result)
+        assert columns == [
+            {"name": "id", "type": "Int64"},
+            {"name": "amount", "type": "Float64"},
+            {"name": "name", "type": "Utf8"},
+        ]
+
+
+class TestDataFusionToolsetQuery:
+    def test_returns_rows_as_json(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine(query_result={"id": [1, 2], "amount": 
[10.5, 20.0]})
+
+        result = asyncio.run(
+            ts.call_tool(
+                "query", {"sql": "SELECT id, amount FROM sales_data"}, 
ctx=MagicMock(), tool=MagicMock()
+            )
+        )
+        data = json.loads(result)
+        assert data["rows"] == [{"id": 1, "amount": 10.5}, {"id": 2, "amount": 
20.0}]
+        assert data["count"] == 2
+
+    def test_truncates_at_max_rows(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg], max_rows=1)
+        ts._engine = _make_mock_engine(query_result={"id": [1, 2, 3], "name": 
["a", "b", "c"]})
+
+        result = asyncio.run(
+            ts.call_tool(
+                "query", {"sql": "SELECT id, name FROM sales_data"}, 
ctx=MagicMock(), tool=MagicMock()
+            )
+        )
+        data = json.loads(result)
+        assert len(data["rows"]) == 1
+        assert data["truncated"] is True
+        assert data["count"] == 3
+
+    def test_handles_empty_result(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine(query_result={})
+
+        result = asyncio.run(
+            ts.call_tool(
+                "query", {"sql": "SELECT * FROM sales_data WHERE 1=0"}, 
ctx=MagicMock(), tool=MagicMock()
+            )
+        )
+        data = json.loads(result)
+        assert data["rows"] == []
+        assert data["count"] == 0
+
+    def test_unknown_tool_raises(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+
+        with pytest.raises(ValueError, match="Unknown tool"):
+            asyncio.run(ts.call_tool("bad_tool", {}, ctx=MagicMock(), 
tool=MagicMock()))
+
+
+class TestDataFusionToolsetEngineResolution:
+    @patch("airflow.providers.common.ai.toolsets.datafusion.DataFusionEngine", 
autospec=True)
+    def test_lazy_creates_engine(self, MockEngine):
+        mock_engine_instance = MagicMock()
+        MockEngine.return_value = mock_engine_instance
+
+        cfg = _make_mock_datasource_config("my_table")
+        ts = DataFusionToolset([cfg])
+        engine = ts._get_engine()
+
+        assert engine is mock_engine_instance
+        MockEngine.assert_called_once()
+        mock_engine_instance.register_datasource.assert_called_once_with(cfg)
+
+    @patch("airflow.providers.common.ai.toolsets.datafusion.DataFusionEngine", 
autospec=True)
+    def test_registers_multiple_datasources(self, MockEngine):
+        mock_engine_instance = MagicMock()
+        MockEngine.return_value = mock_engine_instance
+
+        cfg_a = _make_mock_datasource_config("table_a")
+        cfg_b = _make_mock_datasource_config("table_b")
+        ts = DataFusionToolset([cfg_a, cfg_b])
+        ts._get_engine()
+
+        assert mock_engine_instance.register_datasource.call_count == 2
+
+    @patch("airflow.providers.common.ai.toolsets.datafusion.DataFusionEngine", 
autospec=True)
+    def test_caches_engine_after_first_creation(self, MockEngine):
+        MockEngine.return_value = MagicMock()
+
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._get_engine()
+        ts._get_engine()
+
+        MockEngine.assert_called_once()

Review Comment:
   Missing error-path tests — these are the most common failure modes for 
LLM-driven tool calls:
   
   1. `get_schema` with a non-existent table name (engine raises)
   2. `query` with invalid SQL (engine raises `QueryExecutionException`)
   
   Once error handling is added to the toolset methods, tests should verify 
they return structured `{"error": "..."}` JSON instead of propagating 
exceptions.



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store 
workflows."""
+
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent SQL access to object-storage data 
vi Apache DataFusion.
+
+    Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+    backed by
+    :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+    Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+    registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+    local storage. Multiple configs can be registered so that SQL queries can
+    join across tables.
+
+    Requires the ``datafusion`` extra of 
``apache-airflow-providers-common-sql``.
+
+    :param datasource_configs: One or more DataFusion data-source 
configurations.
+    :param max_rows: Maximum number of rows returned from the ``query`` tool.
+        Default ``50``.
+    """
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        *,
+        max_rows: int = 50,
+    ) -> None:
+        self._datasource_configs = datasource_configs
+        self._max_rows = max_rows
+        self._engine: DataFusionEngine | None = None
+
+    @property
+    def id(self) -> str:
+        table_names = sorted(c.table_name for c in self._datasource_configs)
+        return f"sql-datafusion-{'-'.join(table_names)}"
+
+    def _get_engine(self) -> DataFusionEngine:
+        """Lazily create and configure a DataFusionEngine from 
*datasource_configs*."""
+        if self._engine is None:
+            engine = DataFusionEngine()
+            for config in self._datasource_configs:
+                engine.register_datasource(config)
+            self._engine = engine
+        return self._engine
+
+    async def get_tools(self, ctx: RunContext[Any]) -> dict[str, 
ToolsetTool[Any]]:
+        tools: dict[str, ToolsetTool[Any]] = {}
+
+        for name, description, schema in (
+            ("list_tables", "List available table names.", 
_LIST_TABLES_SCHEMA),
+            ("get_schema", "Get column names and types for a table.", 
_GET_SCHEMA_SCHEMA),
+            ("query", "Execute a SQL query and return rows as JSON.", 
_QUERY_SCHEMA),
+        ):
+            tool_def = ToolDefinition(
+                name=name,
+                description=description,
+                parameters_json_schema=schema,
+                sequential=True,
+            )
+            tools[name] = ToolsetTool(
+                toolset=self,
+                tool_def=tool_def,
+                max_retries=1,
+                args_validator=_PASSTHROUGH_VALIDATOR,
+            )
+        return tools
+
+    async def call_tool(
+        self,
+        name: str,
+        tool_args: dict[str, Any],
+        ctx: RunContext[Any],
+        tool: ToolsetTool[Any],
+    ) -> Any:
+        if name == "list_tables":
+            return self._list_tables()
+        if name == "get_schema":
+            return self._get_schema(tool_args["table_name"])
+        if name == "query":
+            return self._query(tool_args["sql"])
+        raise ValueError(f"Unknown tool: {name!r}")
+
+    def _list_tables(self) -> str:
+        engine = self._get_engine()
+        tables: list[str] = list(engine.registered_tables.keys())
+        return json.dumps(tables)
+

Review Comment:
   Same error handling gap here — if `execute_query` raises 
`QueryExecutionException`, the LLM gets a raw traceback. Wrapping this in 
try/except and returning `{"error": "..."}` would let the agent retry with 
corrected SQL.
   
   Separately: `SQLToolset` has `allow_writes=False` and calls 
`_validate_sql()` to reject INSERT/UPDATE/DELETE/DROP before execution. 
`DataFusionToolset` has no equivalent. DataFusion on object stores is mostly 
read-only by nature, but it does support `CREATE TABLE`, `CREATE VIEW`, and 
`INSERT INTO` for in-memory tables. Worth either adding the same guard or 
documenting explicitly that DDL is allowed.



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store 
workflows."""
+
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent SQL access to object-storage data 
vi Apache DataFusion.
+
+    Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+    backed by
+    :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+    Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+    registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+    local storage. Multiple configs can be registered so that SQL queries can
+    join across tables.
+
+    Requires the ``datafusion`` extra of 
``apache-airflow-providers-common-sql``.
+
+    :param datasource_configs: One or more DataFusion data-source 
configurations.
+    :param max_rows: Maximum number of rows returned from the ``query`` tool.
+        Default ``50``.
+    """
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        *,
+        max_rows: int = 50,
+    ) -> None:
+        self._datasource_configs = datasource_configs
+        self._max_rows = max_rows
+        self._engine: DataFusionEngine | None = None
+
+    @property
+    def id(self) -> str:

Review Comment:
   Minor: if a table name contains hyphens (e.g. `sales-2024`), the ID becomes 
ambiguous — `["a-b", "c"]` and `["a", "b-c"]` both produce 
`sql-datafusion-a-b-c`. Probably fine in practice, but using a different 
separator (like `_` or `:`) would avoid it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to