jedcunningham commented on code in PR #44976:
URL: https://github.com/apache/airflow/pull/44976#discussion_r1895218790


##########
airflow/dag_processing/bundles/git.py:
##########
@@ -40,20 +41,36 @@ class GitDagBundle(BaseDagBundle):
     :param repo_url: URL of the git repository
     :param tracking_ref: Branch or tag for this DAG bundle
     :param subdir: Subdirectory within the repository where the DAGs are 
stored (Optional)
+    :param ssh_conn_id: Connection ID for SSH connection to the repository 
(Optional)
     """
 
     supports_versioning = True
 
-    def __init__(self, *, repo_url: str, tracking_ref: str, subdir: str | None 
= None, **kwargs) -> None:
+    def __init__(
+        self,
+        *,
+        repo_url: str,
+        tracking_ref: str,
+        subdir: str | None = None,
+        ssh_conn_kwargs: dict[str, str] | None = None,

Review Comment:
   ```suggestion
           ssh_hook_kwargs: dict[str, str] | None = None,
   ```
   
   These are really kwargs for `SSHHook`.  docstring is wrong too.



##########
tests/dag_processing/test_dag_bundles.py:
##########
@@ -261,7 +268,53 @@ def test_subdir(self, git_repo):
             tracking_ref=GIT_DEFAULT_BRANCH,
             subdir=subdir,
         )
+        bundle.initialize()
 
         files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
         assert str(bundle.path).endswith(subdir)
         assert {"some_new_file.py"} == files_in_repo
+
+    @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook")
+    @mock.patch("airflow.dag_processing.bundles.git.Repo")
+    def test_with_ssh_conn_id(self, mock_gitRepo, mock_hook):
+        repo_url = "[email protected]:apache/airflow.git"
+        conn_id = "ssh_default"
+        bundle = GitDagBundle(
+            repo_url=repo_url,
+            name="test",
+            refresh_interval=300,
+            ssh_conn_kwargs={"ssh_conn_id": "ssh_default"},
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
+        bundle.initialize()
+        mock_hook.assert_called_once_with(ssh_conn_id=conn_id)
+
+    @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook")
+    @mock.patch("airflow.dag_processing.bundles.git.Repo")
+    def test_refresh_with_ssh_connection(self, mock_gitRepo, mock_hook):
+        repo_url = "[email protected]:apache/airflow.git"
+        bundle = GitDagBundle(
+            repo_url=repo_url,
+            name="test",
+            refresh_interval=300,
+            ssh_conn_kwargs={"ssh_conn_id": "ssh_default"},
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
+        bundle.initialize()
+        bundle.refresh()
+        # check remotes called twice. one at initialize and one at refresh 
above
+        assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2
+
+    def test_repo_url_starts_with_git_when_using_ssh_conn_id(self):
+        repo_url = "https://github.com/apache/airflow";

Review Comment:
   ```suggestion
       @pytest.mark.parametrize(
           "repo_url",
           [
               pytest.param("https://github.com/apache/airflow";, 
id="https_url"),
               pytest.param("airflow@example:apache/airflow.git", 
id="does_not_start_with_git_at"),
               pytest.param("git@example:apache/airflow", 
id="does_not_end_with_dot_git"),
           ]
       )
       def test_repo_url_starts_with_git_when_using_ssh_conn_id(self, repo_url):
   ```
   
   Or similar. This better tests that we enforce it both starts with and ends 
with git.



##########
airflow/dag_processing/bundles/git.py:
##########
@@ -66,20 +83,37 @@ def __init__(self, *, repo_url: str, tracking_ref: str, 
subdir: str | None = Non
             self.repo.head.set_reference(self.repo.commit(self.version))
             self.repo.head.reset(index=True, working_tree=True)
         else:
-            self.refresh()
+            self._refresh()
+
+    def _ssh_hook(self):
+        try:
+            from airflow.providers.ssh.hooks.ssh import SSHHook
+        except ImportError as e:
+            raise AirflowOptionalProviderFeatureException(e)
+        return SSHHook(**self.ssh_conn_kwargs)
+
+    def initialize(self) -> None:
+        if self.ssh_conn_kwargs:
+            if not self.repo_url.startswith("git@") and not 
self.repo_url.endswith(".git"):

Review Comment:
   ```suggestion
               if not self.repo_url.startswith("git@") or not 
self.repo_url.endswith(".git"):
   ```
   
   We need to check that they are there individually.



##########
airflow/dag_processing/bundles/git.py:
##########
@@ -120,9 +154,16 @@ def _has_version(repo: Repo, version: str) -> bool:
         except BadName:
             return False
 
+    def _refresh(self):
+        self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+        self.repo.remotes.origin.pull()
+
     def refresh(self) -> None:
         if self.version:
             raise AirflowException("Refreshing a specific version is not 
supported")
 
-        self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
-        self.repo.remotes.origin.pull()
+        if self.ssh_hook:
+            with self.ssh_hook.get_conn():

Review Comment:
   Doesn't this create a session to the (a?) remote ssh host 
[here](https://github.com/apache/airflow/blob/351ac28c7c4dde3a611e4260bbb0f01c7618f261/providers/src/airflow/providers/ssh/hooks/ssh.py#L339-L340)?
 Not sure we want that. Not even sure where this would connect, since the 
remote host likely wouldn't match the `repo_url`.
   
   Unless we can somehow tell `git` to use that socket/session?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to