This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch sbp
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git


The following commit(s) were added to refs/heads/sbp by this push:
     new 1e01d75  Check out GitHub trees by commit, not branch
1e01d75 is described below

commit 1e01d7562cfcddb5b187238668b8b7166cac6c31
Author: Sean B. Palmer <[email protected]>
AuthorDate: Thu Feb 5 15:04:36 2026 +0000

    Check out GitHub trees by commit, not branch
---
 atr/tasks/checks/compare.py       |  51 +++++++------
 tests/unit/test_checks_compare.py | 151 ++++++++++++++++++++++++++++++++++++--
 2 files changed, 176 insertions(+), 26 deletions(-)

diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py
index c77c241..8f1464c 100644
--- a/atr/tasks/checks/compare.py
+++ b/atr/tasks/checks/compare.py
@@ -16,17 +16,22 @@
 # under the License.
 
 import asyncio
+import contextlib
 import json
 import os
 import pathlib
 import shutil
 import time
+from collections.abc import Mapping
 from typing import Any, Final
 
 import aiofiles
 import aiofiles.os
-import dulwich.objectspec as objectspec
-import dulwich.porcelain as porcelain
+import dulwich.client
+import dulwich.objects
+import dulwich.objectspec
+import dulwich.porcelain
+import dulwich.refs
 import pydantic
 
 import atr.attestable as attestable
@@ -40,6 +45,18 @@ _DEFAULT_EMAIL: Final[str] = "atr@localhost"
 _DEFAULT_USER: Final[str] = "atr"
 
 
+class DetermineWantsForSha:
+    def __init__(self, sha: str) -> None:
+        self.sha = sha
+
+    def __call__(
+        self,
+        refs: Mapping[dulwich.refs.Ref, dulwich.objects.ObjectID],
+        depth: int | None = None,
+    ) -> list[dulwich.objects.ObjectID]:
+        return [dulwich.objects.ObjectID(self.sha.encode("ascii"))]
+
+
 async def source_trees(args: checks.FunctionArguments) -> results.Results | 
None:
     recorder = await args.recorder()
     is_source = await recorder.primary_path_is_source()
@@ -79,17 +96,15 @@ async def _checkout_github_source(
     payload: github_models.TrustedPublisherPayload, checkout_dir: pathlib.Path
 ) -> str | None:
     repo_url = f"https://github.com/{payload.repository}.git";
-    branch = _ref_to_branch(payload.ref)
     started_ns = time.perf_counter_ns()
     try:
-        await asyncio.to_thread(_clone_repo, repo_url, payload.sha, branch, 
checkout_dir)
+        await asyncio.to_thread(_clone_repo, repo_url, payload.sha, 
checkout_dir)
     except Exception:
         elapsed_ms = (time.perf_counter_ns() - started_ns) / 1_000_000.0
         log.exception(
             "Failed to clone GitHub repo for compare.source_trees",
             repo_url=repo_url,
             sha=payload.sha,
-            branch=branch,
             checkout_dir=str(checkout_dir),
             clone_ms=elapsed_ms,
             git_author_name=os.environ.get("GIT_AUTHOR_NAME"),
@@ -106,24 +121,24 @@ async def _checkout_github_source(
         "Cloned GitHub repo for compare.source_trees",
         repo_url=repo_url,
         sha=payload.sha,
-        branch=branch,
         checkout_dir=str(checkout_dir),
         clone_ms=elapsed_ms,
     )
     return str(checkout_dir)
 
 
-def _clone_repo(repo_url: str, sha: str, branch: str | None, checkout_dir: 
pathlib.Path) -> None:
+def _clone_repo(repo_url: str, sha: str, checkout_dir: pathlib.Path) -> None:
     _ensure_clone_identity_env()
-    repo = porcelain.clone(
-        repo_url,
-        str(checkout_dir),
-        checkout=True,
-        depth=1,
-        branch=branch,
-    )
+    repo = dulwich.porcelain.init(str(checkout_dir))
+    git_client, path = dulwich.client.get_transport_and_path(repo_url, 
operation="pull")
     try:
-        commit = objectspec.parse_commit(repo, sha)
+        determine_wants = DetermineWantsForSha(sha)
+        git_client.fetch(path, repo, determine_wants=determine_wants, depth=1)
+    finally:
+        with contextlib.suppress(Exception):
+            git_client.close()
+    try:
+        commit = dulwich.objectspec.parse_commit(repo, sha)
         repo.get_worktree().reset_index(tree=commit.tree)
     except (KeyError, ValueError) as exc:
         raise RuntimeError(f"Commit {sha} not found in shallow clone") from exc
@@ -170,9 +185,3 @@ def _payload_summary(payload: 
github_models.TrustedPublisherPayload | None) -> d
         "actor": payload.actor,
         "actor_id": payload.actor_id,
     }
