This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 352d7f7 More operators for Databricks Repos (#22422)
352d7f7 is described below
commit 352d7f72dd1e21f1522d69b71917142430548d66
Author: Alex Ott <[email protected]>
AuthorDate: Sun Mar 27 22:24:45 2022 +0200
More operators for Databricks Repos (#22422)
---
.../example_dags/example_databricks_repos.py | 26 ++-
airflow/providers/databricks/hooks/databricks.py | 27 ++-
.../providers/databricks/hooks/databricks_base.py | 23 ++-
.../databricks/operators/databricks_repos.py | 205 ++++++++++++++++++++-
airflow/providers/databricks/provider.yaml | 2 +
.../connections/databricks.rst | 4 +-
.../operators/repos_create.rst | 69 +++++++
.../operators/repos_delete.rst | 61 ++++++
.../databricks/operators/test_databricks.py | 2 +-
.../databricks/operators/test_databricks_repos.py | 154 +++++++++++++++-
10 files changed, 552 insertions(+), 21 deletions(-)
diff --git
a/airflow/providers/databricks/example_dags/example_databricks_repos.py
b/airflow/providers/databricks/example_dags/example_databricks_repos.py
index 458f7cb..e33d320 100644
--- a/airflow/providers/databricks/example_dags/example_databricks_repos.py
+++ b/airflow/providers/databricks/example_dags/example_databricks_repos.py
@@ -19,20 +19,32 @@ from datetime import datetime
from airflow import DAG
from airflow.providers.databricks.operators.databricks import
DatabricksSubmitRunOperator
-from airflow.providers.databricks.operators.databricks_repos import
DatabricksReposUpdateOperator
+from airflow.providers.databricks.operators.databricks_repos import (
+ DatabricksReposCreateOperator,
+ DatabricksReposDeleteOperator,
+ DatabricksReposUpdateOperator,
+)
default_args = {
'owner': 'airflow',
- 'databricks_conn_id': 'my-shard-pat',
+ 'databricks_conn_id': 'databricks',
}
with DAG(
- dag_id='example_databricks_operator',
+ dag_id='example_databricks_repos_operator',
schedule_interval='@daily',
start_date=datetime(2021, 1, 1),
+ default_args=default_args,
tags=['example'],
catchup=False,
) as dag:
+ # [START howto_operator_databricks_repo_create]
+ # Example of creating a Databricks Repo
+ repo_path = "/Repos/[email protected]/demo-repo"
+ git_url = "https://github.com/test/test"
+ create_repo = DatabricksReposCreateOperator(task_id='create_repo',
repo_path=repo_path, git_url=git_url)
+ # [END howto_operator_databricks_repo_create]
+
# [START howto_operator_databricks_repo_update]
# Example of updating a Databricks Repo to the latest code
repo_path = "/Repos/[email protected]/demo-repo"
@@ -53,4 +65,10 @@ with DAG(
notebook_task = DatabricksSubmitRunOperator(task_id='notebook_task',
json=notebook_task_params)
- (update_repo >> notebook_task)
+ # [START howto_operator_databricks_repo_delete]
+ # Example of deleting a Databricks Repo
+ repo_path = "/Repos/[email protected]/demo-repo"
+ delete_repo = DatabricksReposDeleteOperator(task_id='delete_repo',
repo_path=repo_path)
+ # [END howto_operator_databricks_repo_delete]
+
+ (create_repo >> update_repo >> notebook_task >> delete_repo)
diff --git a/airflow/providers/databricks/hooks/databricks.py
b/airflow/providers/databricks/hooks/databricks.py
index 977800e..ffa7757 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -333,14 +333,35 @@ class DatabricksHook(BaseDatabricksHook):
def update_repo(self, repo_id: str, json: Dict[str, Any]) -> dict:
"""
+ Updates given Databricks Repos
- :param repo_id:
- :param json:
- :return:
+ :param repo_id: ID of Databricks Repos
+ :param json: payload
+ :return: metadata from update
"""
repos_endpoint = ('PATCH', f'api/2.0/repos/{repo_id}')
return self._do_api_call(repos_endpoint, json)
+ def delete_repo(self, repo_id: str):
+ """
+ Deletes given Databricks Repos
+
+ :param repo_id: ID of Databricks Repos
+ :return:
+ """
+ repos_endpoint = ('DELETE', f'api/2.0/repos/{repo_id}')
+ self._do_api_call(repos_endpoint)
+
+ def create_repo(self, json: Dict[str, Any]) -> dict:
+ """
+ Creates a Databricks Repos
+
+ :param json: payload
+ :return:
+ """
+ repos_endpoint = ('POST', 'api/2.0/repos')
+ return self._do_api_call(repos_endpoint, json)
+
def get_repo_by_path(self, path: str) -> Optional[str]:
"""
diff --git a/airflow/providers/databricks/hooks/databricks_base.py
b/airflow/providers/databricks/hooks/databricks_base.py
index ec856a0..1a418fd 100644
--- a/airflow/providers/databricks/hooks/databricks_base.py
+++ b/airflow/providers/databricks/hooks/databricks_base.py
@@ -31,6 +31,7 @@ from urllib.parse import urlparse
import requests
from requests import PreparedRequest, exceptions as requests_exceptions
from requests.auth import AuthBase, HTTPBasicAuth
+from requests.exceptions import JSONDecodeError
from tenacity import RetryError, Retrying, retry_if_exception,
stop_after_attempt, wait_exponential
from airflow import __version__
@@ -340,6 +341,8 @@ class BaseDatabricksHook(BaseHook):
request_func = requests.post
elif method == 'PATCH':
request_func = requests.patch
+ elif method == 'DELETE':
+ request_func = requests.delete
else:
raise AirflowException('Unexpected HTTP Method: ' + method)
@@ -362,12 +365,30 @@ class BaseDatabricksHook(BaseHook):
raise AirflowException(f'Response: {e.response.content}, Status
Code: {e.response.status_code}')
@staticmethod
+ def _get_error_code(exception: BaseException) -> str:
+ if isinstance(exception, requests_exceptions.HTTPError):
+ try:
+ jsn = exception.response.json()
+ return jsn.get('error_code', '')
+ except JSONDecodeError:
+ pass
+
+ return ""
+
+ @staticmethod
def _retryable_error(exception: BaseException) -> bool:
if not isinstance(exception, requests_exceptions.RequestException):
return False
return isinstance(exception, (requests_exceptions.ConnectionError,
requests_exceptions.Timeout)) or (
exception.response is not None
- and (exception.response.status_code >= 500 or
exception.response.status_code == 429)
+ and (
+ exception.response.status_code >= 500
+ or exception.response.status_code == 429
+ or (
+ exception.response.status_code == 400
+ and BaseDatabricksHook._get_error_code(exception) ==
'COULD_NOT_ACQUIRE_LOCK'
+ )
+ )
)
diff --git a/airflow/providers/databricks/operators/databricks_repos.py
b/airflow/providers/databricks/operators/databricks_repos.py
index fc50730..15543cc 100644
--- a/airflow/providers/databricks/operators/databricks_repos.py
+++ b/airflow/providers/databricks/operators/databricks_repos.py
@@ -17,8 +17,9 @@
# under the License.
#
"""This module contains Databricks operators."""
-
+import re
from typing import TYPE_CHECKING, Optional, Sequence
+from urllib.parse import urlparse
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -28,12 +29,142 @@ if TYPE_CHECKING:
from airflow.utils.context import Context
+class DatabricksReposCreateOperator(BaseOperator):
+ """
+ Creates a Databricks Repo
+ using
+ `POST api/2.0/repos
<https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/create-repo>`_
+ API endpoint and optionally checking it out to a specific branch or tag.
+
+ :param git_url: Required HTTPS URL of a Git repository
+ :param git_provider: Optional name of Git provider. Must be provided if we
can't guess its name from URL.
+ :param repo_path: optional path for a repository. Must be in the format
``/Repos/{folder}/{repo-name}``.
+ If not specified, it will be created in the user's directory.
+ :param branch: optional name of branch to check out.
+ :param tag: optional name of tag to checkout.
+ :param ignore_existing_repo: don't throw exception if repository with
given path already exists.
+ :param databricks_conn_id: Reference to the :ref:`Databricks connection
<howto/connection:databricks>`.
+ By default and in the common case this will be ``databricks_default``.
To use
+ token based authentication, provide the key ``token`` in the extra
field for the
+ connection and create the key ``host`` and leave the ``host`` field
empty.
+ :param databricks_retry_limit: Amount of times retry if the Databricks
backend is
+ unreachable. Its value must be greater than or equal to 1.
+ :param databricks_retry_delay: Number of seconds to wait between retries
(it
+ might be a floating point number).
+ """
+
+ # Used in airflow.models.BaseOperator
+ template_fields: Sequence[str] = ('repo_path', 'tag', 'branch')
+
+ __git_providers__ = {
+ "github.com": "gitHub",
+ "dev.azure.com": "azureDevOpsServices",
+ "gitlab.com": "gitLab",
+ "bitbucket.org": "bitbucketCloud",
+ }
+ __aws_code_commit_regexp__ =
re.compile(r"^git-codecommit\.[^.]+\.amazonaws.com$")
+ __repos_path_regexp__ = re.compile(r"/Repos/[^/]+/[^/]+/?$")
+
+ def __init__(
+ self,
+ *,
+ git_url: str,
+ git_provider: Optional[str] = None,
+ branch: Optional[str] = None,
+ tag: Optional[str] = None,
+ repo_path: Optional[str] = None,
+ ignore_existing_repo: bool = False,
+ databricks_conn_id: str = 'databricks_default',
+ databricks_retry_limit: int = 3,
+ databricks_retry_delay: int = 1,
+ **kwargs,
+ ) -> None:
+ """Creates a new ``DatabricksReposCreateOperator``."""
+ super().__init__(**kwargs)
+ self.databricks_conn_id = databricks_conn_id
+ self.databricks_retry_limit = databricks_retry_limit
+ self.databricks_retry_delay = databricks_retry_delay
+ self.git_url = git_url
+ self.ignore_existing_repo = ignore_existing_repo
+ if git_provider is None:
+ self.git_provider = self.__detect_repo_provider__(git_url)
+ if self.git_provider is None:
+ raise AirflowException(
+ "git_provider isn't specified and couldn't be guessed" f"
for URL {git_url}"
+ )
+ else:
+ self.git_provider = git_provider
+ self.repo_path = repo_path
+ if branch is not None and tag is not None:
+ raise AirflowException("Only one of branch or tag should be
provided, but not both")
+ self.branch = branch
+ self.tag = tag
+
+ @staticmethod
+ def __detect_repo_provider__(url):
+ provider = None
+ try:
+ netloc = urlparse(url).netloc
+ idx = netloc.rfind("@")
+ if idx != -1:
+ netloc = netloc[(idx + 1) :]
+ netloc = netloc.lower()
+ provider =
DatabricksReposCreateOperator.__git_providers__.get(netloc)
+ if provider is None and
DatabricksReposCreateOperator.__aws_code_commit_regexp__.match(netloc):
+ provider = "awsCodeCommit"
+ except ValueError:
+ pass
+ return provider
+
+ def _get_hook(self) -> DatabricksHook:
+ return DatabricksHook(
+ self.databricks_conn_id,
+ retry_limit=self.databricks_retry_limit,
+ retry_delay=self.databricks_retry_delay,
+ )
+
+ def execute(self, context: 'Context'):
+ """
+ Creates a Databricks Repo
+
+ :param context: context
+ :return: Repo ID
+ """
+ payload = {
+ "url": self.git_url,
+ "provider": self.git_provider,
+ }
+ if self.repo_path is not None:
+ if not self.__repos_path_regexp__.match(self.repo_path):
+ raise AirflowException(
+ f"repo_path should have form of
/Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'"
+ )
+ payload["path"] = self.repo_path
+ hook = self._get_hook()
+ existing_repo_id = None
+ if self.repo_path is not None:
+ existing_repo_id = hook.get_repo_by_path(self.repo_path)
+ if existing_repo_id is not None and not self.ignore_existing_repo:
+ raise AirflowException(f"Repo with path '{self.repo_path}'
already exists")
+ if existing_repo_id is None:
+ result = hook.create_repo(payload)
+ repo_id = result["id"]
+ else:
+ repo_id = existing_repo_id
+ # update repo if necessary
+ if self.branch is not None:
+ hook.update_repo(str(repo_id), {'branch': str(self.branch)})
+ elif self.tag is not None:
+ hook.update_repo(str(repo_id), {'tag': str(self.tag)})
+
+ return repo_id
+
+
class DatabricksReposUpdateOperator(BaseOperator):
"""
- Updates specified repository to a given branch or tag using
- `api/2.0/repos/
-
<https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/update-repo>`_
- API endpoint.
+ Updates specified repository to a given branch or tag
+ using `PATCH api/2.0/repos
+
<https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/update-repo>`_
API endpoint.
:param branch: optional name of branch to update to. Should be specified
if ``tag`` is omitted
:param tag: optional name of tag to update to. Should be specified if
``branch`` is omitted
@@ -64,7 +195,7 @@ class DatabricksReposUpdateOperator(BaseOperator):
databricks_retry_delay: int = 1,
**kwargs,
) -> None:
- """Creates a new ``DatabricksSubmitRunOperator``."""
+ """Creates a new ``DatabricksReposUpdateOperator``."""
super().__init__(**kwargs)
self.databricks_conn_id = databricks_conn_id
self.databricks_retry_limit = databricks_retry_limit
@@ -76,7 +207,7 @@ class DatabricksReposUpdateOperator(BaseOperator):
if repo_id is not None and repo_path is not None:
raise AirflowException("Only one of repo_id or repo_path should be
provided, but not both")
if repo_id is None and repo_path is None:
- raise AirflowException("One of repo_id repo_path tag should be
provided")
+ raise AirflowException("One of repo_id or repo_path should be
provided")
self.repo_path = repo_path
self.repo_id = repo_id
self.branch = branch
@@ -102,3 +233,63 @@ class DatabricksReposUpdateOperator(BaseOperator):
result = hook.update_repo(str(self.repo_id), payload)
return result['head_commit_id']
+
+
+class DatabricksReposDeleteOperator(BaseOperator):
+ """
+ Deletes specified repository
+ using `DELETE api/2.0/repos
+
<https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/delete-repo>`_
API endpoint.
+
+ :param repo_id: optional ID of existing repository. Should be specified if
``repo_path`` is omitted
+ :param repo_path: optional path of existing repository. Should be
specified if ``repo_id`` is omitted
+ :param databricks_conn_id: Reference to the :ref:`Databricks connection
<howto/connection:databricks>`.
+ By default and in the common case this will be ``databricks_default``.
To use
+ token based authentication, provide the key ``token`` in the extra
field for the
+ connection and create the key ``host`` and leave the ``host`` field
empty.
+ :param databricks_retry_limit: Amount of times retry if the Databricks
backend is
+ unreachable. Its value must be greater than or equal to 1.
+ :param databricks_retry_delay: Number of seconds to wait between retries
(it
+ might be a floating point number).
+ """
+
+ # Used in airflow.models.BaseOperator
+ template_fields: Sequence[str] = ('repo_path',)
+
+ def __init__(
+ self,
+ *,
+ repo_id: Optional[str] = None,
+ repo_path: Optional[str] = None,
+ databricks_conn_id: str = 'databricks_default',
+ databricks_retry_limit: int = 3,
+ databricks_retry_delay: int = 1,
+ **kwargs,
+ ) -> None:
+ """Creates a new ``DatabricksReposDeleteOperator``."""
+ super().__init__(**kwargs)
+ self.databricks_conn_id = databricks_conn_id
+ self.databricks_retry_limit = databricks_retry_limit
+ self.databricks_retry_delay = databricks_retry_delay
+ if repo_id is not None and repo_path is not None:
+ raise AirflowException("Only one of repo_id or repo_path should be
provided, but not both")
+ if repo_id is None and repo_path is None:
+ raise AirflowException("One of repo_id repo_path tag should be
provided")
+ self.repo_path = repo_path
+ self.repo_id = repo_id
+
+ def _get_hook(self) -> DatabricksHook:
+ return DatabricksHook(
+ self.databricks_conn_id,
+ retry_limit=self.databricks_retry_limit,
+ retry_delay=self.databricks_retry_delay,
+ )
+
+ def execute(self, context: 'Context'):
+ hook = self._get_hook()
+ if self.repo_path is not None:
+ self.repo_id = hook.get_repo_by_path(self.repo_path)
+ if self.repo_id is None:
+ raise AirflowException(f"Can't find Repo ID for path
'{self.repo_path}'")
+
+ hook.delete_repo(str(self.repo_id))
diff --git a/airflow/providers/databricks/provider.yaml
b/airflow/providers/databricks/provider.yaml
index 7003fe0..ba9b3f0 100644
--- a/airflow/providers/databricks/provider.yaml
+++ b/airflow/providers/databricks/provider.yaml
@@ -57,7 +57,9 @@ integrations:
- integration-name: Databricks Repos
external-doc-url: https://docs.databricks.com/repos/index.html
how-to-guide:
+ - /docs/apache-airflow-providers-databricks/operators/repos_create.rst
- /docs/apache-airflow-providers-databricks/operators/repos_update.rst
+ - /docs/apache-airflow-providers-databricks/operators/repos_delete.rst
logo: /integration-logos/databricks/Databricks.png
tags: [service]
diff --git
a/docs/apache-airflow-providers-databricks/connections/databricks.rst
b/docs/apache-airflow-providers-databricks/connections/databricks.rst
index 5a87539..cb62ada 100644
--- a/docs/apache-airflow-providers-databricks/connections/databricks.rst
+++ b/docs/apache-airflow-providers-databricks/connections/databricks.rst
@@ -64,9 +64,9 @@ Password (optional)
Extra (optional)
Specify the extra parameter (as json dictionary) that can be used in the
Databricks connection.
- Following parameter should be used if using the *PAT* authentication
method:
+ Following parameter could be used if using the *PAT* authentication method:
- * ``token``: Specify PAT to use. Note, the PAT must appear in both the
Password field as the token value in Extra.
+ * ``token``: Specify PAT to use. Consider to switch to specification of
PAT in the Password field as it's more secure.
Following parameters are necessary if using authentication with AAD token:
diff --git
a/docs/apache-airflow-providers-databricks/operators/repos_create.rst
b/docs/apache-airflow-providers-databricks/operators/repos_create.rst
new file mode 100644
index 0000000..fc04340
--- /dev/null
+++ b/docs/apache-airflow-providers-databricks/operators/repos_create.rst
@@ -0,0 +1,69 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+
+DatabricksReposCreateOperator
+=============================
+
+Use the
:class:`~airflow.providers.databricks.operators.DatabricksReposCreateOperator`
to create (and optionally checkout) a
+`Databricks Repos <https://docs.databricks.com/repos/index.html>`_
+via `api/2.0/repos
<https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/create-repo>`_
API endpoint.
+
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+To use this operator you need to provide at least ``git_url`` parameter.
+
+.. list-table::
+ :widths: 15 25
+ :header-rows: 1
+
+ * - Parameter
+ - Input
+ * - git_url: str
+ - Required HTTPS URL of a Git repository
+ * - git_provider: str
+ - Optional name of Git provider. Must be provided if we can't guess its
name from URL. See API documentation for actual list of supported Git providers.
+ * - branch: str
+ - Optional name of the existing Git branch to checkout.
+ * - tag: str
+ - Optional name of the existing Git tag to checkout.
+ * - repo_path: str
+ - Optional path to a Databricks Repos, like,
``/Repos/<user_email>/repo_name``. If not specified, it will be created in the
user's directory.
+ * - ignore_existing_repo: bool
+ - Don't throw exception if repository with given path already exists.
+ * - databricks_conn_id: string
+ - the name of the Airflow connection to use.
+ * - databricks_retry_limit: integer
+ - amount of times retry if the Databricks backend is unreachable.
+ * - databricks_retry_delay: decimal
+ - number of seconds to wait between retries.
+
+Examples
+--------
+
+Create a Databricks Repo
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+An example usage of the DatabricksReposCreateOperator is as follows:
+
+.. exampleinclude::
/../../airflow/providers/databricks/example_dags/example_databricks_repos.py
+ :language: python
+ :start-after: [START howto_operator_databricks_repo_create]
+ :end-before: [END howto_operator_databricks_repo_create]
diff --git
a/docs/apache-airflow-providers-databricks/operators/repos_delete.rst
b/docs/apache-airflow-providers-databricks/operators/repos_delete.rst
new file mode 100644
index 0000000..e359deb
--- /dev/null
+++ b/docs/apache-airflow-providers-databricks/operators/repos_delete.rst
@@ -0,0 +1,61 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+
+DatabricksReposDeleteOperator
+=============================
+
+Use the
:class:`~airflow.providers.databricks.operators.DatabricksReposDeleteOperator`
to delete an existing
+`Databricks Repo <https://docs.databricks.com/repos/index.html>`_
+via `api/2.0/repos/
<https://docs.databricks.com/dev-tools/api/latest/repos.html#operation/delete-repo>`_
API endpoint.
+
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+To use this operator you need to provide either ``repo_path`` or ``repo_id``.
+
+.. list-table::
+ :widths: 15 25
+ :header-rows: 1
+
+ * - Parameter
+ - Input
+ * - repo_path: str
+ - Path to existing Databricks Repos, like,
``/Repos/<user_email>/repo_name`` (required if ``repo_id`` isn't provided).
+ * - repo_id: str
+ - ID of existing Databricks Repos (required if ``repo_path`` isn't
provided).
+ * - databricks_conn_id: string
+ - the name of the Airflow connection to use.
+ * - databricks_retry_limit: integer
+ - amount of times retry if the Databricks backend is unreachable.
+ * - databricks_retry_delay: decimal
+ - number of seconds to wait between retries.
+
+Examples
+--------
+
+Deleting Databricks Repo by specifying path
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+An example usage of the DatabricksReposDeleteOperator is as follows:
+
+.. exampleinclude::
/../../airflow/providers/databricks/example_dags/example_databricks_repos.py
+ :language: python
+ :start-after: [START howto_operator_databricks_repo_delete]
+ :end-before: [END howto_operator_databricks_repo_delete]
diff --git a/tests/providers/databricks/operators/test_databricks.py
b/tests/providers/databricks/operators/test_databricks.py
index 0d1bd09..c467234 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -571,7 +571,7 @@ class TestDatabricksRunNowOperator(unittest.TestCase):
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run_state.assert_not_called()
- def test_init_exeption_with_job_name_and_job_id(self):
+ def test_init_exception_with_job_name_and_job_id(self):
exception_message = "Argument 'job_name' is not allowed with argument
'job_id'"
with pytest.raises(AirflowException, match=exception_message):
diff --git a/tests/providers/databricks/operators/test_databricks_repos.py
b/tests/providers/databricks/operators/test_databricks_repos.py
index ad8ccdc..aaf0326 100644
--- a/tests/providers/databricks/operators/test_databricks_repos.py
+++ b/tests/providers/databricks/operators/test_databricks_repos.py
@@ -19,7 +19,14 @@
import unittest
from unittest import mock
-from airflow.providers.databricks.operators.databricks_repos import
DatabricksReposUpdateOperator
+import pytest
+
+from airflow import AirflowException
+from airflow.providers.databricks.operators.databricks_repos import (
+ DatabricksReposCreateOperator,
+ DatabricksReposDeleteOperator,
+ DatabricksReposUpdateOperator,
+)
TASK_ID = 'databricks-operator'
DEFAULT_CONN_ID = 'databricks_default'
@@ -29,7 +36,7 @@ class TestDatabricksReposUpdateOperator(unittest.TestCase):
@mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook')
def test_update_with_id(self, db_mock_class):
"""
- Test the execute function in case where the run is successful.
+ Test the execute function using Repo ID.
"""
op = DatabricksReposUpdateOperator(task_id=TASK_ID, branch="releases",
repo_id="123")
db_mock = db_mock_class.return_value
@@ -46,7 +53,7 @@ class TestDatabricksReposUpdateOperator(unittest.TestCase):
@mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook')
def test_update_with_path(self, db_mock_class):
"""
- Test the execute function in case where the run is successful.
+ Test the execute function using Repo path.
"""
op = DatabricksReposUpdateOperator(
task_id=TASK_ID, tag="v1.0.0",
repo_path="/Repos/[email protected]/test-repo"
@@ -62,3 +69,144 @@ class TestDatabricksReposUpdateOperator(unittest.TestCase):
)
db_mock.update_repo.assert_called_once_with('123', {'tag': 'v1.0.0'})
+
+ def test_init_exception(self):
+ """
+ Tests handling of incorrect parameters passed to ``__init__``
+ """
+ with pytest.raises(
+ AirflowException, match="Only one of repo_id or repo_path should
be provided, but not both"
+ ):
+ DatabricksReposUpdateOperator(task_id=TASK_ID, repo_id="abc",
repo_path="path", branch="abc")
+
+ with pytest.raises(AirflowException, match="One of repo_id or
repo_path should be provided"):
+ DatabricksReposUpdateOperator(task_id=TASK_ID, branch="abc")
+
+ with pytest.raises(
+ AirflowException, match="Only one of branch or tag should be
provided, but not both"
+ ):
+ DatabricksReposUpdateOperator(task_id=TASK_ID, repo_id="123",
branch="123", tag="123")
+
+ with pytest.raises(AirflowException, match="One of branch or tag
should be provided"):
+ DatabricksReposUpdateOperator(task_id=TASK_ID, repo_id="123")
+
+
+class TestDatabricksReposDeleteOperator(unittest.TestCase):
+
@mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook')
+ def test_delete_with_id(self, db_mock_class):
+ """
+ Test the execute function using Repo ID.
+ """
+ op = DatabricksReposDeleteOperator(task_id=TASK_ID, repo_id="123")
+ db_mock = db_mock_class.return_value
+ db_mock.delete_repo.return_value = None
+
+ op.execute(None)
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ )
+
+ db_mock.delete_repo.assert_called_once_with('123')
+
+
@mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook')
+ def test_delete_with_path(self, db_mock_class):
+ """
+ Test the execute function using Repo path.
+ """
+ op = DatabricksReposDeleteOperator(task_id=TASK_ID,
repo_path="/Repos/[email protected]/test-repo")
+ db_mock = db_mock_class.return_value
+ db_mock.get_repo_by_path.return_value = '123'
+ db_mock.delete_repo.return_value = None
+
+ op.execute(None)
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ )
+
+ db_mock.delete_repo.assert_called_once_with('123')
+
+ def test_init_exception(self):
+ """
+ Tests handling of incorrect parameters passed to ``__init__``
+ """
+ with pytest.raises(
+ AirflowException, match="Only one of repo_id or repo_path should
be provided, but not both"
+ ):
+ DatabricksReposDeleteOperator(task_id=TASK_ID, repo_id="abc",
repo_path="path")
+
+ with pytest.raises(AirflowException, match="One of repo_id repo_path
tag should be provided"):
+ DatabricksReposDeleteOperator(task_id=TASK_ID)
+
+
+class TestDatabricksReposCreateOperator(unittest.TestCase):
+
@mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook')
+ def test_create_plus_checkout(self, db_mock_class):
+ """
+ Test the execute function creating new Repo.
+ """
+ git_url = "https://github.com/test/test"
+ repo_path = '/Repos/Project1/test-repo'
+ op = DatabricksReposCreateOperator(
+ task_id=TASK_ID, git_url=git_url, repo_path=repo_path,
branch="releases"
+ )
+ db_mock = db_mock_class.return_value
+ db_mock.update_repo.return_value = {'head_commit_id': '123456'}
+ db_mock.create_repo.return_value = {'id': '123', 'branch': 'main'}
+ db_mock.get_repo_by_path.return_value = None
+
+ op.execute(None)
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ )
+
+ db_mock.create_repo.assert_called_once_with({'url': git_url,
'provider': 'gitHub', 'path': repo_path})
+ db_mock.update_repo.assert_called_once_with('123', {'branch':
'releases'})
+
+
@mock.patch('airflow.providers.databricks.operators.databricks_repos.DatabricksHook')
+ def test_create_ignore_existing_plus_checkout(self, db_mock_class):
+ """
+ Test the execute function creating new Repo.
+ """
+ git_url = "https://github.com/test/test"
+ repo_path = '/Repos/Project1/test-repo'
+ op = DatabricksReposCreateOperator(
+ task_id=TASK_ID,
+ git_url=git_url,
+ repo_path=repo_path,
+ branch="releases",
+ ignore_existing_repo=True,
+ )
+ db_mock = db_mock_class.return_value
+ db_mock.update_repo.return_value = {'head_commit_id': '123456'}
+ db_mock.get_repo_by_path.return_value = '123'
+
+ op.execute(None)
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay
+ )
+
+ db_mock.get_repo_by_path.assert_called_once_with(repo_path)
+ db_mock.update_repo.assert_called_once_with('123', {'branch':
'releases'})
+
+ def test_init_exception(self):
+ """
+ Tests handling of incorrect parameters passed to ``__init__``
+ """
+ git_url = "https://github.com/test/test"
+ repo_path = '/Repos/test-repo'
+ exception_message = (
+ f"repo_path should have form of /Repos/{{folder}}/{{repo-name}},
got '{repo_path}'"
+ )
+
+ with pytest.raises(AirflowException, match=exception_message):
+ op = DatabricksReposCreateOperator(task_id=TASK_ID,
git_url=git_url, repo_path=repo_path)
+ op.execute(None)
+
+ with pytest.raises(
+ AirflowException, match="Only one of branch or tag should be
provided, but not both"
+ ):
+ DatabricksReposCreateOperator(task_id=TASK_ID, git_url=git_url,
branch="123", tag="123")