This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 01f62b990a8 AIP-99: Add DataFusionToolset (#62850)
01f62b990a8 is described below

commit 01f62b990a8b0569fd1ddf7044089d27bf72c555
Author: GPK <[email protected]>
AuthorDate: Thu Mar 5 06:28:31 2026 +0000

    AIP-99: Add DataFusionToolset (#62850)
    
    * Add objectstorage support to SQLToolset via DataFusion
    
    * Add DataFusionToolset
    
    * Update tests
    
    * Resolve comments
    
    * Resolve comments
    
    * Resolve comments
---
 providers/common/ai/docs/toolsets.rst              |  70 ++++-
 .../providers/common/ai/toolsets/datafusion.py     | 207 +++++++++++++
 .../unit/common/ai/toolsets/test_datafusion.py     | 344 +++++++++++++++++++++
 3 files changed, 620 insertions(+), 1 deletion(-)

diff --git a/providers/common/ai/docs/toolsets.rst 
b/providers/common/ai/docs/toolsets.rst
index 7334a5ae0a4..b9c3feb7477 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -24,7 +24,7 @@ Airflow's 350+ provider hooks already have typed methods, 
rich docstrings,
 and managed credentials. Toolsets expose them as pydantic-ai tools so that
 LLM agents can call them during multi-turn reasoning.
 
-Two toolsets are included:
+Three toolsets are included:
 
 - :class:`~airflow.providers.common.ai.toolsets.hook.HookToolset` — generic
   adapter for any Airflow Hook.
@@ -121,6 +121,69 @@ Parameters
   Default ``False`` — only SELECT-family statements are permitted.
 - ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
 
+``DataFusionToolset``
+---------------------
+
+Curated toolset wrapping
+:class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`
+with three tools — ``list_tables``, ``get_schema``, and ``query`` — for
+querying files on object stores (S3, local filesystem, Iceberg) via Apache 
DataFusion.
+
+.. list-table::
+   :header-rows: 1
+   :widths: 20 50
+
+   * - Tool
+     - Description
+   * - ``list_tables``
+     - Lists registered table names
+   * - ``get_schema``
+     - Returns column names and types for a table (Arrow schema)
+   * - ``query``
+     - Executes a SQL query and returns rows as JSON
+
+Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+registers a table backed by Parquet, CSV, Avro, or Iceberg data. Multiple
+configs can be registered so that SQL queries can join across tables.
+
+.. code-block:: python
+
+    from airflow.providers.common.ai.toolsets.datafusion import 
DataFusionToolset
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+    toolset = DataFusionToolset(
+        datasource_configs=[
+            DataSourceConfig(
+                conn_id="aws_default",
+                table_name="sales",
+                uri="s3://my-bucket/data/sales/",
+                format="parquet",
+            ),
+            DataSourceConfig(
+                conn_id="aws_default",
+                table_name="returns",
+                uri="s3://my-bucket/data/returns/",
+                format="csv",
+            ),
+        ],
+        max_rows=100,
+    )
+
+The ``DataFusionEngine`` is created lazily on the first tool call. This
+toolset requires the ``datafusion`` extra of
+``apache-airflow-providers-common-sql``.
+
+Parameters
+^^^^^^^^^^
+
+- ``datasource_configs``: One or more
+  :class:`~airflow.providers.common.sql.config.DataSourceConfig` entries.
+  Requires ``apache-airflow-providers-common-sql[datafusion]``.
+- ``allow_writes``: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW,
+  INSERT INTO, etc.). Default ``False`` — only SELECT-family statements are
+  permitted. DataFusion on object stores is mostly read-only, but it does
+  support DDL for in-memory tables; this guard blocks those by default.
+- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
 
 Security
 --------
@@ -155,6 +218,11 @@ No single layer is sufficient — they work together.
        ``validate_sql()`` and rejects INSERT, UPDATE, DELETE, DROP, etc.
      - Does not prevent the agent from reading sensitive data that the
        database user has SELECT access to.
+   * - **DataFusionToolset: read-only by default**
+     - ``allow_writes=False`` (default) validates every SQL query through
+       ``validate_sql()`` and rejects CREATE TABLE, CREATE VIEW, INSERT
+       INTO, and other non-SELECT statements.
+     - Does not prevent the agent from reading any registered data source.
    * - **SQLToolset: allowed_tables**
      - Restricts which tables appear in ``list_tables`` and ``get_schema``
        responses, limiting the agent's knowledge of the schema.
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
new file mode 100644
index 00000000000..7c3de86241e
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
@@ -0,0 +1,207 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store 
workflows."""
+
+from __future__ import annotations
+
+import json
+import logging
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.ai.utils.sql_validation import 
SQLSafetyError, validate_sql as _validate_sql
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+    from airflow.providers.common.sql.datafusion.exceptions import 
QueryExecutionException
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+log = logging.getLogger(__name__)
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent SQL access to object-storage data 
via Apache DataFusion.
+
+    Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+    backed by
+    :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+    Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+    registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+    local storage. Multiple configs can be registered so that SQL queries can
+    join across tables.
+
+    Requires the ``datafusion`` extra of 
``apache-airflow-providers-common-sql``.
+
+    :param datasource_configs: One or more DataFusion data-source 
configurations.
+    :param allow_writes: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW,
+        INSERT INTO, etc.). Default ``False`` — only SELECT-family statements
+        are permitted.
+    :param max_rows: Maximum number of rows returned from the ``query`` tool.
+        Default ``50``.
+    """
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        *,
+        allow_writes: bool = False,
+        max_rows: int = 50,
+    ) -> None:
+        if not datasource_configs:
+            raise ValueError("datasource_configs must contain at least one 
DataSourceConfig")
+        self._datasource_configs = datasource_configs
+        self._allow_writes = allow_writes
+        self._max_rows = max_rows
+        self._engine: DataFusionEngine | None = None
+
+    @property
+    def id(self) -> str:
+        suffix = "_".join(config.table_name.replace("-", "_") for config in 
self._datasource_configs)
+        return f"sql_datafusion_{suffix}"
+
+    def _get_engine(self) -> DataFusionEngine:
+        """Lazily create and configure a DataFusionEngine from 
*datasource_configs*."""
+        if self._engine is None:
+            engine = DataFusionEngine()
+            for config in self._datasource_configs:
+                engine.register_datasource(config)
+            self._engine = engine
+        return self._engine
+
+    async def get_tools(self, ctx: RunContext[Any]) -> dict[str, 
ToolsetTool[Any]]:
+        tools: dict[str, ToolsetTool[Any]] = {}
+
+        for name, description, schema in (
+            ("list_tables", "List available table names.", 
_LIST_TABLES_SCHEMA),
+            ("get_schema", "Get column names and types for a table.", 
_GET_SCHEMA_SCHEMA),
+            ("query", "Execute a SQL query and return rows as JSON.", 
_QUERY_SCHEMA),
+        ):
+            tool_def = ToolDefinition(
+                name=name,
+                description=description,
+                parameters_json_schema=schema,
+                sequential=True,
+            )
+            tools[name] = ToolsetTool(
+                toolset=self,
+                tool_def=tool_def,
+                max_retries=1,
+                args_validator=_PASSTHROUGH_VALIDATOR,
+            )
+        return tools
+
+    async def call_tool(
+        self,
+        name: str,
+        tool_args: dict[str, Any],
+        ctx: RunContext[Any],
+        tool: ToolsetTool[Any],
+    ) -> Any:
+        if name == "list_tables":
+            return self._list_tables()
+        if name == "get_schema":
+            return self._get_schema(tool_args["table_name"])
+        if name == "query":
+            return self._query(tool_args["sql"])
+        raise ValueError(f"Unknown tool: {name!r}")
+
+    def _list_tables(self) -> str:
+        try:
+            engine = self._get_engine()
+            tables: list[str] = 
list(engine.session_context.catalog().schema().table_names())
+            return json.dumps(tables)
+        except Exception as ex:
+            log.warning("list_tables failed: %s", ex)
+            return json.dumps({"error": str(ex)})
+
+    def _get_schema(self, table_name: str) -> str:
+        engine = self._get_engine()
+        # session_context lookup is required here instead of 
engine.registered_tables,
+        # because registered_tables only tracks tables registered via 
datasource config.
+        # When allow_writes is enabled, the agent may create temporary 
in-memory tables
+        # that would not be captured there.
+        if not engine.session_context.table_exist(table_name):
+            return json.dumps({"error": f"Table {table_name!r} is not 
available"})
+        # Intentionally using session_context instead of engine.get_schema() —
+        # the latter returns a pre-formatted string intended for other 
operators,
+        # not a JSON-compatible format.
+        # TODO: refactor engine.get_schema() to return JSON and update this 
accordingly
+        table = engine.session_context.table(table_name)
+        columns = [{"name": f.name, "type": str(f.type)} for f in 
table.schema()]
+        return json.dumps(columns)
+
+    def _query(self, sql: str) -> str:
+        try:
+            if not self._allow_writes:
+                _validate_sql(sql)
+
+            engine = self._get_engine()
+            pydict = engine.execute_query(sql)
+            col_names = list(pydict.keys())
+            num_rows = len(next(iter(pydict.values()), []))
+
+            result: list[dict[str, Any]] = [
+                {col: pydict[col][i] for col in col_names} for i in 
range(min(num_rows, self._max_rows))
+            ]
+
+            truncated = num_rows > self._max_rows
+            output: dict[str, Any] = {"rows": result, "count": num_rows}
+            if truncated:
+                output["truncated"] = True
+                output["max_rows"] = self._max_rows
+            return json.dumps(output, default=str)
+        except SQLSafetyError as ex:
+            log.warning("query failed SQL safety validation: %s", ex)
+            raise
+        except QueryExecutionException as ex:
+            return json.dumps({"error": str(ex), "query": sql})
diff --git 
a/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py 
b/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
new file mode 100644
index 00000000000..77bc0cc80ea
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
@@ -0,0 +1,344 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pydantic_ai._run_context import RunContext
+from pydantic_ai.toolsets.abstract import ToolsetTool
+
+from airflow.providers.common.ai.toolsets.datafusion import DataFusionToolset
+from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+from airflow.providers.common.sql.config import DataSourceConfig
+
+
+def _make_mock_datasource_config(table_name: str = "sales_data"):
+    """Create a mock DataSourceConfig."""
+
+    mock = MagicMock(spec=DataSourceConfig)
+    mock.table_name = table_name
+    return mock
+
+
+def _make_mock_engine(
+    registered_tables: dict[str, str] | None = None,
+    schema_fields: list[tuple[str, str]] | None = None,
+    query_result: dict[str, list] | None = None,
+):
+    """Create a mock DataFusionEngine with sensible defaults."""
+    mock = MagicMock()
+    tables = registered_tables or {"sales_data": "s3://bucket/sales/"}
+    mock.registered_tables = tables
+    mock.session_context.catalog().schema().table_names.return_value = 
list(tables.keys())
+    mock.session_context.table_exist.side_effect = lambda name: name in tables
+
+    fields = schema_fields or [("id", "Int64"), ("amount", "Float64")]
+    arrow_fields = []
+    for name, ftype in fields:
+        field = MagicMock()
+        field.name = name
+        field.type = ftype
+        arrow_fields.append(field)
+    for tname in tables:
+        mock.session_context.table(tname).schema.return_value = arrow_fields
+
+    mock.execute_query.return_value = (
+        query_result
+        if query_result is not None
+        else {
+            "id": [1, 2],
+            "amount": [10.5, 20.0],
+        }
+    )
+    return mock
+
+
+class TestDataFusionToolsetInit:
+    def test_id_includes_table_names(self):
+        cfg_a = _make_mock_datasource_config("alpha")
+        cfg_b = _make_mock_datasource_config("beta")
+        ts = DataFusionToolset([cfg_b, cfg_a])
+        assert ts.id == "sql_datafusion_beta_alpha"
+
+    def test_single_table_id(self):
+        cfg = _make_mock_datasource_config("orders")
+        ts = DataFusionToolset([cfg])
+        assert ts.id == "sql_datafusion_orders"
+
+    def test_requires_datasource_configs(self):
+        with pytest.raises(ValueError, match="datasource_configs must contain 
at least one DataSourceConfig"):
+            DataFusionToolset([])
+
+
+class TestDataFusionToolsetGetTools:
+    def test_returns_three_tools(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock(spec=RunContext)))
+        assert set(tools.keys()) == {"list_tables", "get_schema", "query"}
+
+    def test_tool_definitions_have_descriptions(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock(spec=RunContext)))
+        for tool in tools.values():
+            assert tool.tool_def.description
+
+
+class TestDataFusionToolsetListTables:
+    def test_returns_registered_tables(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine(
+            registered_tables={"sales": "s3://bucket/sales/", "orders": 
"s3://bucket/orders/"}
+        )
+
+        result = asyncio.run(
+            ts.call_tool("list_tables", {}, ctx=MagicMock(spec=RunContext), 
tool=MagicMock(spec=ToolsetTool))
+        )
+        tables = json.loads(result)
+        assert set(tables) == {"sales", "orders"}
+
+
+class TestDataFusionToolsetGetSchema:
+    def test_returns_column_info(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine(
+            schema_fields=[("id", "Int64"), ("amount", "Float64"), ("name", 
"Utf8")]
+        )
+
+        result = asyncio.run(
+            ts.call_tool(
+                "get_schema",
+                {"table_name": "sales_data"},
+                ctx=MagicMock(spec=RunContext),
+                tool=MagicMock(spec=ToolsetTool),
+            )
+        )
+        columns = json.loads(result)
+        assert columns == [
+            {"name": "id", "type": "Int64"},
+            {"name": "amount", "type": "Float64"},
+            {"name": "name", "type": "Utf8"},
+        ]
+
+
+class TestDataFusionToolsetQuery:
+    def test_returns_rows_as_json(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine(query_result={"id": [1, 2], "amount": 
[10.5, 20.0]})
+
+        result = asyncio.run(
+            ts.call_tool(
+                "query",
+                {"sql": "SELECT id, amount FROM sales_data"},
+                ctx=MagicMock(spec=RunContext),
+                tool=MagicMock(spec=ToolsetTool),
+            )
+        )
+        data = json.loads(result)
+        assert data["rows"] == [{"id": 1, "amount": 10.5}, {"id": 2, "amount": 
20.0}]
+        assert data["count"] == 2
+
+    def test_truncates_at_max_rows(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg], max_rows=1)
+        ts._engine = _make_mock_engine(query_result={"id": [1, 2, 3], "name": 
["a", "b", "c"]})
+
+        result = asyncio.run(
+            ts.call_tool(
+                "query",
+                {"sql": "SELECT id, name FROM sales_data"},
+                ctx=MagicMock(spec=RunContext),
+                tool=MagicMock(spec=ToolsetTool),
+            )
+        )
+        data = json.loads(result)
+        assert len(data["rows"]) == 1
+        assert data["truncated"] is True
+        assert data["count"] == 3
+
+    def test_handles_empty_result(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine(query_result={})
+
+        result = asyncio.run(
+            ts.call_tool(
+                "query",
+                {"sql": "SELECT * FROM sales_data WHERE 1=0"},
+                ctx=MagicMock(spec=RunContext),
+                tool=MagicMock(spec=ToolsetTool),
+            )
+        )
+        data = json.loads(result)
+        assert data["rows"] == []
+        assert data["count"] == 0
+
+    def test_blocks_create_table_by_default(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine()
+
+        with pytest.raises(SQLSafetyError, match="Statement type 'Create' is 
not allowed"):
+            asyncio.run(
+                ts.call_tool(
+                    "query",
+                    {"sql": "CREATE TABLE new_table (id INT, name TEXT)"},
+                    ctx=MagicMock(spec=RunContext),
+                    tool=MagicMock(spec=ToolsetTool),
+                )
+            )
+
+    def test_allows_create_table_when_writes_enabled(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg], allow_writes=True)
+        ts._engine = _make_mock_engine(query_result={})
+
+        result = asyncio.run(
+            ts.call_tool(
+                "query",
+                {"sql": "CREATE TABLE new_table (id INT, name TEXT)"},
+                ctx=MagicMock(spec=RunContext),
+                tool=MagicMock(spec=ToolsetTool),
+            )
+        )
+        data = json.loads(result)
+        assert "error" not in data
+
+    def test_unknown_tool_raises(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+
+        with pytest.raises(ValueError, match="Unknown tool"):
+            asyncio.run(
+                ts.call_tool("bad_tool", {}, ctx=MagicMock(spec=RunContext), 
tool=MagicMock(spec=ToolsetTool))
+            )
+
+
+class TestDataFusionToolsetGetSchemaErrors:
+    def test_unknown_table_returns_error_json(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine(registered_tables={"sales_data": 
"s3://bucket/sales/"})
+
+        result = asyncio.run(
+            ts.call_tool(
+                "get_schema",
+                {"table_name": "nonexistent"},
+                ctx=MagicMock(spec=RunContext),
+                tool=MagicMock(spec=ToolsetTool),
+            )
+        )
+        data = json.loads(result)
+        assert "error" in data
+        assert "nonexistent" in data["error"]
+
+    def test_missing_table_name_arg_raises_key_error(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._engine = _make_mock_engine()
+
+        with pytest.raises(KeyError):
+            asyncio.run(
+                ts.call_tool(
+                    "get_schema", {}, ctx=MagicMock(spec=RunContext), 
tool=MagicMock(spec=ToolsetTool)
+                )
+            )
+
+
+class TestDataFusionToolsetQueryErrors:
+    def test_query_execution_exception_returns_error_json(self):
+        from airflow.providers.common.sql.datafusion.exceptions import 
QueryExecutionException
+
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        engine = _make_mock_engine()
+        engine.execute_query.side_effect = QueryExecutionException("execution 
failed: column x not found")
+        ts._engine = engine
+
+        result = asyncio.run(
+            ts.call_tool(
+                "query",
+                {"sql": "SELECT x FROM t"},
+                ctx=MagicMock(spec=RunContext),
+                tool=MagicMock(spec=ToolsetTool),
+            )
+        )
+        data = json.loads(result)
+        assert "column x not found" in data["error"]
+        assert data["query"] == "SELECT x FROM t"
+
+    def test_unexpected_exception_propagates(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        engine = _make_mock_engine()
+        engine.execute_query.side_effect = TypeError("unexpected error")
+        ts._engine = engine
+
+        with pytest.raises(TypeError, match="unexpected error"):
+            asyncio.run(
+                ts.call_tool(
+                    "query",
+                    {"sql": "SELECT 1"},
+                    ctx=MagicMock(spec=RunContext),
+                    tool=MagicMock(spec=ToolsetTool),
+                )
+            )
+
+
+class TestDataFusionToolsetEngineResolution:
+    @patch("airflow.providers.common.ai.toolsets.datafusion.DataFusionEngine", 
autospec=True)
+    def test_lazy_creates_engine(self, MockEngine):
+        mock_engine_instance = MagicMock()
+        MockEngine.return_value = mock_engine_instance
+
+        cfg = _make_mock_datasource_config("my_table")
+        ts = DataFusionToolset([cfg])
+        engine = ts._get_engine()
+
+        assert engine is mock_engine_instance
+        MockEngine.assert_called_once()
+        mock_engine_instance.register_datasource.assert_called_once_with(cfg)
+
+    @patch("airflow.providers.common.ai.toolsets.datafusion.DataFusionEngine", 
autospec=True)
+    def test_registers_multiple_datasources(self, MockEngine):
+        mock_engine_instance = MagicMock()
+        MockEngine.return_value = mock_engine_instance
+
+        cfg_a = _make_mock_datasource_config("table_a")
+        cfg_b = _make_mock_datasource_config("table_b")
+        ts = DataFusionToolset([cfg_a, cfg_b])
+        ts._get_engine()
+
+        assert mock_engine_instance.register_datasource.call_count == 2
+
+    @patch("airflow.providers.common.ai.toolsets.datafusion.DataFusionEngine", 
autospec=True)
+    def test_caches_engine_after_first_creation(self, MockEngine):
+        MockEngine.return_value = MagicMock()
+
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        ts._get_engine()
+        ts._get_engine()
+
+        MockEngine.assert_called_once()

Reply via email to