This is an automated email from the ASF dual-hosted git repository. ButterBright pushed a commit to branch v0.10.x in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git
commit e66a0808fb578dfe7a1aa88d0d3e530b0f8dd376 Author: Tanay Paul <[email protected]> AuthorDate: Thu May 21 06:18:52 2026 +0530 Fix backup restore path traversal (#1135) --- banyand/backup/restore.go | 35 +++-- banyand/backup/restore_test.go | 97 ++++++++++++ pkg/fs/remote/local/local.go | 94 +++++++++++- pkg/fs/remote/local/local_test.go | 307 ++++++++++++++++++++++++++++++++++++++ pkg/path/path.go | 14 ++ pkg/path/path_test.go | 36 +++++ 6 files changed, 567 insertions(+), 16 deletions(-) diff --git a/banyand/backup/restore.go b/banyand/backup/restore.go index 93a3de4b1..dd5db44b8 100644 --- a/banyand/backup/restore.go +++ b/banyand/backup/restore.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "os" + "path" "path/filepath" "strings" @@ -193,13 +194,12 @@ func restoreByName(fs remote.FS, timeDir, rootPath, catalogName string) error { logger.Infof("Restoring %s to %s from %s, remote total %d files", catalogName, localDir, remotePrefix, len(remoteFiles)) remoteRelSet := make(map[string]bool) - var relPath string for _, remoteFile := range remoteFiles { - relPath, err = filepath.Rel(timeDir, remoteFile) - if err != nil { - return fmt.Errorf("failed to get relative path for %s: %w", remoteFile, err) + relPath, relPathErr := validatedRemoteRelPath(timeDir, catalogName, remoteFile) + if relPathErr != nil { + return relPathErr } - remoteRelSet[filepath.ToSlash(relPath)] = true + remoteRelSet[path.Join(catalogName, relPath)] = true } localFiles, err := getAllFiles(localDir) @@ -208,7 +208,7 @@ func restoreByName(fs remote.FS, timeDir, rootPath, catalogName string) error { } for _, localRelPath := range localFiles { - localRelPathWithCatalog := filepath.Join(catalogName, localRelPath) + localRelPathWithCatalog := path.Join(catalogName, filepath.ToSlash(localRelPath)) if !remoteRelSet[localRelPathWithCatalog] { localPath := filepath.Join(localDir, localRelPath) logger.Infof("found local file: %s not exist in the remote storage, so delete it", localRelPathWithCatalog) @@ -220,11 +220,10 @@ func restoreByName(fs remote.FS, timeDir, rootPath, catalogName string) error { } for _, remoteFile := range remoteFiles { - relPath, err := filepath.Rel(filepath.Join(timeDir, catalogName), remoteFile) + relPath, err := validatedRemoteRelPath(timeDir, catalogName, remoteFile) if err != nil { - return fmt.Errorf("failed to get relative path for %s: %w", remoteFile, err) + return err } - relPath = filepath.ToSlash(relPath) localPath := filepath.Join(rootPath, catalogName, storage.DataDir, relPath) if !contains(localFiles, relPath) { @@ -244,6 +243,24 @@ func restoreByName(fs remote.FS, timeDir, rootPath, catalogName string) error { return nil } +func validatedRemoteRelPath(timeDir, catalogName, remoteFile string) (string, error) { + remotePath := filepath.ToSlash(remoteFile) + if path.IsAbs(remotePath) || banyandbpath.HasVolumeName(remotePath) { + return "", fmt.Errorf("remote file %q escapes backup prefix", remoteFile) + } + prefix := path.Clean(path.Join(filepath.ToSlash(timeDir), filepath.ToSlash(catalogName))) + cleanRemotePath := path.Clean(remotePath) + if banyandbpath.HasVolumeName(cleanRemotePath) || cleanRemotePath == "." || prefix == "." { + return "", fmt.Errorf("remote file %q escapes backup prefix", remoteFile) + } + prefixWithSlash := prefix + "/" + if !strings.HasPrefix(cleanRemotePath, prefixWithSlash) { + return "", fmt.Errorf("remote file %q escapes backup prefix", remoteFile) + } + relPath := strings.TrimPrefix(cleanRemotePath, prefixWithSlash) + return relPath, nil +} + func cleanEmptyDirs(dir, stopDir string) { for { if dir == stopDir || dir == "." { diff --git a/banyand/backup/restore_test.go b/banyand/backup/restore_test.go index 086f18a23..ca76f984e 100644 --- a/banyand/backup/restore_test.go +++ b/banyand/backup/restore_test.go @@ -19,6 +19,7 @@ package backup import ( "context" + "io" "os" "path/filepath" "strings" @@ -192,3 +193,99 @@ func TestRestoreSame(t *testing.T) { t.Fatalf("expected extra file %q exist", extraFilePath) } } + +func TestRestoreRejectsRemotePathTraversal(t *testing.T) { + timeDir := testTimeDir + catalogName := snapshot.CatalogName(commonv1.Catalog_CATALOG_STREAM) + localRestoreDir := t.TempDir() + escapedFile := filepath.Join(localRestoreDir, catalogName, "escaped.txt") + fs := &restoreTraversalFS{ + files: []string{ + filepath.ToSlash(filepath.Join(timeDir, catalogName, "..", "escaped.txt")), + }, + } + + err := restoreByName(fs, timeDir, localRestoreDir, catalogName) + if err == nil { + t.Fatal("expected restoreByName to reject remote path traversal") + } + if _, statErr := os.Stat(escapedFile); !os.IsNotExist(statErr) { + t.Fatalf("escaped file exists or stat failed with unexpected error: %v", statErr) + } +} + +func TestValidatedRemoteRelPath(t *testing.T) { + timeDir := testTimeDir + catalogName := snapshot.CatalogName(commonv1.Catalog_CATALOG_STREAM) + validRemoteFile := filepath.ToSlash(filepath.Join(timeDir, catalogName, "nested", "test.txt")) + + relPath, err := validatedRemoteRelPath(timeDir, catalogName, validRemoteFile) + if err != nil { + t.Fatalf("validatedRemoteRelPath failed: %v", err) + } + if relPath != "nested/test.txt" { + t.Fatalf("relPath = %q, want %q", relPath, "nested/test.txt") + } +} + +func TestValidatedRemoteRelPathRejectsInvalidPaths(t *testing.T) { + timeDir := testTimeDir + catalogName := snapshot.CatalogName(commonv1.Catalog_CATALOG_STREAM) + tests := []struct { + remoteFile string + name string + }{ + { + name: "absolute path", + remoteFile: "/tmp/backup/file.txt", + }, + { + name: "catalog prefix only", + remoteFile: filepath.ToSlash(filepath.Join(timeDir, catalogName)), + }, + { + name: "outside catalog prefix", + remoteFile: filepath.ToSlash(filepath.Join(timeDir, "measure", "test.txt")), + }, + { + name: "parent traversal", + remoteFile: filepath.ToSlash(filepath.Join(timeDir, catalogName, "..", "escaped.txt")), + }, + { + name: "volume name", + remoteFile: `C:/backup/file.txt`, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + if _, err := validatedRemoteRelPath(timeDir, catalogName, testCase.remoteFile); err == nil { + t.Fatal("expected invalid remote path to be rejected") + } + }) + } +} + +type restoreTraversalFS struct { + files []string +} + +func (r *restoreTraversalFS) Upload(_ context.Context, _ string, _ io.Reader) error { + return nil +} + +func (r *restoreTraversalFS) Download(_ context.Context, _ string) (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader("escape")), nil +} + +func (r *restoreTraversalFS) List(_ context.Context, _ string) ([]string, error) { + return r.files, nil +} + +func (r *restoreTraversalFS) Delete(_ context.Context, _ string) error { + return nil +} + +func (r *restoreTraversalFS) Close() error { + return nil +} diff --git a/pkg/fs/remote/local/local.go b/pkg/fs/remote/local/local.go index 01f204552..176e18e2f 100644 --- a/pkg/fs/remote/local/local.go +++ b/pkg/fs/remote/local/local.go @@ -20,11 +20,14 @@ package local import ( "context" + "fmt" "io" "os" "path/filepath" + "strings" "github.com/apache/skywalking-banyandb/pkg/fs/remote" + pathutil "github.com/apache/skywalking-banyandb/pkg/path" ) const dirPerm = 0o755 @@ -40,14 +43,28 @@ func NewFS(baseDir string) (remote.FS, error) { if err := os.MkdirAll(baseDir, dirPerm); err != nil { return nil, err } - return &fs{baseDir: baseDir}, nil + cleanBaseDir, err := filepath.Abs(baseDir) + if err != nil { + return nil, err + } + realBaseDir, err := filepath.EvalSymlinks(cleanBaseDir) + if err != nil { + return nil, err + } + return &fs{baseDir: realBaseDir}, nil } func (l *fs) Upload(_ context.Context, path string, data io.Reader) error { - fullPath := filepath.Join(l.baseDir, path) - if err := os.MkdirAll(filepath.Dir(fullPath), dirPerm); err != nil { + fullPath, err := l.fullPath(path, false) + if err != nil { return err } + if mkdirErr := os.MkdirAll(filepath.Dir(fullPath), dirPerm); mkdirErr != nil { + return mkdirErr + } + if err = l.ensureResolvedWithinBase(filepath.Dir(fullPath)); err != nil { + return fmt.Errorf("path %q escapes base directory: %w", path, err) + } file, err := os.Create(fullPath) if err != nil { @@ -60,15 +77,21 @@ func (l *fs) Upload(_ context.Context, path string, data io.Reader) error { } func (l *fs) Download(_ context.Context, path string) (io.ReadCloser, error) { - fullPath := filepath.Join(l.baseDir, path) + fullPath, err := l.fullPath(path, false) + if err != nil { + return nil, err + } return os.Open(fullPath) } func (l *fs) List(_ context.Context, prefix string) ([]string, error) { var files []string - fullPath := filepath.Join(l.baseDir, prefix) + fullPath, err := l.fullPath(prefix, true) + if err != nil { + return nil, err + } - err := filepath.Walk(fullPath, func(path string, info os.FileInfo, err error) error { + err = filepath.Walk(fullPath, func(path string, info os.FileInfo, err error) error { if err != nil { return err } @@ -89,10 +112,67 @@ func (l *fs) List(_ context.Context, prefix string) ([]string, error) { } func (l *fs) Delete(_ context.Context, path string) error { - fullPath := filepath.Join(l.baseDir, path) + fullPath, err := l.fullPath(path, false) + if err != nil { + return err + } return os.Remove(fullPath) } func (l *fs) Close() error { return nil } + +func (l *fs) fullPath(path string, allowRoot bool) (string, error) { + if filepath.IsAbs(path) || pathutil.HasVolumeName(path) { + return "", fmt.Errorf("path %q escapes base directory", path) + } + cleanPath := filepath.Clean(path) + if !allowRoot && cleanPath == "." { + return "", fmt.Errorf("path %q escapes base directory", path) + } + if pathutil.HasVolumeName(cleanPath) || cleanPath == ".." || strings.HasPrefix(cleanPath, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("path %q escapes base directory", path) + } + fullPath := filepath.Join(l.baseDir, cleanPath) + relPath, err := filepath.Rel(l.baseDir, fullPath) + if err != nil { + return "", err + } + if relPath == ".." || strings.HasPrefix(relPath, ".."+string(filepath.Separator)) || filepath.IsAbs(relPath) { + return "", fmt.Errorf("path %q escapes base directory", path) + } + if err := l.ensureResolvedWithinBase(fullPath); err != nil { + return "", fmt.Errorf("path %q escapes base directory: %w", path, err) + } + return fullPath, nil +} + +func (l *fs) ensureResolvedWithinBase(path string) error { + existingPath := path + for { + if _, err := os.Lstat(existingPath); err == nil { + break + } else if !os.IsNotExist(err) { + return err + } + parentPath := filepath.Dir(existingPath) + if parentPath == existingPath { + return os.ErrNotExist + } + existingPath = parentPath + } + + realPath, err := filepath.EvalSymlinks(existingPath) + if err != nil { + return err + } + relPath, err := filepath.Rel(l.baseDir, realPath) + if err != nil { + return err + } + if relPath == ".." || strings.HasPrefix(relPath, ".."+string(filepath.Separator)) || filepath.IsAbs(relPath) { + return fmt.Errorf("resolved path %q is outside base directory %q", realPath, l.baseDir) + } + return nil +} diff --git a/pkg/fs/remote/local/local_test.go b/pkg/fs/remote/local/local_test.go new file mode 100644 index 000000000..8ad28f1ca --- /dev/null +++ b/pkg/fs/remote/local/local_test.go @@ -0,0 +1,307 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package local + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestFSOperationsStayWithinBase(t *testing.T) { + baseDir := filepath.Join(t.TempDir(), "remote") + fs, err := NewFS(baseDir) + if err != nil { + t.Fatalf("NewFS failed: %v", err) + } + + const content = "hello" + filePath := filepath.Join("snapshot", "data", "test.txt") + if err = fs.Upload(context.Background(), filePath, strings.NewReader(content)); err != nil { + t.Fatalf("Upload failed: %v", err) + } + + files, err := fs.List(context.Background(), "snapshot") + if err != nil { + t.Fatalf("List failed: %v", err) + } + if len(files) != 1 || files[0] != filepath.ToSlash(filePath) { + t.Fatalf("files = %v, want [%s]", files, filepath.ToSlash(filePath)) + } + + reader, err := fs.Download(context.Background(), filePath) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + got, err := io.ReadAll(reader) + closeErr := reader.Close() + if err != nil { + t.Fatalf("failed to read downloaded content: %v", err) + } + if closeErr != nil { + t.Fatalf("failed to close downloaded content: %v", closeErr) + } + if string(got) != content { + t.Fatalf("content = %q, want %q", string(got), content) + } + + if err = fs.Delete(context.Background(), filePath); err != nil { + t.Fatalf("Delete failed: %v", err) + } + if _, statErr := os.Stat(filepath.Join(baseDir, filePath)); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("deleted file exists or stat failed with unexpected error: %v", statErr) + } +} + +func TestFSListMissingPrefixReturnsEmpty(t *testing.T) { + fs, err := NewFS(t.TempDir()) + if err != nil { + t.Fatalf("NewFS failed: %v", err) + } + + files, err := fs.List(context.Background(), "missing") + if err != nil { + t.Fatalf("List failed: %v", err) + } + if len(files) != 0 { + t.Fatalf("files = %v, want empty", files) + } +} + +func TestFSListEmptyPrefixReturnsFiles(t *testing.T) { + baseDir := filepath.Join(t.TempDir(), "remote") + fs, err := NewFS(baseDir) + if err != nil { + t.Fatalf("NewFS failed: %v", err) + } + + const filePath = "snapshot/data/test.txt" + if err = fs.Upload(context.Background(), filePath, strings.NewReader("hello")); err != nil { + t.Fatalf("Upload failed: %v", err) + } + + files, err := fs.List(context.Background(), "") + if err != nil { + t.Fatalf("List failed: %v", err) + } + if len(files) != 1 || files[0] != filePath { + t.Fatalf("files = %v, want [%s]", files, filePath) + } +} + +func TestFSRejectsRootFileOperationPaths(t *testing.T) { + fs, err := NewFS(t.TempDir()) + if err != nil { + t.Fatalf("NewFS failed: %v", err) + } + + tests := []struct { + run func(string) error + name string + path string + }{ + { + name: "upload empty", + path: "", + run: func(path string) error { + return fs.Upload(context.Background(), path, strings.NewReader("root")) + }, + }, + { + name: "upload dot", + path: ".", + run: func(path string) error { + return fs.Upload(context.Background(), path, strings.NewReader("root")) + }, + }, + { + name: "download empty", + path: "", + run: func(path string) error { + reader, downloadErr := fs.Download(context.Background(), path) + if reader != nil { + reader.Close() + } + return downloadErr + }, + }, + { + name: "delete dot", + path: ".", + run: func(path string) error { + return fs.Delete(context.Background(), path) + }, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + if err := testCase.run(testCase.path); err == nil { + t.Fatal("expected root file operation path to be rejected") + } + }) + } +} + +func TestFSRejectsPathTraversal(t *testing.T) { + baseDir := filepath.Join(t.TempDir(), "remote") + fs, err := NewFS(baseDir) + if err != nil { + t.Fatalf("NewFS failed: %v", err) + } + + escapedPath := filepath.Join("..", "escaped.txt") + tests := []struct { + run func() error + name string + }{ + { + name: "upload", + run: func() error { + return fs.Upload(context.Background(), escapedPath, strings.NewReader("escape")) + }, + }, + { + name: "download", + run: func() error { + reader, downloadErr := fs.Download(context.Background(), escapedPath) + if reader != nil { + reader.Close() + } + return downloadErr + }, + }, + { + name: "list", + run: func() error { + _, listErr := fs.List(context.Background(), escapedPath) + return listErr + }, + }, + { + name: "delete", + run: func() error { + return fs.Delete(context.Background(), escapedPath) + }, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + if err := testCase.run(); err == nil { + t.Fatal("expected path traversal to be rejected") + } + if _, statErr := os.Stat(filepath.Join(filepath.Dir(baseDir), "escaped.txt")); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("escaped file exists or stat failed with unexpected error: %v", statErr) + } + }) + } +} + +func TestFSRejectsAbsolutePath(t *testing.T) { + baseDir := filepath.Join(t.TempDir(), "remote") + fs, err := NewFS(baseDir) + if err != nil { + t.Fatalf("NewFS failed: %v", err) + } + + absolutePath := filepath.Join(baseDir, "escaped.txt") + if err = fs.Upload(context.Background(), absolutePath, strings.NewReader("escape")); err == nil { + t.Fatal("expected absolute path to be rejected") + } + if _, statErr := os.Stat(absolutePath); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("absolute path was written or stat failed with unexpected error: %v", statErr) + } +} + +func TestFSRejectsVolumeName(t *testing.T) { + fs, err := NewFS(t.TempDir()) + if err != nil { + t.Fatalf("NewFS failed: %v", err) + } + + if err = fs.Upload(context.Background(), `C:escaped.txt`, strings.NewReader("escape")); err == nil { + t.Fatal("expected volume-name path to be rejected") + } +} + +func TestFSRejectsSymlinkEscape(t *testing.T) { + baseDir := filepath.Join(t.TempDir(), "remote") + outsideDir := filepath.Join(t.TempDir(), "outside") + if err := os.MkdirAll(outsideDir, dirPerm); err != nil { + t.Fatalf("failed to create outside directory: %v", err) + } + fs, err := NewFS(baseDir) + if err != nil { + t.Fatalf("NewFS failed: %v", err) + } + if err = os.Symlink(outsideDir, filepath.Join(baseDir, "link")); err != nil { + t.Skipf("symlinks are not available: %v", err) + } + + escapedPath := filepath.Join("link", "escaped.txt") + tests := []struct { + run func() error + name string + }{ + { + name: "upload", + run: func() error { + return fs.Upload(context.Background(), escapedPath, strings.NewReader("escape")) + }, + }, + { + name: "download", + run: func() error { + reader, downloadErr := fs.Download(context.Background(), escapedPath) + if reader != nil { + reader.Close() + } + return downloadErr + }, + }, + { + name: "list", + run: func() error { + _, listErr := fs.List(context.Background(), "link") + return listErr + }, + }, + { + name: "delete", + run: func() error { + return fs.Delete(context.Background(), escapedPath) + }, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + if err := testCase.run(); err == nil { + t.Fatal("expected symlink escape to be rejected") + } + if _, statErr := os.Stat(filepath.Join(outsideDir, "escaped.txt")); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("escaped file exists or stat failed with unexpected error: %v", statErr) + } + }) + } +} diff --git a/pkg/path/path.go b/pkg/path/path.go index 6cb49c1ed..b0590ed88 100644 --- a/pkg/path/path.go +++ b/pkg/path/path.go @@ -32,3 +32,17 @@ func Get(p string) (string, error) { } return filepath.Abs(p) } + +// HasVolumeName reports whether p starts with a platform volume name or a +// Windows drive prefix. The explicit drive-prefix check keeps validation +// portable on non-Windows platforms. +func HasVolumeName(p string) bool { + if filepath.VolumeName(p) != "" || filepath.VolumeName(filepath.FromSlash(p)) != "" { + return true + } + return len(p) >= 2 && isASCIILetter(p[0]) && p[1] == ':' +} + +func isASCIILetter(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') +} diff --git a/pkg/path/path_test.go b/pkg/path/path_test.go index 472b85af0..f8f92dc6e 100644 --- a/pkg/path/path_test.go +++ b/pkg/path/path_test.go @@ -81,3 +81,39 @@ func TestGet(t *testing.T) { }) } } + +func TestHasVolumeName(t *testing.T) { + tests := []struct { + name string + path string + want bool + }{ + { + name: "relative path", + path: "snapshot/data/file.txt", + }, + { + name: "windows drive relative", + path: `C:snapshot\data.txt`, + want: true, + }, + { + name: "windows drive absolute slash", + path: "C:/snapshot/data.txt", + want: true, + }, + { + name: "windows drive absolute backslash", + path: `C:\snapshot\data.txt`, + want: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := HasVolumeName(test.path); got != test.want { + t.Fatalf("HasVolumeName(%q) = %t, want %t", test.path, got, test.want) + } + }) + } +}
