This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7273f90f74b Support tracking a tag and force pushes in the git bundle 
(#46983)
7273f90f74b is described below

commit 7273f90f74b7cc487483e40af204959608e75d4b
Author: Jed Cunningham <[email protected]>
AuthorDate: Sat Feb 22 18:35:58 2025 -0700

    Support tracking a tag and force pushes in the git bundle (#46983)
    
    - Tracking tags, while intended to be supported, weren't.
    - For branches, we were doing a naive pull, but this breaks if the repo had 
a force
    push. We instead do an explicit fetch and hard reset now.
---
 airflow/dag_processing/bundles/git.py    | 15 ++++++++--
 tests/dag_processing/bundles/test_git.py | 49 ++++++++++++++++++++++++++++++--
 2 files changed, 58 insertions(+), 6 deletions(-)

diff --git a/airflow/dag_processing/bundles/git.py 
b/airflow/dag_processing/bundles/git.py
index e6dc68ea343..6fd7d815191 100644
--- a/airflow/dag_processing/bundles/git.py
+++ b/airflow/dag_processing/bundles/git.py
@@ -243,11 +243,12 @@ class GitDagBundle(BaseDagBundle, LoggingMixin):
             return False
 
     def _fetch_bare_repo(self):
+        refspecs = ["+refs/heads/*:refs/heads/*", "+refs/tags/*:refs/tags/*"]
         if self.hook.env:
             with 
self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
-                
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+                self.bare_repo.remotes.origin.fetch(refspecs)
         else:
-            self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+            self.bare_repo.remotes.origin.fetch(refspecs)
 
     def refresh(self) -> None:
         if self.version:
@@ -256,7 +257,15 @@ class GitDagBundle(BaseDagBundle, LoggingMixin):
         with self.lock():
             with self.hook.configure_hook_env():
                 self._fetch_bare_repo()
-                self.repo.remotes.origin.pull()
+                self.repo.remotes.origin.fetch(
+                    ["+refs/heads/*:refs/remotes/origin/*", 
"+refs/tags/*:refs/tags/*"]
+                )
+                remote_branch = f"origin/{self.tracking_ref}"
+                if remote_branch in [ref.name for ref in 
self.repo.remotes.origin.refs]:
+                    target = remote_branch
+                else:
+                    target = self.tracking_ref
+                self.repo.head.reset(target, index=True, working_tree=True)
 
     @staticmethod
     def _convert_git_ssh_url_to_https(url: str) -> str:
diff --git a/tests/dag_processing/bundles/test_git.py 
b/tests/dag_processing/bundles/test_git.py
index acbc84b4018..42a96a24f95 100644
--- a/tests/dag_processing/bundles/test_git.py
+++ b/tests/dag_processing/bundles/test_git.py
@@ -295,7 +295,6 @@ class TestGitDagBundle:
 
         # add tag
         repo.create_tag("test")
-        print(repo.tags)
 
         # Add new file to the repo
         file_path = repo_path / "new_test.py"
@@ -335,12 +334,24 @@ class TestGitDagBundle:
         files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
         assert {"test_dag.py", "new_test.py"} == files_in_repo
 
+    @pytest.mark.parametrize(
+        "amend",
+        [
+            True,
+            False,
+        ],
+    )
     @mock.patch("airflow.dag_processing.bundles.git.GitHook")
-    def test_refresh(self, mock_githook, git_repo):
+    def test_refresh(self, mock_githook, git_repo, amend):
+        """Ensure that the bundle refresh works when tracking a branch, with a 
new commit and amending the commit"""
         repo_path, repo = git_repo
         mock_githook.return_value.repo_url = repo_path
         starting_commit = repo.head.commit
 
+        with repo.config_writer() as writer:
+            writer.set_value("user", "name", "Test User")
+            writer.set_value("user", "email", "[email protected]")
+
         bundle = GitDagBundle(name="test", tracking_ref=GIT_DEFAULT_BRANCH)
         bundle.initialize()
 
@@ -349,12 +360,43 @@ class TestGitDagBundle:
         files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
         assert {"test_dag.py"} == files_in_repo
 
+        file_path = repo_path / "new_test.py"
+        with open(file_path, "w") as f:
+            f.write("hello world")
+        repo.index.add([file_path])
+        commit = repo.git.commit(amend=amend, message="Another commit")
+
+        bundle.refresh()
+
+        assert bundle.get_current_version()[:6] in commit
+
+        files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
+        assert {"test_dag.py", "new_test.py"} == files_in_repo
+
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_refresh_tag(self, mock_githook, git_repo):
+        """Ensure that the bundle refresh works when tracking a tag"""
+        repo_path, repo = git_repo
+        mock_githook.return_value.repo_url = repo_path
+        starting_commit = repo.head.commit
+
+        # add tag
+        repo.create_tag("test123")
+
+        bundle = GitDagBundle(name="test", tracking_ref="test123")
+        bundle.initialize()
+        assert bundle.get_current_version() == starting_commit.hexsha
+
+        # Add new file to the repo
         file_path = repo_path / "new_test.py"
         with open(file_path, "w") as f:
             f.write("hello world")
         repo.index.add([file_path])
         commit = repo.index.commit("Another commit")
 
+        # update tag
+        repo.create_tag("test123", force=True)
+
         bundle.refresh()
 
         assert bundle.get_current_version() == commit.hexsha
@@ -440,8 +482,9 @@ class TestGitDagBundle:
             tracking_ref=GIT_DEFAULT_BRANCH,
         )
         bundle.initialize()
+        assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2  
# 1 in bare, 1 in main repo
+        mock_gitRepo.return_value.remotes.origin.fetch.reset_mock()
         bundle.refresh()
-        # check remotes called twice. one at initialize and one at refresh 
above
         assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2
 
     @pytest.mark.parametrize(

Reply via email to