jedcunningham commented on code in PR #44976:
URL: https://github.com/apache/airflow/pull/44976#discussion_r1915069817


##########
airflow/dag_processing/bundles/git.py:
##########
@@ -121,12 +197,32 @@ def _has_version(repo: Repo, version: str) -> bool:
         except BadName:
             return False
 
+    def _fetch_repo(self):
+        if self.hook.env:
+            with 
self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
+                self.repo.remotes.origin.fetch()
+        else:
+            self.repo.remotes.origin.fetch()
+
+    def _fetch_bare_repo(self):
+        if self.hook.env:
+            with 
self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
+                
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+        else:
+            self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+
+    def _pull_repo(self):
+        if self.hook.env:
+            with 
self.repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):

Review Comment:
   And that means this won't need this either - as upstream in this case will 
be the local bare repo. Also means we don't really need a helper for it imo.



##########
tests/dag_processing/test_dag_bundles.py:
##########
@@ -107,27 +112,113 @@ def git_repo(tmp_path_factory):
     return (directory, repo)
 
 
+AIRFLOW_HTTPS_URL = "https://github.com/apache/airflow.git";
+AIRFLOW_GIT = "[email protected]:apache/airflow.git"
+ACCESS_TOKEN = "my_access_token"
+CONN_DEFAULT = "git_default"
+CONN_HTTPS = "my_git_conn"
+CONN_HTTPS_PASSWORD = "my_git_conn_https_password"
+CONN_ONLY_PATH = "my_git_conn_only_path"
+CONN_NO_REPO_URL = "my_git_conn_no_repo_url"
+
+
+class TestGitHook:
+    @classmethod
+    def teardown_class(cls) -> None:
+        clear_db_connections()
+
+    @classmethod
+    def setup_class(cls) -> None:
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_DEFAULT,
+                host=AIRFLOW_GIT,
+                conn_type="git",
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_HTTPS,
+                host=AIRFLOW_HTTPS_URL,
+                password=ACCESS_TOKEN,
+                conn_type="git",
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_HTTPS_PASSWORD,
+                host=AIRFLOW_HTTPS_URL,
+                conn_type="git",
+                password=ACCESS_TOKEN,
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_ONLY_PATH,
+                host="path/to/repo",
+                conn_type="git",
+            )
+        )
+
+    @pytest.mark.parametrize(
+        "conn_id, expected_repo_url",
+        [
+            (CONN_DEFAULT, AIRFLOW_GIT),
+            (CONN_HTTPS, 
f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git";),
+            (CONN_HTTPS_PASSWORD, 
f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git";),
+            (CONN_ONLY_PATH, "path/to/repo"),
+        ],
+    )
+    def test_correct_repo_urls(self, conn_id, expected_repo_url):
+        hook = GitHook(git_conn_id=conn_id)
+        assert hook.repo_url == expected_repo_url
+
+
 class TestGitDagBundle:
+    @classmethod
+    def teardown_class(cls) -> None:
+        clear_db_connections()
+
+    @classmethod
+    def setup_class(cls) -> None:
+        db.merge_conn(
+            Connection(
+                conn_id="git_default",
+                host="[email protected]:apache/airflow.git",
+                conn_type="git",
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_NO_REPO_URL,
+                conn_type="git",
+            )
+        )
+
     def test_supports_versioning(self):
         assert GitDagBundle.supports_versioning is True
 
     def test_uses_dag_bundle_root_storage_path(self, git_repo):
         repo_path, repo = git_repo
-        bundle = GitDagBundle(
-            name="test", refresh_interval=300, repo_url=repo_path, 
tracking_ref=GIT_DEFAULT_BRANCH
-        )
+        bundle = GitDagBundle(name="test", refresh_interval=300, 
tracking_ref=GIT_DEFAULT_BRANCH)
         assert str(bundle._dag_bundle_root_storage_path) in str(bundle.path)
 
-    def test_get_current_version(self, git_repo):
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_get_current_version(self, mock_githook, git_repo):
+        mock_githook.get_conn.return_value = mock.MagicMock()

Review Comment:
   ```suggestion
   ```
   
   This is the default behavior, no need to do it explicitly.



##########
tests/dag_processing/test_dag_bundles.py:
##########
@@ -258,15 +365,78 @@ def test_subdir(self, git_repo):
         bundle = GitDagBundle(
             name="test",
             refresh_interval=300,
-            repo_url=repo_path,
             tracking_ref=GIT_DEFAULT_BRANCH,
             subdir=subdir,
         )
+        bundle.initialize()
 
         files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
         assert str(bundle.path).endswith(subdir)
         assert {"some_new_file.py"} == files_in_repo
 
+    def test_raises_when_no_repo_url(self):
+        bundle = GitDagBundle(
+            name="test",
+            refresh_interval=300,
+            git_conn_id=CONN_NO_REPO_URL,
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
+        with pytest.raises(
+            AirflowException, match=f"Connection {CONN_NO_REPO_URL} doesn't 
have a git_repo_url"
+        ):
+            bundle.initialize()
+
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    @mock.patch("airflow.dag_processing.bundles.git.Repo")
+    @mock.patch.object(GitDagBundle, "_clone_from")
+    def test_with_path_as_repo_url(self, mock_clone_from, mock_gitRepo, 
mock_githook):
+        bundle = GitDagBundle(
+            name="test",
+            refresh_interval=300,
+            git_conn_id=CONN_ONLY_PATH,
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
+        bundle.initialize()
+        assert mock_clone_from.call_count == 2
+        assert mock_gitRepo.return_value.git.checkout.call_count == 1
+
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    @mock.patch("airflow.dag_processing.bundles.git.Repo")
+    def test_refresh_with_git_connection(self, mock_gitRepo, mock_hook):
+        bundle = GitDagBundle(
+            name="test",
+            refresh_interval=300,
+            git_conn_id="git_default",
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
+        bundle.initialize()
+        bundle.refresh()
+        # check remotes called twice. one at initialize and one at refresh 
above
+        assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2
+
+    @pytest.mark.parametrize(
+        "repo_url",
+        [
+            pytest.param("https://github.com/apache/airflow";, id="https_url"),
+            pytest.param("airflow@example:apache/airflow.git", 
id="does_not_start_with_git_at"),
+            pytest.param("git@example:apache/airflow", 
id="does_not_end_with_dot_git"),
+        ],
+    )
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_repo_url_starts_with_git_when_using_ssh_conn_id(self, mock_hook, 
repo_url, session):

Review Comment:
   ```suggestion
       def test_repo_url_validation_for_ssh(self, mock_hook, repo_url, session):
   ```
   
   or something



##########
airflow/dag_processing/bundles/git.py:
##########
@@ -17,70 +17,146 @@
 
 from __future__ import annotations
 
+import json
 import os
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 from urllib.parse import urlparse
 
 from git import Repo
 from git.exc import BadName
 
 from airflow.dag_processing.bundles.base import BaseDagBundle
 from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
 
 if TYPE_CHECKING:
     from pathlib import Path
 
 
-class GitDagBundle(BaseDagBundle):
+class GitHook(BaseHook):
+    """
+    Hook for git repositories.
+
+    :param git_conn_id: Connection ID for SSH connection to the repository
+
+    """
+
+    conn_name_attr = "git_conn_id"
+    default_conn_name = "git_default"
+    conn_type = "git"
+    hook_name = "GIT"
+
+    @classmethod
+    def get_ui_field_behaviour(cls) -> dict[str, Any]:
+        return {
+            "hidden_fields": ["schema"],
+            "relabeling": {
+                "login": "Username",
+                "host": "Repository URL",
+                "password": "Access Token (optional)",
+            },
+            "placeholders": {
+                "extra": json.dumps(
+                    {
+                        "key_file": "optional/path/to/keyfile",
+                    }
+                )
+            },
+        }
+
+    def __init__(self, git_conn_id="git_default", *args, **kwargs):
+        super().__init__()
+        connection = self.get_connection(git_conn_id)
+        self.repo_url = connection.host
+        self.auth_token = connection.password
+        self.key_file = connection.extra_dejson.get("key_file")
+        self.env: dict[str, str] = {}
+        if self.key_file:
+            self.env["GIT_SSH_COMMAND"] = f"ssh -i {self.key_file} -o 
IdentitiesOnly=yes"
+        self._process_git_auth_url()
+
+    def _process_git_auth_url(self):
+        if not isinstance(self.repo_url, str):
+            return
+        if self.auth_token and self.repo_url.startswith("https://";):
+            self.repo_url = self.repo_url.replace("https://";, 
f"https://{self.auth_token}@";)
+        elif not self.repo_url.startswith("git@") or not 
self.repo_url.startswith("https://";):
+            self.repo_url = os.path.expanduser(self.repo_url)
+
+
+class GitDagBundle(BaseDagBundle, LoggingMixin):
     """
     git DAG bundle - exposes a git repository as a DAG bundle.
 
     Instead of cloning the repository every time, we clone the repository once 
into a bare repo from the source
     and then do a clone for each version from there.
 
-    :param repo_url: URL of the git repository
     :param tracking_ref: Branch or tag for this DAG bundle
     :param subdir: Subdirectory within the repository where the DAGs are 
stored (Optional)
+    :param git_conn_id: Connection ID for SSH/token based connection to the 
repository (Optional)
     """
 
     supports_versioning = True
 
-    def __init__(self, *, repo_url: str, tracking_ref: str, subdir: str | None 
= None, **kwargs) -> None:
+    def __init__(
+        self,
+        *,
+        tracking_ref: str,
+        subdir: str | None = None,
+        git_conn_id: str = "git_default",
+        **kwargs,
+    ) -> None:
         super().__init__(**kwargs)
-        self.repo_url = repo_url
         self.tracking_ref = tracking_ref
         self.subdir = subdir
-
         self.bare_repo_path = self._dag_bundle_root_storage_path / "git" / 
self.name
         self.repo_path = (
             self._dag_bundle_root_storage_path / "git" / (self.name + 
f"+{self.version or self.tracking_ref}")
         )
+        self.git_conn_id = git_conn_id
+        self.hook = GitHook(git_conn_id=self.git_conn_id)
+        self.repo_url = self.hook.repo_url
+
+    def _clone_from(self, to_path: Path, bare: bool = False) -> Repo:
+        self.log.info("Cloning %s to %s", self.repo_url, to_path)
+        return Repo.clone_from(self.repo_url, to_path, bare=bare, 
env=self.hook.env)
+
+    def _initialize(self):
         self._clone_bare_repo_if_required()
         self._ensure_version_in_bare_repo()
         self._clone_repo_if_required()
         self.repo.git.checkout(self.tracking_ref)
-
         if self.version:
             if not self._has_version(self.repo, self.version):
-                self.repo.remotes.origin.fetch()
-
+                self._fetch_repo()
             self.repo.head.set_reference(self.repo.commit(self.version))
             self.repo.head.reset(index=True, working_tree=True)
         else:
             self.refresh()
 
+    def initialize(self) -> None:
+        if not self.repo_url:
+            raise AirflowException(f"Connection {self.git_conn_id} doesn't 
have a git_repo_url")
+        if isinstance(self.repo_url, os.PathLike):
+            self._initialize()
+        elif not self.repo_url.startswith("git@") or not 
self.repo_url.endswith(".git"):
+            raise AirflowException(
+                f"Invalid git URL: {self.repo_url}. URL must start with git@ 
and end with .git"
+            )
+        else:
+            self._initialize()
+
     def _clone_repo_if_required(self) -> None:
         if not os.path.exists(self.repo_path):
-            Repo.clone_from(
-                url=self.bare_repo_path,

Review Comment:
   We do want to clone from the bare repo for this one, so you'll want to pass 
url also.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to