jedcunningham commented on code in PR #44976:
URL: https://github.com/apache/airflow/pull/44976#discussion_r1915069817
##########
airflow/dag_processing/bundles/git.py:
##########
@@ -121,12 +197,32 @@ def _has_version(repo: Repo, version: str) -> bool:
except BadName:
return False
+ def _fetch_repo(self):
+ if self.hook.env:
+ with
self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
+ self.repo.remotes.origin.fetch()
+ else:
+ self.repo.remotes.origin.fetch()
+
+ def _fetch_bare_repo(self):
+ if self.hook.env:
+ with
self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
+
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+ else:
+ self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+
+ def _pull_repo(self):
+ if self.hook.env:
+ with
self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
Review Comment:
And that means this won't need this either - as upstream in this case will
be the local bare repo. Also means we don't really need a helper for it imo.
##########
tests/dag_processing/test_dag_bundles.py:
##########
@@ -107,27 +112,113 @@ def git_repo(tmp_path_factory):
return (directory, repo)
+AIRFLOW_HTTPS_URL = "https://github.com/apache/airflow.git"
+AIRFLOW_GIT = "[email protected]:apache/airflow.git"
+ACCESS_TOKEN = "my_access_token"
+CONN_DEFAULT = "git_default"
+CONN_HTTPS = "my_git_conn"
+CONN_HTTPS_PASSWORD = "my_git_conn_https_password"
+CONN_ONLY_PATH = "my_git_conn_only_path"
+CONN_NO_REPO_URL = "my_git_conn_no_repo_url"
+
+
+class TestGitHook:
+ @classmethod
+ def teardown_class(cls) -> None:
+ clear_db_connections()
+
+ @classmethod
+ def setup_class(cls) -> None:
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_DEFAULT,
+ host=AIRFLOW_GIT,
+ conn_type="git",
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_HTTPS,
+ host=AIRFLOW_HTTPS_URL,
+ password=ACCESS_TOKEN,
+ conn_type="git",
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_HTTPS_PASSWORD,
+ host=AIRFLOW_HTTPS_URL,
+ conn_type="git",
+ password=ACCESS_TOKEN,
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_ONLY_PATH,
+ host="path/to/repo",
+ conn_type="git",
+ )
+ )
+
+ @pytest.mark.parametrize(
+ "conn_id, expected_repo_url",
+ [
+ (CONN_DEFAULT, AIRFLOW_GIT),
+ (CONN_HTTPS,
f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"),
+ (CONN_HTTPS_PASSWORD,
f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git"),
+ (CONN_ONLY_PATH, "path/to/repo"),
+ ],
+ )
+ def test_correct_repo_urls(self, conn_id, expected_repo_url):
+ hook = GitHook(git_conn_id=conn_id)
+ assert hook.repo_url == expected_repo_url
+
+
class TestGitDagBundle:
+ @classmethod
+ def teardown_class(cls) -> None:
+ clear_db_connections()
+
+ @classmethod
+ def setup_class(cls) -> None:
+ db.merge_conn(
+ Connection(
+ conn_id="git_default",
+ host="[email protected]:apache/airflow.git",
+ conn_type="git",
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=CONN_NO_REPO_URL,
+ conn_type="git",
+ )
+ )
+
def test_supports_versioning(self):
assert GitDagBundle.supports_versioning is True
def test_uses_dag_bundle_root_storage_path(self, git_repo):
repo_path, repo = git_repo
- bundle = GitDagBundle(
- name="test", refresh_interval=300, repo_url=repo_path,
tracking_ref=GIT_DEFAULT_BRANCH
- )
+ bundle = GitDagBundle(name="test", refresh_interval=300,
tracking_ref=GIT_DEFAULT_BRANCH)
assert str(bundle._dag_bundle_root_storage_path) in str(bundle.path)
- def test_get_current_version(self, git_repo):
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_get_current_version(self, mock_githook, git_repo):
+ mock_githook.get_conn.return_value = mock.MagicMock()
Review Comment:
```suggestion
```
This is the default behavior, no need to do it explicitly.
##########
tests/dag_processing/test_dag_bundles.py:
##########
@@ -258,15 +365,78 @@ def test_subdir(self, git_repo):
bundle = GitDagBundle(
name="test",
refresh_interval=300,
- repo_url=repo_path,
tracking_ref=GIT_DEFAULT_BRANCH,
subdir=subdir,
)
+ bundle.initialize()
files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
assert str(bundle.path).endswith(subdir)
assert {"some_new_file.py"} == files_in_repo
+ def test_raises_when_no_repo_url(self):
+ bundle = GitDagBundle(
+ name="test",
+ refresh_interval=300,
+ git_conn_id=CONN_NO_REPO_URL,
+ tracking_ref=GIT_DEFAULT_BRANCH,
+ )
+ with pytest.raises(
+ AirflowException, match=f"Connection {CONN_NO_REPO_URL} doesn't
have a git_repo_url"
+ ):
+ bundle.initialize()
+
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ @mock.patch("airflow.dag_processing.bundles.git.Repo")
+ @mock.patch.object(GitDagBundle, "_clone_from")
+ def test_with_path_as_repo_url(self, mock_clone_from, mock_gitRepo,
mock_githook):
+ bundle = GitDagBundle(
+ name="test",
+ refresh_interval=300,
+ git_conn_id=CONN_ONLY_PATH,
+ tracking_ref=GIT_DEFAULT_BRANCH,
+ )
+ bundle.initialize()
+ assert mock_clone_from.call_count == 2
+ assert mock_gitRepo.return_value.git.checkout.call_count == 1
+
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ @mock.patch("airflow.dag_processing.bundles.git.Repo")
+ def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook):
+ bundle = GitDagBundle(
+ name="test",
+ refresh_interval=300,
+ git_conn_id="git_default",
+ tracking_ref=GIT_DEFAULT_BRANCH,
+ )
+ bundle.initialize()
+ bundle.refresh()
+ # check remotes called twice. one at initialize and one at refresh
above
+ assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2
+
+ @pytest.mark.parametrize(
+ "repo_url",
+ [
+ pytest.param("https://github.com/apache/airflow", id="https_url"),
+ pytest.param("airflow@example:apache/airflow.git",
id="does_not_start_with_git_at"),
+ pytest.param("git@example:apache/airflow",
id="does_not_end_with_dot_git"),
+ ],
+ )
+ @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+ def test_repo_url_starts_with_git_when_using_ssh_conn_id(self, mock_hook,
repo_url, session):
Review Comment:
```suggestion
def test_repo_url_validation_for_ssh(self, mock_hook, repo_url, session):
```
or something
##########
airflow/dag_processing/bundles/git.py:
##########
@@ -17,70 +17,146 @@
from __future__ import annotations
+import json
import os
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from git import Repo
from git.exc import BadName
from airflow.dag_processing.bundles.base import BaseDagBundle
from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
if TYPE_CHECKING:
from pathlib import Path
-class GitDagBundle(BaseDagBundle):
+class GitHook(BaseHook):
+ """
+ Hook for git repositories.
+
+ :param git_conn_id: Connection ID for SSH connection to the repository
+
+ """
+
+ conn_name_attr = "git_conn_id"
+ default_conn_name = "git_default"
+ conn_type = "git"
+ hook_name = "GIT"
+
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
+ return {
+ "hidden_fields": ["schema"],
+ "relabeling": {
+ "login": "Username",
+ "host": "Repository URL",
+ "password": "Access Token (optional)",
+ },
+ "placeholders": {
+ "extra": json.dumps(
+ {
+ "key_file": "optional/path/to/keyfile",
+ }
+ )
+ },
+ }
+
+ def __init__(self, git_conn_id="git_default", *args, **kwargs):
+ super().__init__()
+ connection = self.get_connection(git_conn_id)
+ self.repo_url = connection.host
+ self.auth_token = connection.password
+ self.key_file = connection.extra_dejson.get("key_file")
+ self.env: dict[str, str] = {}
+ if self.key_file:
+ self.env["GIT_SSH_COMMAND"] = f"ssh -i {self.key_file} -o
IdentitiesOnly=yes"
+ self._process_git_auth_url()
+
+ def _process_git_auth_url(self):
+ if not isinstance(self.repo_url, str):
+ return
+ if self.auth_token and self.repo_url.startswith("https://"):
+ self.repo_url = self.repo_url.replace("https://",
f"https://{self.auth_token}@")
+ elif not self.repo_url.startswith("git@") or not
self.repo_url.startswith("https://"):
+ self.repo_url = os.path.expanduser(self.repo_url)
+
+
+class GitDagBundle(BaseDagBundle, LoggingMixin):
"""
git DAG bundle - exposes a git repository as a DAG bundle.
Instead of cloning the repository every time, we clone the repository once
into a bare repo from the source
and then do a clone for each version from there.
- :param repo_url: URL of the git repository
:param tracking_ref: Branch or tag for this DAG bundle
:param subdir: Subdirectory within the repository where the DAGs are
stored (Optional)
+ :param git_conn_id: Connection ID for SSH/token based connection to the
repository (Optional)
"""
supports_versioning = True
- def __init__(self, *, repo_url: str, tracking_ref: str, subdir: str | None
= None, **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ tracking_ref: str,
+ subdir: str | None = None,
+ git_conn_id: str = "git_default",
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
- self.repo_url = repo_url
self.tracking_ref = tracking_ref
self.subdir = subdir
-
self.bare_repo_path = self._dag_bundle_root_storage_path / "git" /
self.name
self.repo_path = (
self._dag_bundle_root_storage_path / "git" / (self.name +
f"+{self.version or self.tracking_ref}")
)
+ self.git_conn_id = git_conn_id
+ self.hook = GitHook(git_conn_id=self.git_conn_id)
+ self.repo_url = self.hook.repo_url
+
+ def _clone_from(self, to_path: Path, bare: bool = False) -> Repo:
+ self.log.info("Cloning %s to %s", self.repo_url, to_path)
+ return Repo.clone_from(self.repo_url, to_path, bare=bare,
env=self.hook.env)
+
+ def _initialize(self):
self._clone_bare_repo_if_required()
self._ensure_version_in_bare_repo()
self._clone_repo_if_required()
self.repo.git.checkout(self.tracking_ref)
-
if self.version:
if not self._has_version(self.repo, self.version):
- self.repo.remotes.origin.fetch()
-
+ self._fetch_repo()
self.repo.head.set_reference(self.repo.commit(self.version))
self.repo.head.reset(index=True, working_tree=True)
else:
self.refresh()
+ def initialize(self) -> None:
+ if not self.repo_url:
+ raise AirflowException(f"Connection {self.git_conn_id} doesn't
have a git_repo_url")
+ if isinstance(self.repo_url, os.PathLike):
+ self._initialize()
+ elif not self.repo_url.startswith("git@") or not
self.repo_url.endswith(".git"):
+ raise AirflowException(
+ f"Invalid git URL: {self.repo_url}. URL must start with git@
and end with .git"
+ )
+ else:
+ self._initialize()
+
def _clone_repo_if_required(self) -> None:
if not os.path.exists(self.repo_path):
- Repo.clone_from(
- url=self.bare_repo_path,
Review Comment:
We do want to clone from the bare repo for this one, so you'll want to pass
url also.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]