This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 75e47927f7d Fix mypy static errors in databricks provider (#57768)
75e47927f7d is described below

commit 75e47927f7d5b5457c1a31bd95e6bdf01faeac9c
Author: Vincent <[email protected]>
AuthorDate: Tue Nov 4 23:00:45 2025 -0500

    Fix mypy static errors in databricks provider (#57768)
---
 .../providers/databricks/hooks/databricks_base.py  | 25 ++++++++++++++--------
 .../providers/databricks/hooks/databricks_sql.py   |  4 ++--
 .../databricks/plugins/databricks_workflow.py      |  2 +-
 .../databricks/plugins/test_databricks_workflow.py |  2 +-
 4 files changed, 20 insertions(+), 13 deletions(-)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
index 919e21c3287..6415740d90e 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
@@ -166,12 +166,12 @@ class BaseDatabricksHook(BaseHook):
         return ua_string
 
     @cached_property
-    def host(self) -> str:
+    def host(self) -> str | None:
+        host = None
         if "host" in self.databricks_conn.extra_dejson:
             host = self._parse_host(self.databricks_conn.extra_dejson["host"])
-        else:
+        elif self.databricks_conn.host:
             host = self._parse_host(self.databricks_conn.host)
-
         return host
 
     async def __aenter__(self):
@@ -207,6 +207,11 @@ class BaseDatabricksHook(BaseHook):
         # In this case, host = xx.cloud.databricks.com
         return host
 
+    def _get_connection_attr(self, attr_name: str) -> str:
+        if not (attr := getattr(self.databricks_conn, attr_name)):
+            raise ValueError(f"`{attr_name}` must be present in Connection")
+        return attr
+
     def _get_retry_object(self) -> Retrying:
         """
         Instantiate a retry object.
@@ -235,7 +240,7 @@ class BaseDatabricksHook(BaseHook):
                 with attempt:
                     resp = requests.post(
                         resource,
-                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        auth=HTTPBasicAuth(self._get_connection_attr("login"), 
self.databricks_conn.password),
                         data="grant_type=client_credentials&scope=all-apis",
                         headers={
                             **self.user_agent_header,
@@ -271,7 +276,9 @@ class BaseDatabricksHook(BaseHook):
                 with attempt:
                     async with self._session.post(
                         resource,
-                        auth=aiohttp.BasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        auth=aiohttp.BasicAuth(
+                            self._get_connection_attr("login"), 
self.databricks_conn.password
+                        ),
                         data="grant_type=client_credentials&scope=all-apis",
                         headers={
                             **self.user_agent_header,
@@ -316,7 +323,7 @@ class BaseDatabricksHook(BaseHook):
                         token = 
ManagedIdentityCredential().get_token(f"{resource}/.default")
                     else:
                         credential = ClientSecretCredential(
-                            client_id=self.databricks_conn.login,
+                            client_id=self._get_connection_attr("login"),
                             client_secret=self.databricks_conn.password,
                             
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
                         )
@@ -364,7 +371,7 @@ class BaseDatabricksHook(BaseHook):
                             token = await 
credential.get_token(f"{resource}/.default")
                     else:
                         async with AsyncClientSecretCredential(
-                            client_id=self.databricks_conn.login,
+                            client_id=self._get_connection_attr("login"),
                             client_secret=self.databricks_conn.password,
                             
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
                         ) as credential:
@@ -678,7 +685,7 @@ class BaseDatabricksHook(BaseHook):
             auth = _TokenAuth(token)
         else:
             self.log.info("Using basic auth.")
-            auth = HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password)
+            auth = HTTPBasicAuth(self._get_connection_attr("login"), 
self.databricks_conn.password)
 
         request_func: Any
         if method == "GET":
@@ -745,7 +752,7 @@ class BaseDatabricksHook(BaseHook):
             auth = BearerAuth(token)
         else:
             self.log.info("Using basic auth.")
-            auth = aiohttp.BasicAuth(self.databricks_conn.login, 
self.databricks_conn.password)
+            auth = aiohttp.BasicAuth(self._get_connection_attr("login"), 
self.databricks_conn.password)
 
         request_func: Any
         if method == "GET":
diff --git 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index d86d0453fa5..f7619bfbb2e 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -184,13 +184,13 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
             "catalog": self.catalog,
             "schema": self.schema,
         }
-        url_query = {k: v for k, v in url_query.items() if v is not None}
+        url_query_formatted: dict[str, str] = {k: v for k, v in 
url_query.items() if v is not None}
         return URL.create(
             drivername="databricks",
             username="token",
             password=self._get_token(raise_error=True),
             host=self.host,
-            query=url_query,
+            query=url_query_formatted,
         )
 
     def get_uri(self) -> str:
diff --git 
a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
 
b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
index 8513ac27f1c..c53739634ee 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
@@ -143,7 +143,7 @@ if not AIRFLOW_V_3_0_PLUS:
         if not session:
             raise AirflowException("Session not provided.")
 
-        return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, 
DagRun.run_id == run_id).first()
+        return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id, 
DagRun.run_id == run_id).one()
 
     @provide_session
     def _clear_task_instances(
diff --git 
a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
 
b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
index c41f02b9239..1bb7974df8f 100644
--- 
a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
+++ 
b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
@@ -82,7 +82,7 @@ def test_get_dagrun_airflow2():
 
     session = MagicMock()
     dag = MagicMock(dag_id=DAG_ID)
-    session.query.return_value.filter.return_value.first.return_value = 
DagRun()
+    session.query.return_value.filter.return_value.one.return_value = DagRun()
 
     result = _get_dagrun(dag, RUN_ID, session=session)
 

Reply via email to