This is an automated email from the ASF dual-hosted git repository.
gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 01f62b990a8 AIP-99: Add DataFusionToolset (#62850)
01f62b990a8 is described below
commit 01f62b990a8b0569fd1ddf7044089d27bf72c555
Author: GPK <[email protected]>
AuthorDate: Thu Mar 5 06:28:31 2026 +0000
AIP-99: Add DataFusionToolset (#62850)
* Add objectstorage support to SQLToolset via DataFusion
* Add DataFusionToolset
* Update tests
* Resolve comments
* Resolve comments
* Resolve comments
---
providers/common/ai/docs/toolsets.rst | 70 ++++-
.../providers/common/ai/toolsets/datafusion.py | 207 +++++++++++++
.../unit/common/ai/toolsets/test_datafusion.py | 344 +++++++++++++++++++++
3 files changed, 620 insertions(+), 1 deletion(-)
diff --git a/providers/common/ai/docs/toolsets.rst
b/providers/common/ai/docs/toolsets.rst
index 7334a5ae0a4..b9c3feb7477 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -24,7 +24,7 @@ Airflow's 350+ provider hooks already have typed methods,
rich docstrings,
and managed credentials. Toolsets expose them as pydantic-ai tools so that
LLM agents can call them during multi-turn reasoning.
-Two toolsets are included:
+Three toolsets are included:
- :class:`~airflow.providers.common.ai.toolsets.hook.HookToolset` — generic
adapter for any Airflow Hook.
@@ -121,6 +121,69 @@ Parameters
Default ``False`` — only SELECT-family statements are permitted.
- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
+``DataFusionToolset``
+---------------------
+
+Curated toolset wrapping
+:class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`
+with three tools — ``list_tables``, ``get_schema``, and ``query`` — for
+querying files on object stores (S3, local filesystem, Iceberg) via Apache
DataFusion.
+
+.. list-table::
+ :header-rows: 1
+ :widths: 20 50
+
+ * - Tool
+ - Description
+ * - ``list_tables``
+ - Lists registered table names
+ * - ``get_schema``
+ - Returns column names and types for a table (Arrow schema)
+ * - ``query``
+ - Executes a SQL query and returns rows as JSON
+
+Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+registers a table backed by Parquet, CSV, Avro, or Iceberg data. Multiple
+configs can be registered so that SQL queries can join across tables.
+
+.. code-block:: python
+
+ from airflow.providers.common.ai.toolsets.datafusion import
DataFusionToolset
+ from airflow.providers.common.sql.config import DataSourceConfig
+
+ toolset = DataFusionToolset(
+ datasource_configs=[
+ DataSourceConfig(
+ conn_id="aws_default",
+ table_name="sales",
+ uri="s3://my-bucket/data/sales/",
+ format="parquet",
+ ),
+ DataSourceConfig(
+ conn_id="aws_default",
+ table_name="returns",
+ uri="s3://my-bucket/data/returns/",
+ format="csv",
+ ),
+ ],
+ max_rows=100,
+ )
+
+The ``DataFusionEngine`` is created lazily on the first tool call. This
+toolset requires the ``datafusion`` extra of
+``apache-airflow-providers-common-sql``.
+
+Parameters
+^^^^^^^^^^
+
+- ``datasource_configs``: One or more
+ :class:`~airflow.providers.common.sql.config.DataSourceConfig` entries.
+ Requires ``apache-airflow-providers-common-sql[datafusion]``.
+- ``allow_writes``: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW,
+ INSERT INTO, etc.). Default ``False`` — only SELECT-family statements are
+ permitted. DataFusion on object stores is mostly read-only, but it does
+ support DDL for in-memory tables; this guard blocks those by default.
+- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
Security
--------
@@ -155,6 +218,11 @@ No single layer is sufficient — they work together.
``validate_sql()`` and rejects INSERT, UPDATE, DELETE, DROP, etc.
- Does not prevent the agent from reading sensitive data that the
database user has SELECT access to.
+ * - **DataFusionToolset: read-only by default**
+ - ``allow_writes=False`` (default) validates every SQL query through
+ ``validate_sql()`` and rejects CREATE TABLE, CREATE VIEW, INSERT
+ INTO, and other non-SELECT statements.
+ - Does not prevent the agent from reading any registered data source.
* - **SQLToolset: allowed_tables**
- Restricts which tables appear in ``list_tables`` and ``get_schema``
responses, limiting the agent's knowledge of the schema.
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
new file mode 100644
index 00000000000..7c3de86241e
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py
@@ -0,0 +1,207 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store
workflows."""
+
+from __future__ import annotations
+
+import json
+import logging
+from typing import TYPE_CHECKING, Any
+
+try:
+ from airflow.providers.common.ai.utils.sql_validation import
SQLSafetyError, validate_sql as _validate_sql
+ from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+ from airflow.providers.common.sql.datafusion.exceptions import
QueryExecutionException
+except ImportError as e:
+ from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
+
+ raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+ from pydantic_ai._run_context import RunContext
+
+ from airflow.providers.common.sql.config import DataSourceConfig
+
+log = logging.getLogger(__name__)
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+ "type": "object",
+ "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+ "type": "object",
+ "properties": {
+ "table_name": {"type": "string", "description": "Name of the table to
inspect."},
+ },
+ "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+ "type": "object",
+ "properties": {
+ "sql": {"type": "string", "description": "SQL query to execute."},
+ },
+ "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+ """
+ Curated toolset that gives an LLM agent SQL access to object-storage data
via Apache DataFusion.
+
+ Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+ backed by
+ :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+ Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+ registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+ local storage. Multiple configs can be registered so that SQL queries can
+ join across tables.
+
+ Requires the ``datafusion`` extra of
``apache-airflow-providers-common-sql``.
+
+ :param datasource_configs: One or more DataFusion data-source
configurations.
+ :param allow_writes: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW,
+ INSERT INTO, etc.). Default ``False`` — only SELECT-family statements
+ are permitted.
+ :param max_rows: Maximum number of rows returned from the ``query`` tool.
+ Default ``50``.
+ """
+
+ def __init__(
+ self,
+ datasource_configs: list[DataSourceConfig],
+ *,
+ allow_writes: bool = False,
+ max_rows: int = 50,
+ ) -> None:
+ if not datasource_configs:
+ raise ValueError("datasource_configs must contain at least one
DataSourceConfig")
+ self._datasource_configs = datasource_configs
+ self._allow_writes = allow_writes
+ self._max_rows = max_rows
+ self._engine: DataFusionEngine | None = None
+
+ @property
+ def id(self) -> str:
+ suffix = "_".join(config.table_name.replace("-", "_") for config in
self._datasource_configs)
+ return f"sql_datafusion_{suffix}"
+
+ def _get_engine(self) -> DataFusionEngine:
+ """Lazily create and configure a DataFusionEngine from
*datasource_configs*."""
+ if self._engine is None:
+ engine = DataFusionEngine()
+ for config in self._datasource_configs:
+ engine.register_datasource(config)
+ self._engine = engine
+ return self._engine
+
+ async def get_tools(self, ctx: RunContext[Any]) -> dict[str,
ToolsetTool[Any]]:
+ tools: dict[str, ToolsetTool[Any]] = {}
+
+ for name, description, schema in (
+ ("list_tables", "List available table names.",
_LIST_TABLES_SCHEMA),
+ ("get_schema", "Get column names and types for a table.",
_GET_SCHEMA_SCHEMA),
+ ("query", "Execute a SQL query and return rows as JSON.",
_QUERY_SCHEMA),
+ ):
+ tool_def = ToolDefinition(
+ name=name,
+ description=description,
+ parameters_json_schema=schema,
+ sequential=True,
+ )
+ tools[name] = ToolsetTool(
+ toolset=self,
+ tool_def=tool_def,
+ max_retries=1,
+ args_validator=_PASSTHROUGH_VALIDATOR,
+ )
+ return tools
+
+ async def call_tool(
+ self,
+ name: str,
+ tool_args: dict[str, Any],
+ ctx: RunContext[Any],
+ tool: ToolsetTool[Any],
+ ) -> Any:
+ if name == "list_tables":
+ return self._list_tables()
+ if name == "get_schema":
+ return self._get_schema(tool_args["table_name"])
+ if name == "query":
+ return self._query(tool_args["sql"])
+ raise ValueError(f"Unknown tool: {name!r}")
+
+ def _list_tables(self) -> str:
+ try:
+ engine = self._get_engine()
+ tables: list[str] =
list(engine.session_context.catalog().schema().table_names())
+ return json.dumps(tables)
+ except Exception as ex:
+ log.warning("list_tables failed: %s", ex)
+ return json.dumps({"error": str(ex)})
+
+ def _get_schema(self, table_name: str) -> str:
+ engine = self._get_engine()
+ # session_context lookup is required here instead of
engine.registered_tables,
+ # because registered_tables only tracks tables registered via
datasource config.
+ # When allow_writes is enabled, the agent may create temporary
in-memory tables
+ # that would not be captured there.
+ if not engine.session_context.table_exist(table_name):
+ return json.dumps({"error": f"Table {table_name!r} is not
available"})
+ # Intentionally using session_context instead of engine.get_schema() —
+ # the latter returns a pre-formatted string intended for other
operators,
+ # not a JSON-compatible format.
+ # TODO: refactor engine.get_schema() to return JSON and update this
accordingly
+ table = engine.session_context.table(table_name)
+ columns = [{"name": f.name, "type": str(f.type)} for f in
table.schema()]
+ return json.dumps(columns)
+
+ def _query(self, sql: str) -> str:
+ try:
+ if not self._allow_writes:
+ _validate_sql(sql)
+
+ engine = self._get_engine()
+ pydict = engine.execute_query(sql)
+ col_names = list(pydict.keys())
+ num_rows = len(next(iter(pydict.values()), []))
+
+ result: list[dict[str, Any]] = [
+ {col: pydict[col][i] for col in col_names} for i in
range(min(num_rows, self._max_rows))
+ ]
+
+ truncated = num_rows > self._max_rows
+ output: dict[str, Any] = {"rows": result, "count": num_rows}
+ if truncated:
+ output["truncated"] = True
+ output["max_rows"] = self._max_rows
+ return json.dumps(output, default=str)
+ except SQLSafetyError as ex:
+ log.warning("query failed SQL safety validation: %s", ex)
+ raise
+ except QueryExecutionException as ex:
+ return json.dumps({"error": str(ex), "query": sql})
diff --git
a/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
new file mode 100644
index 00000000000..77bc0cc80ea
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py
@@ -0,0 +1,344 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pydantic_ai._run_context import RunContext
+from pydantic_ai.toolsets.abstract import ToolsetTool
+
+from airflow.providers.common.ai.toolsets.datafusion import DataFusionToolset
+from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+from airflow.providers.common.sql.config import DataSourceConfig
+
+
+def _make_mock_datasource_config(table_name: str = "sales_data"):
+ """Create a mock DataSourceConfig."""
+
+ mock = MagicMock(spec=DataSourceConfig)
+ mock.table_name = table_name
+ return mock
+
+
+def _make_mock_engine(
+ registered_tables: dict[str, str] | None = None,
+ schema_fields: list[tuple[str, str]] | None = None,
+ query_result: dict[str, list] | None = None,
+):
+ """Create a mock DataFusionEngine with sensible defaults."""
+ mock = MagicMock()
+ tables = registered_tables or {"sales_data": "s3://bucket/sales/"}
+ mock.registered_tables = tables
+ mock.session_context.catalog().schema().table_names.return_value =
list(tables.keys())
+ mock.session_context.table_exist.side_effect = lambda name: name in tables
+
+ fields = schema_fields or [("id", "Int64"), ("amount", "Float64")]
+ arrow_fields = []
+ for name, ftype in fields:
+ field = MagicMock()
+ field.name = name
+ field.type = ftype
+ arrow_fields.append(field)
+ for tname in tables:
+ mock.session_context.table(tname).schema.return_value = arrow_fields
+
+ mock.execute_query.return_value = (
+ query_result
+ if query_result is not None
+ else {
+ "id": [1, 2],
+ "amount": [10.5, 20.0],
+ }
+ )
+ return mock
+
+
+class TestDataFusionToolsetInit:
+ def test_id_includes_table_names(self):
+ cfg_a = _make_mock_datasource_config("alpha")
+ cfg_b = _make_mock_datasource_config("beta")
+ ts = DataFusionToolset([cfg_b, cfg_a])
+ assert ts.id == "sql_datafusion_beta_alpha"
+
+ def test_single_table_id(self):
+ cfg = _make_mock_datasource_config("orders")
+ ts = DataFusionToolset([cfg])
+ assert ts.id == "sql_datafusion_orders"
+
+ def test_requires_datasource_configs(self):
+ with pytest.raises(ValueError, match="datasource_configs must contain
at least one DataSourceConfig"):
+ DataFusionToolset([])
+
+
+class TestDataFusionToolsetGetTools:
+ def test_returns_three_tools(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock(spec=RunContext)))
+ assert set(tools.keys()) == {"list_tables", "get_schema", "query"}
+
+ def test_tool_definitions_have_descriptions(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock(spec=RunContext)))
+ for tool in tools.values():
+ assert tool.tool_def.description
+
+
+class TestDataFusionToolsetListTables:
+ def test_returns_registered_tables(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ ts._engine = _make_mock_engine(
+ registered_tables={"sales": "s3://bucket/sales/", "orders":
"s3://bucket/orders/"}
+ )
+
+ result = asyncio.run(
+ ts.call_tool("list_tables", {}, ctx=MagicMock(spec=RunContext),
tool=MagicMock(spec=ToolsetTool))
+ )
+ tables = json.loads(result)
+ assert set(tables) == {"sales", "orders"}
+
+
+class TestDataFusionToolsetGetSchema:
+ def test_returns_column_info(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ ts._engine = _make_mock_engine(
+ schema_fields=[("id", "Int64"), ("amount", "Float64"), ("name",
"Utf8")]
+ )
+
+ result = asyncio.run(
+ ts.call_tool(
+ "get_schema",
+ {"table_name": "sales_data"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+ columns = json.loads(result)
+ assert columns == [
+ {"name": "id", "type": "Int64"},
+ {"name": "amount", "type": "Float64"},
+ {"name": "name", "type": "Utf8"},
+ ]
+
+
+class TestDataFusionToolsetQuery:
+ def test_returns_rows_as_json(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ ts._engine = _make_mock_engine(query_result={"id": [1, 2], "amount":
[10.5, 20.0]})
+
+ result = asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT id, amount FROM sales_data"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+ data = json.loads(result)
+ assert data["rows"] == [{"id": 1, "amount": 10.5}, {"id": 2, "amount":
20.0}]
+ assert data["count"] == 2
+
+ def test_truncates_at_max_rows(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg], max_rows=1)
+ ts._engine = _make_mock_engine(query_result={"id": [1, 2, 3], "name":
["a", "b", "c"]})
+
+ result = asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT id, name FROM sales_data"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+ data = json.loads(result)
+ assert len(data["rows"]) == 1
+ assert data["truncated"] is True
+ assert data["count"] == 3
+
+ def test_handles_empty_result(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ ts._engine = _make_mock_engine(query_result={})
+
+ result = asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT * FROM sales_data WHERE 1=0"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+ data = json.loads(result)
+ assert data["rows"] == []
+ assert data["count"] == 0
+
+ def test_blocks_create_table_by_default(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ ts._engine = _make_mock_engine()
+
+ with pytest.raises(SQLSafetyError, match="Statement type 'Create' is
not allowed"):
+ asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "CREATE TABLE new_table (id INT, name TEXT)"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+
+ def test_allows_create_table_when_writes_enabled(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg], allow_writes=True)
+ ts._engine = _make_mock_engine(query_result={})
+
+ result = asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "CREATE TABLE new_table (id INT, name TEXT)"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+ data = json.loads(result)
+ assert "error" not in data
+
+ def test_unknown_tool_raises(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+
+ with pytest.raises(ValueError, match="Unknown tool"):
+ asyncio.run(
+ ts.call_tool("bad_tool", {}, ctx=MagicMock(spec=RunContext),
tool=MagicMock(spec=ToolsetTool))
+ )
+
+
+class TestDataFusionToolsetGetSchemaErrors:
+ def test_unknown_table_returns_error_json(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ ts._engine = _make_mock_engine(registered_tables={"sales_data":
"s3://bucket/sales/"})
+
+ result = asyncio.run(
+ ts.call_tool(
+ "get_schema",
+ {"table_name": "nonexistent"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+ data = json.loads(result)
+ assert "error" in data
+ assert "nonexistent" in data["error"]
+
+ def test_missing_table_name_arg_raises_key_error(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ ts._engine = _make_mock_engine()
+
+ with pytest.raises(KeyError):
+ asyncio.run(
+ ts.call_tool(
+ "get_schema", {}, ctx=MagicMock(spec=RunContext),
tool=MagicMock(spec=ToolsetTool)
+ )
+ )
+
+
+class TestDataFusionToolsetQueryErrors:
+ def test_query_execution_exception_returns_error_json(self):
+ from airflow.providers.common.sql.datafusion.exceptions import
QueryExecutionException
+
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ engine = _make_mock_engine()
+ engine.execute_query.side_effect = QueryExecutionException("execution
failed: column x not found")
+ ts._engine = engine
+
+ result = asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT x FROM t"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+ data = json.loads(result)
+ assert "column x not found" in data["error"]
+ assert data["query"] == "SELECT x FROM t"
+
+ def test_unexpected_exception_propagates(self):
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ engine = _make_mock_engine()
+ engine.execute_query.side_effect = TypeError("unexpected error")
+ ts._engine = engine
+
+ with pytest.raises(TypeError, match="unexpected error"):
+ asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT 1"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+
+
+class TestDataFusionToolsetEngineResolution:
+ @patch("airflow.providers.common.ai.toolsets.datafusion.DataFusionEngine",
autospec=True)
+ def test_lazy_creates_engine(self, MockEngine):
+ mock_engine_instance = MagicMock()
+ MockEngine.return_value = mock_engine_instance
+
+ cfg = _make_mock_datasource_config("my_table")
+ ts = DataFusionToolset([cfg])
+ engine = ts._get_engine()
+
+ assert engine is mock_engine_instance
+ MockEngine.assert_called_once()
+ mock_engine_instance.register_datasource.assert_called_once_with(cfg)
+
+ @patch("airflow.providers.common.ai.toolsets.datafusion.DataFusionEngine",
autospec=True)
+ def test_registers_multiple_datasources(self, MockEngine):
+ mock_engine_instance = MagicMock()
+ MockEngine.return_value = mock_engine_instance
+
+ cfg_a = _make_mock_datasource_config("table_a")
+ cfg_b = _make_mock_datasource_config("table_b")
+ ts = DataFusionToolset([cfg_a, cfg_b])
+ ts._get_engine()
+
+ assert mock_engine_instance.register_datasource.call_count == 2
+
+ @patch("airflow.providers.common.ai.toolsets.datafusion.DataFusionEngine",
autospec=True)
+ def test_caches_engine_after_first_creation(self, MockEngine):
+ MockEngine.return_value = MagicMock()
+
+ cfg = _make_mock_datasource_config()
+ ts = DataFusionToolset([cfg])
+ ts._get_engine()
+ ts._get_engine()
+
+ MockEngine.assert_called_once()