This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e7d8f5bdaba Use SSH to authenticate GitDagBundle (#44976)
e7d8f5bdaba is described below
commit e7d8f5bdaba74f9a3980f4b71e89879f4ff8bfc0
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Thu Jan 16 12:04:09 2025 +0100
Use SSH to authenticate GitDagBundle (#44976)
* Use SSH to authenticate GitDagBundle
This uses SSH hook to authenticate GitDagBundle when provided.
* Add tests
* Account for remotes with ssh
* renames
* fix tests
* Refactor code
* Use githook
* fixup! Use githook
* Populate the connection form with git type connection
* Mark test_dag_bundles as db test
* Add names to the extra items
* Update airflow/dag_processing/bundles/git.py
Co-authored-by: Felix Uellendall <[email protected]>
* Fix refresh
* Apply suggestions from code review
Co-authored-by: Jed Cunningham
<[email protected]>
* Remove ssh hook inheritance
* fixup! Remove ssh hook inheritance
* Apply suggestions from code review
Co-authored-by: Jed Cunningham
<[email protected]>
* Fix code and link to dag processor
* Apply suggestions from code review
Co-authored-by: Jed Cunningham
<[email protected]>
---------
Co-authored-by: Felix Uellendall <[email protected]>
Co-authored-by: Jed Cunningham
<[email protected]>
---
airflow/dag_processing/bundles/base.py | 10 ++
airflow/dag_processing/bundles/git.py | 107 ++++++++++--
airflow/dag_processing/bundles/manager.py | 1 +
airflow/dag_processing/bundles/provider.yaml | 44 +++++
airflow/dag_processing/manager.py | 4 +
airflow/providers_manager.py | 5 +
tests/dag_processing/test_dag_bundles.py | 237 +++++++++++++++++++++++----
7 files changed, 361 insertions(+), 47 deletions(-)
diff --git a/airflow/dag_processing/bundles/base.py
b/airflow/dag_processing/bundles/base.py
index ea560f1be26..cf0467b372a 100644
--- a/airflow/dag_processing/bundles/base.py
+++ b/airflow/dag_processing/bundles/base.py
@@ -50,6 +50,16 @@ class BaseDagBundle(ABC):
self.name = name
self.version = version
self.refresh_interval = refresh_interval
+ self.is_initialized: bool = False
+
+ def initialize(self) -> None:
+ """
+ Initialize the bundle.
+
+ This method is called by the DAG processor before the bundle is used,
+ and allows for deferring expensive operations until that point in time.
+ """
+ self.is_initialized = True
@property
def _dag_bundle_root_storage_path(self) -> Path:
diff --git a/airflow/dag_processing/bundles/git.py
b/airflow/dag_processing/bundles/git.py
index d731f65db3b..4b2a19de364 100644
--- a/airflow/dag_processing/bundles/git.py
+++ b/airflow/dag_processing/bundles/git.py
@@ -17,8 +17,9 @@
from __future__ import annotations
+import json
import os
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from git import Repo
@@ -26,63 +27,141 @@ from git.exc import BadName
from airflow.dag_processing.bundles.base import BaseDagBundle
from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
if TYPE_CHECKING:
from pathlib import Path
-class GitDagBundle(BaseDagBundle):
+class GitHook(BaseHook):
+ """
+ Hook for git repositories.
+
+ :param git_conn_id: Connection ID for SSH connection to the repository
+
+ """
+
+ conn_name_attr = "git_conn_id"
+ default_conn_name = "git_default"
+ conn_type = "git"
+ hook_name = "GIT"
+
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
+ return {
+ "hidden_fields": ["schema"],
+ "relabeling": {
+ "login": "Username",
+ "host": "Repository URL",
+ "password": "Access Token (optional)",
+ },
+ "placeholders": {
+ "extra": json.dumps(
+ {
+ "key_file": "optional/path/to/keyfile",
+ }
+ )
+ },
+ }
+
+ def __init__(self, git_conn_id="git_default", *args, **kwargs):
+ super().__init__()
+ connection = self.get_connection(git_conn_id)
+ self.repo_url = connection.host
+ self.auth_token = connection.password
+ self.key_file = connection.extra_dejson.get("key_file")
+ self.env: dict[str, str] = {}
+ if self.key_file:
+ self.env["GIT_SSH_COMMAND"] = f"ssh -i {self.key_file} -o
IdentitiesOnly=yes"
+ self._process_git_auth_url()
+
+ def _process_git_auth_url(self):
+ if not isinstance(self.repo_url, str):
+ return
+ if self.auth_token and self.repo_url.startswith("https://"):
+ self.repo_url = self.repo_url.replace("https://",
f"https://{self.auth_token}@")
+ elif not self.repo_url.startswith("git@") or not
self.repo_url.startswith("https://"):
+ self.repo_url = os.path.expanduser(self.repo_url)
+
+
+class GitDagBundle(BaseDagBundle, LoggingMixin):
"""
git DAG bundle - exposes a git repository as a DAG bundle.
Instead of cloning the repository every time, we clone the repository once
into a bare repo from the source
and then do a clone for each version from there.
- :param repo_url: URL of the git repository
:param tracking_ref: Branch or tag for this DAG bundle
:param subdir: Subdirectory within the repository where the DAGs are
stored (Optional)
+ :param git_conn_id: Connection ID for SSH/token based connection to the
repository (Optional)
"""
supports_versioning = True
- def __init__(self, *, repo_url: str, tracking_ref: str, subdir: str | None
= None, **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ tracking_ref: str,
+ subdir: str | None = None,
+ git_conn_id: str = "git_default",
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
- self.repo_url = repo_url
self.tracking_ref = tracking_ref
self.subdir = subdir
-
self.bare_repo_path = self._dag_bundle_root_storage_path / "git" /
self.name
self.repo_path = (
self._dag_bundle_root_storage_path / "git" / (self.name +
f"+{self.version or self.tracking_ref}")
)
+ self.git_conn_id = git_conn_id
+ self.hook = GitHook(git_conn_id=self.git_conn_id)
+ self.repo_url = self.hook.repo_url
+
+ def _initialize(self):
self._clone_bare_repo_if_required()
self._ensure_version_in_bare_repo()
self._clone_repo_if_required()
self.repo.git.checkout(self.tracking_ref)
-
if self.version:
if not self._has_version(self.repo, self.version):
self.repo.remotes.origin.fetch()
-
self.repo.head.set_reference(self.repo.commit(self.version))
self.repo.head.reset(index=True, working_tree=True)
else:
self.refresh()
+ def initialize(self) -> None:
+ if not self.repo_url:
+ raise AirflowException(f"Connection {self.git_conn_id} doesn't
have a git_repo_url")
+ if isinstance(self.repo_url, os.PathLike):
+ self._initialize()
+ elif not self.repo_url.startswith("git@") or not
self.repo_url.endswith(".git"):
+ raise AirflowException(
+ f"Invalid git URL: {self.repo_url}. URL must start with git@
and end with .git"
+ )
+ else:
+ self._initialize()
+ super().initialize()
+
def _clone_repo_if_required(self) -> None:
if not os.path.exists(self.repo_path):
+ self.log.info("Cloning repository to %s from %s", self.repo_path,
self.bare_repo_path)
Repo.clone_from(
url=self.bare_repo_path,
to_path=self.repo_path,
)
+
self.repo = Repo(self.repo_path)
def _clone_bare_repo_if_required(self) -> None:
if not os.path.exists(self.bare_repo_path):
+ self.log.info("Cloning bare repository to %s", self.bare_repo_path)
Repo.clone_from(
url=self.repo_url,
to_path=self.bare_repo_path,
bare=True,
+ env=self.hook.env,
)
self.bare_repo = Repo(self.bare_repo_path)
@@ -90,7 +169,7 @@ class GitDagBundle(BaseDagBundle):
if not self.version:
return
if not self._has_version(self.bare_repo, self.version):
- self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+ self._fetch_bare_repo()
if not self._has_version(self.bare_repo, self.version):
raise AirflowException(f"Version {self.version} not found in
the repository")
@@ -121,11 +200,17 @@ class GitDagBundle(BaseDagBundle):
except BadName:
return False
+ def _fetch_bare_repo(self):
+ if self.hook.env:
+ with
self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
+
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+ else:
+ self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+
def refresh(self) -> None:
if self.version:
raise AirflowException("Refreshing a specific version is not
supported")
-
- self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+ self._fetch_bare_repo()
self.repo.remotes.origin.pull()
def _convert_git_ssh_url_to_https(self) -> str:
diff --git a/airflow/dag_processing/bundles/manager.py
b/airflow/dag_processing/bundles/manager.py
index 1ae751f8d33..ad1ebc58891 100644
--- a/airflow/dag_processing/bundles/manager.py
+++ b/airflow/dag_processing/bundles/manager.py
@@ -96,6 +96,7 @@ class DagBundlesManager(LoggingMixin):
class_ = import_string(cfg["classpath"])
kwargs = cfg["kwargs"]
self._bundle_config[name] = (class_, kwargs)
+ self.log.info("DAG bundles loaded: %s", ",
".join(self._bundle_config.keys()))
@provide_session
def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None:
diff --git a/airflow/dag_processing/bundles/provider.yaml
b/airflow/dag_processing/bundles/provider.yaml
new file mode 100644
index 00000000000..9ca5d1479f2
--- /dev/null
+++ b/airflow/dag_processing/bundles/provider.yaml
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+package-name: apache-airflow-providers-bundles
+name: GIT
+description: |
+ `GIT <https://git-scm.com/>`__
+
+state: not-ready
+source-date-epoch: 1726861127
+# note that those versions are maintained by release manager - do not update
them manually
+versions:
+ - 1.0.0
+
+dependencies:
+ - apache-airflow-providers-ssh
+
+integrations:
+ - integration-name: GIT (Git)
+
+hooks:
+ - integration-name: GIT
+ python-modules:
+ - airflow.dag_processing.bundles.git
+
+
+connection-types:
+ - hook-class-name: airflow.dag_processing.bundles.git.GitHook
+ connection-type: git
diff --git a/airflow/dag_processing/manager.py
b/airflow/dag_processing/manager.py
index 96c7fe4f0ed..220b55edce6 100644
--- a/airflow/dag_processing/manager.py
+++ b/airflow/dag_processing/manager.py
@@ -653,6 +653,10 @@ class DagFileProcessorManager:
self.log.info("Refreshing DAG bundles")
for bundle in self._dag_bundles:
+ # TODO: AIP-66 handle errors in the case of incomplete cloning?
And test this.
+ # What if the cloning/refreshing took too long(longer than the
dag processor timeout)
+ if not bundle.is_initialized:
+ bundle.initialize()
# TODO: AIP-66 test to make sure we get a fresh record from the db
and it's not cached
with create_session() as session:
bundle_model = session.get(DagBundleModel, bundle.name)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 575306a840b..9b39439384f 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -175,6 +175,9 @@ def
_create_customized_form_field_behaviours_schema_validator():
def _check_builtin_provider_prefix(provider_package: str, class_name: str) ->
bool:
+ if "bundles" in provider_package:
+ # TODO: AIP-66: remove this when this package is moved to providers
directory
+ return True
if provider_package.startswith("apache-airflow"):
provider_path = provider_package[len("apache-") :].replace("-", ".")
if not class_name.startswith(provider_path):
@@ -676,6 +679,8 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
self._add_provider_info_from_local_source_files_on_path(path)
except Exception as e:
log.warning("Error when loading 'provider.yaml' files from %s
airflow sources: %s", path, e)
+ # TODO: AIP-66: Remove this when the package is moved to providers
+
self._add_provider_info_from_local_source_files_on_path("airflow/dag_processing")
def _add_provider_info_from_local_source_files_on_path(self, path) -> None:
"""
diff --git a/tests/dag_processing/test_dag_bundles.py
b/tests/dag_processing/test_dag_bundles.py
index d450a561313..49b7da1a03a 100644
--- a/tests/dag_processing/test_dag_bundles.py
+++ b/tests/dag_processing/test_dag_bundles.py
@@ -26,11 +26,16 @@ from git import Repo
from airflow.dag_processing.bundles.base import BaseDagBundle
from airflow.dag_processing.bundles.dagfolder import DagsFolderDagBundle
-from airflow.dag_processing.bundles.git import GitDagBundle
+from airflow.dag_processing.bundles.git import GitDagBundle, GitHook
from airflow.dag_processing.bundles.local import LocalDagBundle
from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.utils import db
from tests_common.test_utils.config import conf_vars
+from tests_common.test_utils.db import clear_db_connections
+
+pytestmark = pytest.mark.db_test
@pytest.fixture(autouse=True)
@@ -107,27 +112,111 @@ def git_repo(tmp_path_factory):
return (directory, repo)
+AIRFLOW_HTTPS_URL = "https://github.com/apache/airflow.git"
+AIRFLOW_GIT = "[email protected]:apache/airflow.git"
+ACCESS_TOKEN = "my_access_token"
+CONN_DEFAULT = "git_default"
+CONN_HTTPS = "my_git_conn"
+CONN_HTTPS_PASSWORD = "my_git_conn_https_password"
+CONN_ONLY_PATH = "my_git_conn_only_path"
+CONN_NO_REPO_URL = "my_git_conn_no_repo_url"
+
+
+class TestGitHook:
+ @classmethod
+ def teardown_class(cls) -> None:
+ clear_db_connections()
+
+ @classmethod
+ def setup_class(cls) -> None:
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_DEFAULT,
+ host=AIRFLOW_GIT,
+ conn_type="git",
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_HTTPS,
+ host=AIRFLOW_HTTPS_URL,
+ password=ACCESS_TOKEN,
+ conn_type="git",
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_HTTPS_PASSWORD,
+ host=AIRFLOW_HTTPS_URL,
+ conn_type="git",
+ password=ACCESS_TOKEN,
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_ONLY_PATH,
+ host="path/to/repo",
+ conn_type="git",
+ )
+ )
+
+ @pytest.mark.parametrize(
+ "conn_id, expected_repo_url",
+ [
+ (CONN_DEFAULT, AIRFLOW_GIT),
+ (CONN_HTTPS,
f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"),
+ (CONN_HTTPS_PASSWORD,
f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"),
+ (CONN_ONLY_PATH, "path/to/repo"),
+ ],
+ )
+ def test_correct_repo_urls(self, conn_id, expected_repo_url):
+ hook = GitHook(git_conn_id=conn_id)
+ assert hook.repo_url == expected_repo_url
+
+
class TestGitDagBundle:
+ @classmethod
+ def teardown_class(cls) -> None:
+ clear_db_connections()
+
+ @classmethod
+ def setup_class(cls) -> None:
+ db.merge_conn(
+ Connection(
+ conn_id="git_default",
+ host="[email protected]:apache/airflow.git",
+ conn_type="git",
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_NO_REPO_URL,
+ conn_type="git",
+ )
+ )
+
def test_supports_versioning(self):
assert GitDagBundle.supports_versioning is True
def test_uses_dag_bundle_root_storage_path(self, git_repo):
repo_path, repo = git_repo
- bundle = GitDagBundle(
- name="test", refresh_interval=300, repo_url=repo_path,
tracking_ref=GIT_DEFAULT_BRANCH
- )
+ bundle = GitDagBundle(name="test", refresh_interval=300,
tracking_ref=GIT_DEFAULT_BRANCH)
assert str(bundle._dag_bundle_root_storage_path) in str(bundle.path)
- def test_get_current_version(self, git_repo):
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_get_current_version(self, mock_githook, git_repo):
repo_path, repo = git_repo
- bundle = GitDagBundle(
- name="test", refresh_interval=300, repo_url=repo_path,
tracking_ref=GIT_DEFAULT_BRANCH
- )
+ mock_githook.return_value.repo_url = repo_path
+ bundle = GitDagBundle(name="test", refresh_interval=300,
tracking_ref=GIT_DEFAULT_BRANCH)
+
+ bundle.initialize()
assert bundle.get_current_version() == repo.head.commit.hexsha
- def test_get_specific_version(self, git_repo):
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_get_specific_version(self, mock_githook, git_repo):
repo_path, repo = git_repo
+ mock_githook.return_value.repo_url = repo_path
starting_commit = repo.head.commit
# Add new file to the repo
@@ -141,17 +230,19 @@ class TestGitDagBundle:
name="test",
refresh_interval=300,
version=starting_commit.hexsha,
- repo_url=repo_path,
tracking_ref=GIT_DEFAULT_BRANCH,
)
+ bundle.initialize()
assert bundle.get_current_version() == starting_commit.hexsha
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
assert {"test_dag.py"} == files_in_repo
- def test_get_tag_version(self, git_repo):
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_get_tag_version(self, mock_githook, git_repo):
repo_path, repo = git_repo
+ mock_githook.return_value.repo_url = repo_path
starting_commit = repo.head.commit
# add tag
@@ -169,17 +260,18 @@ class TestGitDagBundle:
name="test",
refresh_interval=300,
version="test",
- repo_url=repo_path,
tracking_ref=GIT_DEFAULT_BRANCH,
)
-
+ bundle.initialize()
assert bundle.get_current_version() == starting_commit.hexsha
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
assert {"test_dag.py"} == files_in_repo
- def test_get_latest(self, git_repo):
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_get_latest(self, mock_githook, git_repo):
repo_path, repo = git_repo
+ mock_githook.return_value.repo_url = repo_path
starting_commit = repo.head.commit
file_path = repo_path / "new_test.py"
@@ -188,22 +280,22 @@ class TestGitDagBundle:
repo.index.add([file_path])
repo.index.commit("Another commit")
- bundle = GitDagBundle(
- name="test", refresh_interval=300, repo_url=repo_path,
tracking_ref=GIT_DEFAULT_BRANCH
- )
+ bundle = GitDagBundle(name="test", refresh_interval=300,
tracking_ref=GIT_DEFAULT_BRANCH)
+ bundle.initialize()
assert bundle.get_current_version() != starting_commit.hexsha
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
assert {"test_dag.py", "new_test.py"} == files_in_repo
- def test_refresh(self, git_repo):
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_refresh(self, mock_githook, git_repo):
repo_path, repo = git_repo
+ mock_githook.return_value.repo_url = repo_path
starting_commit = repo.head.commit
- bundle = GitDagBundle(
- name="test", refresh_interval=300, repo_url=repo_path,
tracking_ref=GIT_DEFAULT_BRANCH
- )
+ bundle = GitDagBundle(name="test", refresh_interval=300,
tracking_ref=GIT_DEFAULT_BRANCH)
+ bundle.initialize()
assert bundle.get_current_version() == starting_commit.hexsha
@@ -223,27 +315,34 @@ class TestGitDagBundle:
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
assert {"test_dag.py", "new_test.py"} == files_in_repo
- def test_head(self, git_repo):
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_head(self, mock_githook, git_repo):
repo_path, repo = git_repo
+ mock_githook.return_value.repo_url = repo_path
repo.create_head("test")
- bundle = GitDagBundle(name="test", refresh_interval=300,
repo_url=repo_path, tracking_ref="test")
+ bundle = GitDagBundle(name="test", refresh_interval=300,
tracking_ref="test")
+ bundle.initialize()
assert bundle.repo.head.ref.name == "test"
- def test_version_not_found(self, git_repo):
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_version_not_found(self, mock_githook, git_repo):
repo_path, repo = git_repo
+ mock_githook.return_value.repo_url = repo_path
+ bundle = GitDagBundle(
+ name="test",
+ refresh_interval=300,
+ version="not_found",
+ tracking_ref=GIT_DEFAULT_BRANCH,
+ )
with pytest.raises(AirflowException, match="Version not_found not
found in the repository"):
- GitDagBundle(
- name="test",
- refresh_interval=300,
- version="not_found",
- repo_url=repo_path,
- tracking_ref=GIT_DEFAULT_BRANCH,
- )
+ bundle.initialize()
- def test_subdir(self, git_repo):
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_subdir(self, mock_githook, git_repo):
repo_path, repo = git_repo
+ mock_githook.return_value.repo_url = repo_path
subdir = "somesubdir"
subdir_path = repo_path / subdir
@@ -258,15 +357,75 @@ class TestGitDagBundle:
bundle = GitDagBundle(
name="test",
refresh_interval=300,
- repo_url=repo_path,
tracking_ref=GIT_DEFAULT_BRANCH,
subdir=subdir,
)
+ bundle.initialize()
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
assert str(bundle.path).endswith(subdir)
assert {"some_new_file.py"} == files_in_repo
+ def test_raises_when_no_repo_url(self):
+ bundle = GitDagBundle(
+ name="test",
+ refresh_interval=300,
+ git_conn_id=CONN_NO_REPO_URL,
+ tracking_ref=GIT_DEFAULT_BRANCH,
+ )
+ with pytest.raises(
+ AirflowException, match=f"Connection {CONN_NO_REPO_URL} doesn't
have a git_repo_url"
+ ):
+ bundle.initialize()
+
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ @mock.patch("airflow.dag_processing.bundles.git.Repo")
+ def test_with_path_as_repo_url(self, mock_gitRepo, mock_githook):
+ bundle = GitDagBundle(
+ name="test",
+ refresh_interval=300,
+ git_conn_id=CONN_ONLY_PATH,
+ tracking_ref=GIT_DEFAULT_BRANCH,
+ )
+ bundle.initialize()
+ assert mock_gitRepo.clone_from.call_count == 2
+ assert mock_gitRepo.return_value.git.checkout.call_count == 1
+
+ @mock.patch("airflow.dag_processing.bundles.git.Repo")
+ def test_refresh_with_git_connection(self, mock_gitRepo):
+ bundle = GitDagBundle(
+ name="test",
+ refresh_interval=300,
+ git_conn_id="git_default",
+ tracking_ref=GIT_DEFAULT_BRANCH,
+ )
+ bundle.initialize()
+ bundle.refresh()
+ # check remotes called twice. one at initialize and one at refresh
above
+ assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2
+
+ @pytest.mark.parametrize(
+ "repo_url",
+ [
+ pytest.param("https://github.com/apache/airflow", id="https_url"),
+ pytest.param("airflow@example:apache/airflow.git",
id="does_not_start_with_git_at"),
+ pytest.param("git@example:apache/airflow",
id="does_not_end_with_dot_git"),
+ ],
+ )
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_repo_url_validation_for_ssh(self, mock_hook, repo_url, session):
+ mock_hook.return_value.repo_url = repo_url
+ bundle = GitDagBundle(
+ name="test",
+ refresh_interval=300,
+ git_conn_id="git_default",
+ tracking_ref=GIT_DEFAULT_BRANCH,
+ )
+ with pytest.raises(
+ AirflowException, match=f"Invalid git URL: {repo_url}. URL must
start with git@ and end with .git"
+ ):
+ bundle.initialize()
+
@pytest.mark.parametrize(
"repo_url, expected_url",
[
@@ -280,11 +439,18 @@ class TestGitDagBundle:
],
)
@mock.patch("airflow.dag_processing.bundles.git.Repo")
- def test_view_url(self, mock_gitrepo, repo_url, expected_url):
+ def test_view_url(self, mock_gitrepo, repo_url, expected_url, session):
+ session.query(Connection).delete()
+ conn = Connection(
+ conn_id="git_default",
+ host=repo_url,
+ conn_type="git",
+ )
+ session.add(conn)
+ session.commit()
bundle = GitDagBundle(
name="test",
refresh_interval=300,
- repo_url=repo_url,
tracking_ref="main",
)
view_url = bundle.view_url("0f0f0f")
@@ -295,7 +461,6 @@ class TestGitDagBundle:
bundle = GitDagBundle(
name="test",
refresh_interval=300,
- repo_url="[email protected]:apache/airflow.git",
tracking_ref="main",
)
view_url = bundle.view_url(None)