This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a6ff00f43cd Fix sql_warehouse_name resolution: handle 'warehouses' API 
response key (#63286)
a6ff00f43cd is described below

commit a6ff00f43cded82e1a832cbb141e4ad946b4e519
Author: ataulmujeeb-cyber <[email protected]>
AuthorDate: Wed Mar 11 09:03:19 2026 -0400

    Fix sql_warehouse_name resolution: handle 'warehouses' API response key 
(#63286)
    
    * Fix sql_warehouse_name resolution failing with "Can't list Databricks SQL 
endpoints"
    
    The _get_sql_endpoint_by_name method calls GET /api/2.0/sql/warehouses
    (the current API path) but checks for the "endpoints" key in the
    response. Since Databricks renamed SQL endpoints to SQL warehouses,
    the current API returns data under the "warehouses" key, causing the
    check to always fail.
    
    This fix handles both the current ("warehouses") and legacy
    ("endpoints") response keys for backward compatibility.
    
    Closes: #63285
    
    * Use standard Python exceptions instead of AirflowException
    
    Replace AirflowException with standard Python exceptions per
    contributing guidelines:
    - RuntimeError for unexpected API response (no warehouses/endpoints key)
    - ValueError for warehouse name not found in results
---
 .../providers/databricks/hooks/databricks_sql.py   | 16 +++--
 .../unit/databricks/hooks/test_databricks_sql.py   | 84 +++++++++++++++++++++-
 2 files changed, 94 insertions(+), 6 deletions(-)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index 2c2164bd9c7..021142395b2 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -129,12 +129,20 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
 
     def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
         result = self._do_api_call(LIST_SQL_ENDPOINTS_ENDPOINT)
-        if "endpoints" not in result:
-            raise AirflowException("Can't list Databricks SQL endpoints")
+        # The API response key depends on which endpoint path is used:
+        # - "warehouses" for the current /api/2.0/sql/warehouses path
+        # - "endpoints" for the legacy /api/2.0/sql/endpoints path
+        warehouses = result.get("warehouses") or result.get("endpoints")
+        if not warehouses:
+            raise RuntimeError(
+                "Can't list Databricks SQL warehouses. The API response 
contained neither "
+                "'warehouses' nor 'endpoints' key. Check that the connection 
has sufficient "
+                "permissions to list SQL warehouses."
+            )
         try:
-            endpoint = next(endpoint for endpoint in result["endpoints"] if 
endpoint["name"] == endpoint_name)
+            endpoint = next(ep for ep in warehouses if ep["name"] == 
endpoint_name)
         except StopIteration:
-            raise AirflowException(f"Can't find Databricks SQL endpoint with 
name '{endpoint_name}'")
+            raise ValueError(f"Can't find Databricks SQL warehouse with name 
'{endpoint_name}'")
         else:
             return endpoint
 
diff --git 
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py 
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
index d661f5b0714..f3f053d443c 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
@@ -80,10 +80,10 @@ def mock_get_requests():
     mock_patch = 
patch("airflow.providers.databricks.hooks.databricks_base.requests")
     mock_requests = mock_patch.start()
 
-    # Configure the mock object
+    # Configure the mock object with the current API response format 
("warehouses" key)
     mock_requests.codes.ok = 200
     mock_requests.get.return_value.json.return_value = {
-        "endpoints": [
+        "warehouses": [
             {
                 "id": "1264e5078741679a",
                 "name": "Test",
@@ -712,3 +712,83 @@ def test_get_df(df_type, df_class, description):
             assert df.row(1)[0] == result_sets[1][0]
 
         assert isinstance(df, df_class)
+
+
+class TestGetSqlEndpointByName:
+    """Tests for _get_sql_endpoint_by_name with both 'warehouses' and legacy 
'endpoints' API response keys."""
+
+    @patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_resolve_warehouse_name_with_warehouses_key(self, mock_requests):
+        """Test that the current API response format with 'warehouses' key 
works."""
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = {
+            "warehouses": [
+                {
+                    "id": "abc123",
+                    "name": "My Warehouse",
+                    "odbc_params": {
+                        "hostname": "xx.cloud.databricks.com",
+                        "path": "/sql/1.0/warehouses/abc123",
+                    },
+                }
+            ]
+        }
+        type(mock_requests.get.return_value).status_code = 
PropertyMock(return_value=200)
+
+        hook = DatabricksSqlHook(sql_endpoint_name="My Warehouse")
+        endpoint = hook._get_sql_endpoint_by_name("My Warehouse")
+        assert endpoint["id"] == "abc123"
+        assert endpoint["odbc_params"]["path"] == "/sql/1.0/warehouses/abc123"
+
+    @patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_resolve_warehouse_name_with_legacy_endpoints_key(self, 
mock_requests):
+        """Test that the legacy API response format with 'endpoints' key still 
works."""
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = {
+            "endpoints": [
+                {
+                    "id": "def456",
+                    "name": "Legacy Endpoint",
+                    "odbc_params": {
+                        "hostname": "xx.cloud.databricks.com",
+                        "path": "/sql/1.0/endpoints/def456",
+                    },
+                }
+            ]
+        }
+        type(mock_requests.get.return_value).status_code = 
PropertyMock(return_value=200)
+
+        hook = DatabricksSqlHook(sql_endpoint_name="Legacy Endpoint")
+        endpoint = hook._get_sql_endpoint_by_name("Legacy Endpoint")
+        assert endpoint["id"] == "def456"
+        assert endpoint["odbc_params"]["path"] == "/sql/1.0/endpoints/def456"
+
+    @patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_resolve_warehouse_name_not_found(self, mock_requests):
+        """Test that a clear error is raised when the warehouse name doesn't 
match any warehouse."""
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = {
+            "warehouses": [
+                {
+                    "id": "abc123",
+                    "name": "Some Other Warehouse",
+                    "odbc_params": {"path": "/sql/1.0/warehouses/abc123"},
+                }
+            ]
+        }
+        type(mock_requests.get.return_value).status_code = 
PropertyMock(return_value=200)
+
+        hook = DatabricksSqlHook(sql_endpoint_name="Nonexistent Warehouse")
+        with pytest.raises(ValueError, match="Can't find Databricks SQL 
warehouse with name"):
+            hook._get_sql_endpoint_by_name("Nonexistent Warehouse")
+
+    @patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_resolve_warehouse_name_empty_response(self, mock_requests):
+        """Test that a clear error is raised when the API returns no 
warehouses."""
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = {}
+        type(mock_requests.get.return_value).status_code = 
PropertyMock(return_value=200)
+
+        hook = DatabricksSqlHook(sql_endpoint_name="Test")
+        with pytest.raises(RuntimeError, match="Can't list Databricks SQL 
warehouses"):
+            hook._get_sql_endpoint_by_name("Test")

Reply via email to