This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 75e47927f7d Fix mypy static errors in databricks provider (#57768)
75e47927f7d is described below
commit 75e47927f7d5b5457c1a31bd95e6bdf01faeac9c
Author: Vincent <[email protected]>
AuthorDate: Tue Nov 4 23:00:45 2025 -0500
Fix mypy static errors in databricks provider (#57768)
---
.../providers/databricks/hooks/databricks_base.py | 25 ++++++++++++++--------
.../providers/databricks/hooks/databricks_sql.py | 4 ++--
.../databricks/plugins/databricks_workflow.py | 2 +-
.../databricks/plugins/test_databricks_workflow.py | 2 +-
4 files changed, 20 insertions(+), 13 deletions(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
index 919e21c3287..6415740d90e 100644
---
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
+++
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py
@@ -166,12 +166,12 @@ class BaseDatabricksHook(BaseHook):
return ua_string
@cached_property
- def host(self) -> str:
+ def host(self) -> str | None:
+ host = None
if "host" in self.databricks_conn.extra_dejson:
host = self._parse_host(self.databricks_conn.extra_dejson["host"])
- else:
+ elif self.databricks_conn.host:
host = self._parse_host(self.databricks_conn.host)
-
return host
async def __aenter__(self):
@@ -207,6 +207,11 @@ class BaseDatabricksHook(BaseHook):
# In this case, host = xx.cloud.databricks.com
return host
+ def _get_connection_attr(self, attr_name: str) -> str:
+ if not (attr := getattr(self.databricks_conn, attr_name)):
+ raise ValueError(f"`{attr_name}` must be present in Connection")
+ return attr
+
def _get_retry_object(self) -> Retrying:
"""
Instantiate a retry object.
@@ -235,7 +240,7 @@ class BaseDatabricksHook(BaseHook):
with attempt:
resp = requests.post(
resource,
- auth=HTTPBasicAuth(self.databricks_conn.login,
self.databricks_conn.password),
+ auth=HTTPBasicAuth(self._get_connection_attr("login"),
self.databricks_conn.password),
data="grant_type=client_credentials&scope=all-apis",
headers={
**self.user_agent_header,
@@ -271,7 +276,9 @@ class BaseDatabricksHook(BaseHook):
with attempt:
async with self._session.post(
resource,
- auth=aiohttp.BasicAuth(self.databricks_conn.login,
self.databricks_conn.password),
+ auth=aiohttp.BasicAuth(
+ self._get_connection_attr("login"),
self.databricks_conn.password
+ ),
data="grant_type=client_credentials&scope=all-apis",
headers={
**self.user_agent_header,
@@ -316,7 +323,7 @@ class BaseDatabricksHook(BaseHook):
token =
ManagedIdentityCredential().get_token(f"{resource}/.default")
else:
credential = ClientSecretCredential(
- client_id=self.databricks_conn.login,
+ client_id=self._get_connection_attr("login"),
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
)
@@ -364,7 +371,7 @@ class BaseDatabricksHook(BaseHook):
token = await
credential.get_token(f"{resource}/.default")
else:
async with AsyncClientSecretCredential(
- client_id=self.databricks_conn.login,
+ client_id=self._get_connection_attr("login"),
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
) as credential:
@@ -678,7 +685,7 @@ class BaseDatabricksHook(BaseHook):
auth = _TokenAuth(token)
else:
self.log.info("Using basic auth.")
- auth = HTTPBasicAuth(self.databricks_conn.login,
self.databricks_conn.password)
+ auth = HTTPBasicAuth(self._get_connection_attr("login"),
self.databricks_conn.password)
request_func: Any
if method == "GET":
@@ -745,7 +752,7 @@ class BaseDatabricksHook(BaseHook):
auth = BearerAuth(token)
else:
self.log.info("Using basic auth.")
- auth = aiohttp.BasicAuth(self.databricks_conn.login,
self.databricks_conn.password)
+ auth = aiohttp.BasicAuth(self._get_connection_attr("login"),
self.databricks_conn.password)
request_func: Any
if method == "GET":
diff --git
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index d86d0453fa5..f7619bfbb2e 100644
---
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -184,13 +184,13 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
"catalog": self.catalog,
"schema": self.schema,
}
- url_query = {k: v for k, v in url_query.items() if v is not None}
+ url_query_formatted: dict[str, str] = {k: v for k, v in
url_query.items() if v is not None}
return URL.create(
drivername="databricks",
username="token",
password=self._get_token(raise_error=True),
host=self.host,
- query=url_query,
+ query=url_query_formatted,
)
def get_uri(self) -> str:
diff --git
a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
index 8513ac27f1c..c53739634ee 100644
---
a/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
+++
b/providers/databricks/src/airflow/providers/databricks/plugins/databricks_workflow.py
@@ -143,7 +143,7 @@ if not AIRFLOW_V_3_0_PLUS:
if not session:
raise AirflowException("Session not provided.")
- return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id,
DagRun.run_id == run_id).first()
+ return session.query(DagRun).filter(DagRun.dag_id == dag.dag_id,
DagRun.run_id == run_id).one()
@provide_session
def _clear_task_instances(
diff --git
a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
index c41f02b9239..1bb7974df8f 100644
---
a/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
+++
b/providers/databricks/tests/unit/databricks/plugins/test_databricks_workflow.py
@@ -82,7 +82,7 @@ def test_get_dagrun_airflow2():
session = MagicMock()
dag = MagicMock(dag_id=DAG_ID)
- session.query.return_value.filter.return_value.first.return_value =
DagRun()
+ session.query.return_value.filter.return_value.one.return_value = DagRun()
result = _get_dagrun(dag, RUN_ID, session=session)