This is an automated email from the ASF dual-hosted git repository.

hanahmily pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ba863118 Fix backup restore path traversal (#1135)
7ba863118 is described below

commit 7ba863118469076a3d34e8a1aa2657fe9286f792
Author: Tanay Paul <[email protected]>
AuthorDate: Thu May 21 06:18:52 2026 +0530

    Fix backup restore path traversal (#1135)
---
 banyand/backup/restore.go         |  35 +++--
 banyand/backup/restore_test.go    |  97 ++++++++++++
 pkg/fs/remote/local/local.go      |  94 +++++++++++-
 pkg/fs/remote/local/local_test.go | 307 ++++++++++++++++++++++++++++++++++++++
 pkg/path/path.go                  |  14 ++
 pkg/path/path_test.go             |  36 +++++
 6 files changed, 567 insertions(+), 16 deletions(-)

diff --git a/banyand/backup/restore.go b/banyand/backup/restore.go
index 93a3de4b1..dd5db44b8 100644
--- a/banyand/backup/restore.go
+++ b/banyand/backup/restore.go
@@ -23,6 +23,7 @@ import (
        "fmt"
        "io"
        "os"
+       "path"
        "path/filepath"
        "strings"
 
@@ -193,13 +194,12 @@ func restoreByName(fs remote.FS, timeDir, rootPath, 
catalogName string) error {
        logger.Infof("Restoring %s to %s from %s, remote total %d files", 
catalogName, localDir, remotePrefix, len(remoteFiles))
 
        remoteRelSet := make(map[string]bool)
-       var relPath string
        for _, remoteFile := range remoteFiles {
-               relPath, err = filepath.Rel(timeDir, remoteFile)
-               if err != nil {
-                       return fmt.Errorf("failed to get relative path for %s: 
%w", remoteFile, err)
+               relPath, relPathErr := validatedRemoteRelPath(timeDir, 
catalogName, remoteFile)
+               if relPathErr != nil {
+                       return relPathErr
                }
-               remoteRelSet[filepath.ToSlash(relPath)] = true
+               remoteRelSet[path.Join(catalogName, relPath)] = true
        }
 
        localFiles, err := getAllFiles(localDir)
@@ -208,7 +208,7 @@ func restoreByName(fs remote.FS, timeDir, rootPath, 
catalogName string) error {
        }
 
        for _, localRelPath := range localFiles {
-               localRelPathWithCatalog := filepath.Join(catalogName, 
localRelPath)
+               localRelPathWithCatalog := path.Join(catalogName, 
filepath.ToSlash(localRelPath))
                if !remoteRelSet[localRelPathWithCatalog] {
                        localPath := filepath.Join(localDir, localRelPath)
                        logger.Infof("found local file: %s not exist in the 
remote storage, so delete it", localRelPathWithCatalog)
@@ -220,11 +220,10 @@ func restoreByName(fs remote.FS, timeDir, rootPath, 
catalogName string) error {
        }
 
        for _, remoteFile := range remoteFiles {
-               relPath, err := filepath.Rel(filepath.Join(timeDir, 
catalogName), remoteFile)
+               relPath, err := validatedRemoteRelPath(timeDir, catalogName, 
remoteFile)
                if err != nil {
-                       return fmt.Errorf("failed to get relative path for %s: 
%w", remoteFile, err)
+                       return err
                }
-               relPath = filepath.ToSlash(relPath)
                localPath := filepath.Join(rootPath, catalogName, 
storage.DataDir, relPath)
 
                if !contains(localFiles, relPath) {
@@ -244,6 +243,24 @@ func restoreByName(fs remote.FS, timeDir, rootPath, 
catalogName string) error {
        return nil
 }
 
+func validatedRemoteRelPath(timeDir, catalogName, remoteFile string) (string, 
error) {
+       remotePath := filepath.ToSlash(remoteFile)
+       if path.IsAbs(remotePath) || banyandbpath.HasVolumeName(remotePath) {
+               return "", fmt.Errorf("remote file %q escapes backup prefix", 
remoteFile)
+       }
+       prefix := path.Clean(path.Join(filepath.ToSlash(timeDir), 
filepath.ToSlash(catalogName)))
+       cleanRemotePath := path.Clean(remotePath)
+       if banyandbpath.HasVolumeName(cleanRemotePath) || cleanRemotePath == 
"." || prefix == "." {
+               return "", fmt.Errorf("remote file %q escapes backup prefix", 
remoteFile)
+       }
+       prefixWithSlash := prefix + "/"
+       if !strings.HasPrefix(cleanRemotePath, prefixWithSlash) {
+               return "", fmt.Errorf("remote file %q escapes backup prefix", 
remoteFile)
+       }
+       relPath := strings.TrimPrefix(cleanRemotePath, prefixWithSlash)
+       return relPath, nil
+}
+
 func cleanEmptyDirs(dir, stopDir string) {
        for {
                if dir == stopDir || dir == "." {
diff --git a/banyand/backup/restore_test.go b/banyand/backup/restore_test.go
index 086f18a23..ca76f984e 100644
--- a/banyand/backup/restore_test.go
+++ b/banyand/backup/restore_test.go
@@ -19,6 +19,7 @@ package backup
 
 import (
        "context"
+       "io"
        "os"
        "path/filepath"
        "strings"
@@ -192,3 +193,99 @@ func TestRestoreSame(t *testing.T) {
                t.Fatalf("expected extra file %q exist", extraFilePath)
        }
 }
+
+func TestRestoreRejectsRemotePathTraversal(t *testing.T) {
+       timeDir := testTimeDir
+       catalogName := snapshot.CatalogName(commonv1.Catalog_CATALOG_STREAM)
+       localRestoreDir := t.TempDir()
+       escapedFile := filepath.Join(localRestoreDir, catalogName, 
"escaped.txt")
+       fs := &restoreTraversalFS{
+               files: []string{
+                       filepath.ToSlash(filepath.Join(timeDir, catalogName, 
"..", "escaped.txt")),
+               },
+       }
+
+       err := restoreByName(fs, timeDir, localRestoreDir, catalogName)
+       if err == nil {
+               t.Fatal("expected restoreByName to reject remote path 
traversal")
+       }
+       if _, statErr := os.Stat(escapedFile); !os.IsNotExist(statErr) {
+               t.Fatalf("escaped file exists or stat failed with unexpected 
error: %v", statErr)
+       }
+}
+
+func TestValidatedRemoteRelPath(t *testing.T) {
+       timeDir := testTimeDir
+       catalogName := snapshot.CatalogName(commonv1.Catalog_CATALOG_STREAM)
+       validRemoteFile := filepath.ToSlash(filepath.Join(timeDir, catalogName, 
"nested", "test.txt"))
+
+       relPath, err := validatedRemoteRelPath(timeDir, catalogName, 
validRemoteFile)
+       if err != nil {
+               t.Fatalf("validatedRemoteRelPath failed: %v", err)
+       }
+       if relPath != "nested/test.txt" {
+               t.Fatalf("relPath = %q, want %q", relPath, "nested/test.txt")
+       }
+}
+
+func TestValidatedRemoteRelPathRejectsInvalidPaths(t *testing.T) {
+       timeDir := testTimeDir
+       catalogName := snapshot.CatalogName(commonv1.Catalog_CATALOG_STREAM)
+       tests := []struct {
+               remoteFile string
+               name       string
+       }{
+               {
+                       name:       "absolute path",
+                       remoteFile: "/tmp/backup/file.txt",
+               },
+               {
+                       name:       "catalog prefix only",
+                       remoteFile: filepath.ToSlash(filepath.Join(timeDir, 
catalogName)),
+               },
+               {
+                       name:       "outside catalog prefix",
+                       remoteFile: filepath.ToSlash(filepath.Join(timeDir, 
"measure", "test.txt")),
+               },
+               {
+                       name:       "parent traversal",
+                       remoteFile: filepath.ToSlash(filepath.Join(timeDir, 
catalogName, "..", "escaped.txt")),
+               },
+               {
+                       name:       "volume name",
+                       remoteFile: `C:/backup/file.txt`,
+               },
+       }
+
+       for _, testCase := range tests {
+               t.Run(testCase.name, func(t *testing.T) {
+                       if _, err := validatedRemoteRelPath(timeDir, 
catalogName, testCase.remoteFile); err == nil {
+                               t.Fatal("expected invalid remote path to be 
rejected")
+                       }
+               })
+       }
+}
+
+type restoreTraversalFS struct {
+       files []string
+}
+
+func (r *restoreTraversalFS) Upload(_ context.Context, _ string, _ io.Reader) 
error {
+       return nil
+}
+
+func (r *restoreTraversalFS) Download(_ context.Context, _ string) 
(io.ReadCloser, error) {
+       return io.NopCloser(strings.NewReader("escape")), nil
+}
+
+func (r *restoreTraversalFS) List(_ context.Context, _ string) ([]string, 
error) {
+       return r.files, nil
+}
+
+func (r *restoreTraversalFS) Delete(_ context.Context, _ string) error {
+       return nil
+}
+
+func (r *restoreTraversalFS) Close() error {
+       return nil
+}
diff --git a/pkg/fs/remote/local/local.go b/pkg/fs/remote/local/local.go
index 01f204552..176e18e2f 100644
--- a/pkg/fs/remote/local/local.go
+++ b/pkg/fs/remote/local/local.go
@@ -20,11 +20,14 @@ package local
 
 import (
        "context"
+       "fmt"
        "io"
        "os"
        "path/filepath"
+       "strings"
 
        "github.com/apache/skywalking-banyandb/pkg/fs/remote"
+       pathutil "github.com/apache/skywalking-banyandb/pkg/path"
 )
 
 const dirPerm = 0o755
@@ -40,14 +43,28 @@ func NewFS(baseDir string) (remote.FS, error) {
        if err := os.MkdirAll(baseDir, dirPerm); err != nil {
                return nil, err
        }
-       return &fs{baseDir: baseDir}, nil
+       cleanBaseDir, err := filepath.Abs(baseDir)
+       if err != nil {
+               return nil, err
+       }
+       realBaseDir, err := filepath.EvalSymlinks(cleanBaseDir)
+       if err != nil {
+               return nil, err
+       }
+       return &fs{baseDir: realBaseDir}, nil
 }
 
 func (l *fs) Upload(_ context.Context, path string, data io.Reader) error {
-       fullPath := filepath.Join(l.baseDir, path)
-       if err := os.MkdirAll(filepath.Dir(fullPath), dirPerm); err != nil {
+       fullPath, err := l.fullPath(path, false)
+       if err != nil {
                return err
        }
+       if mkdirErr := os.MkdirAll(filepath.Dir(fullPath), dirPerm); mkdirErr 
!= nil {
+               return mkdirErr
+       }
+       if err = l.ensureResolvedWithinBase(filepath.Dir(fullPath)); err != nil 
{
+               return fmt.Errorf("path %q escapes base directory: %w", path, 
err)
+       }
 
        file, err := os.Create(fullPath)
        if err != nil {
@@ -60,15 +77,21 @@ func (l *fs) Upload(_ context.Context, path string, data 
io.Reader) error {
 }
 
 func (l *fs) Download(_ context.Context, path string) (io.ReadCloser, error) {
-       fullPath := filepath.Join(l.baseDir, path)
+       fullPath, err := l.fullPath(path, false)
+       if err != nil {
+               return nil, err
+       }
        return os.Open(fullPath)
 }
 
 func (l *fs) List(_ context.Context, prefix string) ([]string, error) {
        var files []string
-       fullPath := filepath.Join(l.baseDir, prefix)
+       fullPath, err := l.fullPath(prefix, true)
+       if err != nil {
+               return nil, err
+       }
 
-       err := filepath.Walk(fullPath, func(path string, info os.FileInfo, err 
error) error {
+       err = filepath.Walk(fullPath, func(path string, info os.FileInfo, err 
error) error {
                if err != nil {
                        return err
                }
@@ -89,10 +112,67 @@ func (l *fs) List(_ context.Context, prefix string) 
([]string, error) {
 }
 
 func (l *fs) Delete(_ context.Context, path string) error {
-       fullPath := filepath.Join(l.baseDir, path)
+       fullPath, err := l.fullPath(path, false)
+       if err != nil {
+               return err
+       }
        return os.Remove(fullPath)
 }
 
 func (l *fs) Close() error {
        return nil
 }
+
+func (l *fs) fullPath(path string, allowRoot bool) (string, error) {
+       if filepath.IsAbs(path) || pathutil.HasVolumeName(path) {
+               return "", fmt.Errorf("path %q escapes base directory", path)
+       }
+       cleanPath := filepath.Clean(path)
+       if !allowRoot && cleanPath == "." {
+               return "", fmt.Errorf("path %q escapes base directory", path)
+       }
+       if pathutil.HasVolumeName(cleanPath) || cleanPath == ".." || 
strings.HasPrefix(cleanPath, ".."+string(filepath.Separator)) {
+               return "", fmt.Errorf("path %q escapes base directory", path)
+       }
+       fullPath := filepath.Join(l.baseDir, cleanPath)
+       relPath, err := filepath.Rel(l.baseDir, fullPath)
+       if err != nil {
+               return "", err
+       }
+       if relPath == ".." || strings.HasPrefix(relPath, 
".."+string(filepath.Separator)) || filepath.IsAbs(relPath) {
+               return "", fmt.Errorf("path %q escapes base directory", path)
+       }
+       if err := l.ensureResolvedWithinBase(fullPath); err != nil {
+               return "", fmt.Errorf("path %q escapes base directory: %w", 
path, err)
+       }
+       return fullPath, nil
+}
+
+func (l *fs) ensureResolvedWithinBase(path string) error {
+       existingPath := path
+       for {
+               if _, err := os.Lstat(existingPath); err == nil {
+                       break
+               } else if !os.IsNotExist(err) {
+                       return err
+               }
+               parentPath := filepath.Dir(existingPath)
+               if parentPath == existingPath {
+                       return os.ErrNotExist
+               }
+               existingPath = parentPath
+       }
+
+       realPath, err := filepath.EvalSymlinks(existingPath)
+       if err != nil {
+               return err
+       }
+       relPath, err := filepath.Rel(l.baseDir, realPath)
+       if err != nil {
+               return err
+       }
+       if relPath == ".." || strings.HasPrefix(relPath, 
".."+string(filepath.Separator)) || filepath.IsAbs(relPath) {
+               return fmt.Errorf("resolved path %q is outside base directory 
%q", realPath, l.baseDir)
+       }
+       return nil
+}
diff --git a/pkg/fs/remote/local/local_test.go 
b/pkg/fs/remote/local/local_test.go
new file mode 100644
index 000000000..8ad28f1ca
--- /dev/null
+++ b/pkg/fs/remote/local/local_test.go
@@ -0,0 +1,307 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package local
+
+import (
+       "context"
+       "errors"
+       "io"
+       "os"
+       "path/filepath"
+       "strings"
+       "testing"
+)
+
+func TestFSOperationsStayWithinBase(t *testing.T) {
+       baseDir := filepath.Join(t.TempDir(), "remote")
+       fs, err := NewFS(baseDir)
+       if err != nil {
+               t.Fatalf("NewFS failed: %v", err)
+       }
+
+       const content = "hello"
+       filePath := filepath.Join("snapshot", "data", "test.txt")
+       if err = fs.Upload(context.Background(), filePath, 
strings.NewReader(content)); err != nil {
+               t.Fatalf("Upload failed: %v", err)
+       }
+
+       files, err := fs.List(context.Background(), "snapshot")
+       if err != nil {
+               t.Fatalf("List failed: %v", err)
+       }
+       if len(files) != 1 || files[0] != filepath.ToSlash(filePath) {
+               t.Fatalf("files = %v, want [%s]", files, 
filepath.ToSlash(filePath))
+       }
+
+       reader, err := fs.Download(context.Background(), filePath)
+       if err != nil {
+               t.Fatalf("Download failed: %v", err)
+       }
+       got, err := io.ReadAll(reader)
+       closeErr := reader.Close()
+       if err != nil {
+               t.Fatalf("failed to read downloaded content: %v", err)
+       }
+       if closeErr != nil {
+               t.Fatalf("failed to close downloaded content: %v", closeErr)
+       }
+       if string(got) != content {
+               t.Fatalf("content = %q, want %q", string(got), content)
+       }
+
+       if err = fs.Delete(context.Background(), filePath); err != nil {
+               t.Fatalf("Delete failed: %v", err)
+       }
+       if _, statErr := os.Stat(filepath.Join(baseDir, filePath)); 
!errors.Is(statErr, os.ErrNotExist) {
+               t.Fatalf("deleted file exists or stat failed with unexpected 
error: %v", statErr)
+       }
+}
+
+func TestFSListMissingPrefixReturnsEmpty(t *testing.T) {
+       fs, err := NewFS(t.TempDir())
+       if err != nil {
+               t.Fatalf("NewFS failed: %v", err)
+       }
+
+       files, err := fs.List(context.Background(), "missing")
+       if err != nil {
+               t.Fatalf("List failed: %v", err)
+       }
+       if len(files) != 0 {
+               t.Fatalf("files = %v, want empty", files)
+       }
+}
+
+func TestFSListEmptyPrefixReturnsFiles(t *testing.T) {
+       baseDir := filepath.Join(t.TempDir(), "remote")
+       fs, err := NewFS(baseDir)
+       if err != nil {
+               t.Fatalf("NewFS failed: %v", err)
+       }
+
+       const filePath = "snapshot/data/test.txt"
+       if err = fs.Upload(context.Background(), filePath, 
strings.NewReader("hello")); err != nil {
+               t.Fatalf("Upload failed: %v", err)
+       }
+
+       files, err := fs.List(context.Background(), "")
+       if err != nil {
+               t.Fatalf("List failed: %v", err)
+       }
+       if len(files) != 1 || files[0] != filePath {
+               t.Fatalf("files = %v, want [%s]", files, filePath)
+       }
+}
+
+func TestFSRejectsRootFileOperationPaths(t *testing.T) {
+       fs, err := NewFS(t.TempDir())
+       if err != nil {
+               t.Fatalf("NewFS failed: %v", err)
+       }
+
+       tests := []struct {
+               run  func(string) error
+               name string
+               path string
+       }{
+               {
+                       name: "upload empty",
+                       path: "",
+                       run: func(path string) error {
+                               return fs.Upload(context.Background(), path, 
strings.NewReader("root"))
+                       },
+               },
+               {
+                       name: "upload dot",
+                       path: ".",
+                       run: func(path string) error {
+                               return fs.Upload(context.Background(), path, 
strings.NewReader("root"))
+                       },
+               },
+               {
+                       name: "download empty",
+                       path: "",
+                       run: func(path string) error {
+                               reader, downloadErr := 
fs.Download(context.Background(), path)
+                               if reader != nil {
+                                       reader.Close()
+                               }
+                               return downloadErr
+                       },
+               },
+               {
+                       name: "delete dot",
+                       path: ".",
+                       run: func(path string) error {
+                               return fs.Delete(context.Background(), path)
+                       },
+               },
+       }
+
+       for _, testCase := range tests {
+               t.Run(testCase.name, func(t *testing.T) {
+                       if err := testCase.run(testCase.path); err == nil {
+                               t.Fatal("expected root file operation path to 
be rejected")
+                       }
+               })
+       }
+}
+
+func TestFSRejectsPathTraversal(t *testing.T) {
+       baseDir := filepath.Join(t.TempDir(), "remote")
+       fs, err := NewFS(baseDir)
+       if err != nil {
+               t.Fatalf("NewFS failed: %v", err)
+       }
+
+       escapedPath := filepath.Join("..", "escaped.txt")
+       tests := []struct {
+               run  func() error
+               name string
+       }{
+               {
+                       name: "upload",
+                       run: func() error {
+                               return fs.Upload(context.Background(), 
escapedPath, strings.NewReader("escape"))
+                       },
+               },
+               {
+                       name: "download",
+                       run: func() error {
+                               reader, downloadErr := 
fs.Download(context.Background(), escapedPath)
+                               if reader != nil {
+                                       reader.Close()
+                               }
+                               return downloadErr
+                       },
+               },
+               {
+                       name: "list",
+                       run: func() error {
+                               _, listErr := fs.List(context.Background(), 
escapedPath)
+                               return listErr
+                       },
+               },
+               {
+                       name: "delete",
+                       run: func() error {
+                               return fs.Delete(context.Background(), 
escapedPath)
+                       },
+               },
+       }
+
+       for _, testCase := range tests {
+               t.Run(testCase.name, func(t *testing.T) {
+                       if err := testCase.run(); err == nil {
+                               t.Fatal("expected path traversal to be 
rejected")
+                       }
+                       if _, statErr := 
os.Stat(filepath.Join(filepath.Dir(baseDir), "escaped.txt")); 
!errors.Is(statErr, os.ErrNotExist) {
+                               t.Fatalf("escaped file exists or stat failed 
with unexpected error: %v", statErr)
+                       }
+               })
+       }
+}
+
+func TestFSRejectsAbsolutePath(t *testing.T) {
+       baseDir := filepath.Join(t.TempDir(), "remote")
+       fs, err := NewFS(baseDir)
+       if err != nil {
+               t.Fatalf("NewFS failed: %v", err)
+       }
+
+       absolutePath := filepath.Join(baseDir, "escaped.txt")
+       if err = fs.Upload(context.Background(), absolutePath, 
strings.NewReader("escape")); err == nil {
+               t.Fatal("expected absolute path to be rejected")
+       }
+       if _, statErr := os.Stat(absolutePath); !errors.Is(statErr, 
os.ErrNotExist) {
+               t.Fatalf("absolute path was written or stat failed with 
unexpected error: %v", statErr)
+       }
+}
+
+func TestFSRejectsVolumeName(t *testing.T) {
+       fs, err := NewFS(t.TempDir())
+       if err != nil {
+               t.Fatalf("NewFS failed: %v", err)
+       }
+
+       if err = fs.Upload(context.Background(), `C:escaped.txt`, 
strings.NewReader("escape")); err == nil {
+               t.Fatal("expected volume-name path to be rejected")
+       }
+}
+
+func TestFSRejectsSymlinkEscape(t *testing.T) {
+       baseDir := filepath.Join(t.TempDir(), "remote")
+       outsideDir := filepath.Join(t.TempDir(), "outside")
+       if err := os.MkdirAll(outsideDir, dirPerm); err != nil {
+               t.Fatalf("failed to create outside directory: %v", err)
+       }
+       fs, err := NewFS(baseDir)
+       if err != nil {
+               t.Fatalf("NewFS failed: %v", err)
+       }
+       if err = os.Symlink(outsideDir, filepath.Join(baseDir, "link")); err != 
nil {
+               t.Skipf("symlinks are not available: %v", err)
+       }
+
+       escapedPath := filepath.Join("link", "escaped.txt")
+       tests := []struct {
+               run  func() error
+               name string
+       }{
+               {
+                       name: "upload",
+                       run: func() error {
+                               return fs.Upload(context.Background(), 
escapedPath, strings.NewReader("escape"))
+                       },
+               },
+               {
+                       name: "download",
+                       run: func() error {
+                               reader, downloadErr := 
fs.Download(context.Background(), escapedPath)
+                               if reader != nil {
+                                       reader.Close()
+                               }
+                               return downloadErr
+                       },
+               },
+               {
+                       name: "list",
+                       run: func() error {
+                               _, listErr := fs.List(context.Background(), 
"link")
+                               return listErr
+                       },
+               },
+               {
+                       name: "delete",
+                       run: func() error {
+                               return fs.Delete(context.Background(), 
escapedPath)
+                       },
+               },
+       }
+
+       for _, testCase := range tests {
+               t.Run(testCase.name, func(t *testing.T) {
+                       if err := testCase.run(); err == nil {
+                               t.Fatal("expected symlink escape to be 
rejected")
+                       }
+                       if _, statErr := os.Stat(filepath.Join(outsideDir, 
"escaped.txt")); !errors.Is(statErr, os.ErrNotExist) {
+                               t.Fatalf("escaped file exists or stat failed 
with unexpected error: %v", statErr)
+                       }
+               })
+       }
+}
diff --git a/pkg/path/path.go b/pkg/path/path.go
index 6cb49c1ed..b0590ed88 100644
--- a/pkg/path/path.go
+++ b/pkg/path/path.go
@@ -32,3 +32,17 @@ func Get(p string) (string, error) {
        }
        return filepath.Abs(p)
 }
+
+// HasVolumeName reports whether p starts with a platform volume name or a
+// Windows drive prefix. The explicit drive-prefix check keeps validation
+// portable on non-Windows platforms.
+func HasVolumeName(p string) bool {
+       if filepath.VolumeName(p) != "" || 
filepath.VolumeName(filepath.FromSlash(p)) != "" {
+               return true
+       }
+       return len(p) >= 2 && isASCIILetter(p[0]) && p[1] == ':'
+}
+
+func isASCIILetter(c byte) bool {
+       return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
+}
diff --git a/pkg/path/path_test.go b/pkg/path/path_test.go
index 472b85af0..f8f92dc6e 100644
--- a/pkg/path/path_test.go
+++ b/pkg/path/path_test.go
@@ -81,3 +81,39 @@ func TestGet(t *testing.T) {
                })
        }
 }
+
+func TestHasVolumeName(t *testing.T) {
+       tests := []struct {
+               name string
+               path string
+               want bool
+       }{
+               {
+                       name: "relative path",
+                       path: "snapshot/data/file.txt",
+               },
+               {
+                       name: "windows drive relative",
+                       path: `C:snapshot\data.txt`,
+                       want: true,
+               },
+               {
+                       name: "windows drive absolute slash",
+                       path: "C:/snapshot/data.txt",
+                       want: true,
+               },
+               {
+                       name: "windows drive absolute backslash",
+                       path: `C:\snapshot\data.txt`,
+                       want: true,
+               },
+       }
+
+       for _, test := range tests {
+               t.Run(test.name, func(t *testing.T) {
+                       if got := HasVolumeName(test.path); got != test.want {
+                               t.Fatalf("HasVolumeName(%q) = %t, want %t", 
test.path, got, test.want)
+                       }
+               })
+       }
+}

Reply via email to