-
-
-def _ref_to_branch(ref: str) -> str | None:
-    if ref.startswith("refs/heads/"):
-        return ref.removeprefix("refs/heads/")
-    return None
diff --git a/tests/unit/test_checks_compare.py 
b/tests/unit/test_checks_compare.py
index 4bd6f5c..b57ee97 100644
--- a/tests/unit/test_checks_compare.py
+++ b/tests/unit/test_checks_compare.py
@@ -16,8 +16,11 @@
 # under the License.
 
 import pathlib
+from collections.abc import Callable, Mapping
 
 import aiofiles.os
+import dulwich.objects
+import dulwich.refs
 import pytest
 
 import atr.sbom.models.github
@@ -44,10 +47,62 @@ class CheckoutRecorder:
 
 class CloneRecorder:
     def __init__(self) -> None:
-        self.calls: list[tuple[str, str, str | None, pathlib.Path]] = []
+        self.calls: list[tuple[str, str, pathlib.Path]] = []
 
-    def __call__(self, repo_url: str, sha: str, branch: str | None, 
checkout_dir: pathlib.Path) -> None:
-        self.calls.append((repo_url, sha, branch, checkout_dir))
+    def __call__(self, repo_url: str, sha: str, checkout_dir: pathlib.Path) -> 
None:
+        self.calls.append((repo_url, sha, checkout_dir))
+
+
+class CommitStub:
+    def __init__(self, tree: object) -> None:
+        self.tree = tree
+
+
+class GitClientStub:
+    def __init__(self) -> None:
+        self.closed = False
+        self.fetch_calls: list[tuple[str, object, int | None]] = []
+        self.wants: list[dulwich.objects.ObjectID] | None = None
+
+    def fetch(
+        self,
+        path: str,
+        repo: object,
+        determine_wants: Callable[
+            [Mapping[dulwich.refs.Ref, dulwich.objects.ObjectID], int | None],
+            list[dulwich.objects.ObjectID],
+        ],
+        depth: int | None = None,
+    ) -> None:
+        self.fetch_calls.append((path, repo, depth))
+        if self.wants is None:
+            self.wants = determine_wants({}, depth)
+
+    def close(self) -> None:
+        self.closed = True
+
+
+class InitRecorder:
+    def __init__(self, repo: object) -> None:
+        self.calls: list[str] = []
+        self.repo = repo
+
+    def __call__(self, path: str) -> object:
+        self.calls.append(path)
+        return self.repo
+
+
+class ParseCommitRecorder:
+    def __init__(self, commit: CommitStub, raise_exc: Exception | None = None) 
-> None:
+        self.calls: list[tuple[object, str]] = []
+        self.commit = commit
+        self.raise_exc = raise_exc
+
+    def __call__(self, repo: object, sha: str) -> CommitStub:
+        self.calls.append((repo, sha))
+        if self.raise_exc is not None:
+            raise self.raise_exc
+        return self.commit
 
 
 class PayloadLoader:
@@ -101,6 +156,18 @@ class RecorderStub(atr.tasks.checks.Recorder):
         return self._is_source
 
 
+class RepoStub:
+    def __init__(self, controldir: pathlib.Path, worktree: object) -> None:
+        self._controldir = controldir
+        self._worktree = worktree
+
+    def controldir(self) -> str:
+        return str(self._controldir)
+
+    def get_worktree(self) -> object:
+        return self._worktree
+
+
 class ReturnValue:
     def __init__(self, value: pathlib.Path) -> None:
         self.value = value
@@ -109,6 +176,25 @@ class ReturnValue:
         return self.value
 
 
+class TransportRecorder:
+    def __init__(self, client: object, path: str) -> None:
+        self.calls: list[tuple[str, str | None]] = []
+        self.client = client
+        self.path = path
+
+    def __call__(self, repo_url: str, operation: str | None = None) -> 
tuple[object, str]:
+        self.calls.append((repo_url, operation))
+        return self.client, self.path
+
+
+class WorktreeStub:
+    def __init__(self) -> None:
+        self.reset_calls: list[object] = []
+
+    def reset_index(self, tree: object | None = None) -> None:
+        self.reset_calls.append(tree)
+
+
 @pytest.mark.asyncio
 async def test_checkout_github_source_uses_provided_dir(
     monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path
@@ -123,13 +209,68 @@ async def test_checkout_github_source_uses_provided_dir(
 
     assert result == str(checkout_dir)
     assert len(clone_recorder.calls) == 1
-    repo_url, sha, branch, called_dir = clone_recorder.calls[0]
+    repo_url, sha, called_dir = clone_recorder.calls[0]
     assert repo_url == "https://github.com/apache/test.git";
     assert sha == "0000000000000000000000000000000000000000"
-    assert branch == "main"
     assert called_dir == checkout_dir
 
 
+def test_clone_repo_fetches_requested_sha(monkeypatch: pytest.MonkeyPatch, 
tmp_path: pathlib.Path) -> None:
+    checkout_dir = tmp_path / "checkout"
+    git_dir = checkout_dir / ".git"
+    git_dir.mkdir(parents=True)
+    worktree = WorktreeStub()
+    repo = RepoStub(git_dir, worktree)
+    init_recorder = InitRecorder(repo)
+    git_client = GitClientStub()
+    transport = TransportRecorder(git_client, "remote-path")
+    tree_marker = object()
+    parse_commit = ParseCommitRecorder(CommitStub(tree_marker))
+    sha = "0000000000000000000000000000000000000000"
+
+    monkeypatch.setattr(atr.tasks.checks.compare.dulwich.porcelain, "init", 
init_recorder)
+    monkeypatch.setattr(atr.tasks.checks.compare.dulwich.client, 
"get_transport_and_path", transport)
+    monkeypatch.setattr(atr.tasks.checks.compare.dulwich.objectspec, 
"parse_commit", parse_commit)
+
+    atr.tasks.checks.compare._clone_repo("https://github.com/apache/test.git";, 
sha, checkout_dir)
+
+    assert init_recorder.calls == [str(checkout_dir)]
+    assert transport.calls == [("https://github.com/apache/test.git";, "pull")]
+    assert git_client.fetch_calls == [("remote-path", repo, 1)]
+    assert git_client.wants == [dulwich.objects.ObjectID(sha.encode("ascii"))]
+    assert parse_commit.calls == [(repo, sha)]
+    assert worktree.reset_calls == [tree_marker]
+    assert not git_dir.exists()
+    assert git_client.closed is True
+
+
+def test_clone_repo_raises_when_commit_missing(monkeypatch: 
pytest.MonkeyPatch, tmp_path: pathlib.Path) -> None:
+    checkout_dir = tmp_path / "checkout"
+    git_dir = checkout_dir / ".git"
+    git_dir.mkdir(parents=True)
+    worktree = WorktreeStub()
+    repo = RepoStub(git_dir, worktree)
+    init_recorder = InitRecorder(repo)
+    git_client = GitClientStub()
+    transport = TransportRecorder(git_client, "remote-path")
+    parse_commit = ParseCommitRecorder(CommitStub(object()), 
raise_exc=KeyError("missing"))
+    sha = "1111111111111111111111111111111111111111"
+
+    monkeypatch.setattr(atr.tasks.checks.compare.dulwich.porcelain, "init", 
init_recorder)
+    monkeypatch.setattr(atr.tasks.checks.compare.dulwich.client, 
"get_transport_and_path", transport)
+    monkeypatch.setattr(atr.tasks.checks.compare.dulwich.objectspec, 
"parse_commit", parse_commit)
+
+    with pytest.raises(RuntimeError, match=r"Commit .* not found in shallow 
clone"):
+        
atr.tasks.checks.compare._clone_repo("https://github.com/apache/test.git";, sha, 
checkout_dir)
+
+    assert git_client.fetch_calls == [("remote-path", repo, 1)]
+    assert git_client.wants == [dulwich.objects.ObjectID(sha.encode("ascii"))]
+    assert parse_commit.calls == [(repo, sha)]
+    assert worktree.reset_calls == []
+    assert git_dir.exists()
+    assert git_client.closed is True
+
+
 @pytest.mark.asyncio
 async def test_source_trees_creates_temp_workspace_and_cleans_up(
     monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to