hanahmily commented on code in PR #642: URL: https://github.com/apache/skywalking-banyandb/pull/642#discussion_r2041550089
########## pkg/tls/reloader_test.go: ########## @@ -0,0 +1,953 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tls + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/apache/skywalking-banyandb/pkg/logger" +) + +func generateSelfSignedCert(t *testing.T, commonName string) (certPEM, keyPEM []byte) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: commonName, + }, + DNSNames: []string{commonName}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + + certPEM = pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + return certPEM, keyPEM +} + +func TestReloader_CertificateRotation(t *testing.T) { + // Create temporary directory for test files + tempDir, err := os.MkdirTemp("", "tls-test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + certFile := filepath.Join(tempDir, "cert.pem") + keyFile := filepath.Join(tempDir, "key.pem") + + // Generate initial certificate + certPEM1, keyPEM1 := generateSelfSignedCert(t, "test1.local") + err = os.WriteFile(certFile, certPEM1, 0o600) + require.NoError(t, err) + err = os.WriteFile(keyFile, keyPEM1, 0o600) + require.NoError(t, err) + + log := logger.GetLogger("tls-test") + reloader, err := NewReloader(certFile, keyFile, log) + require.NoError(t, err) + defer reloader.Stop() + + // Start reloader + reloader.Start() + + // Wait for initial certificate to be loaded + time.Sleep(100 * time.Millisecond) + + // Verify initial certificate + tlsConfig := reloader.GetTLSConfig() + cert, err := tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Equal(t, "test1.local", leafCert.Subject.CommonName) + + // Create a new directory and move files there to trigger watcher + newDir := filepath.Join(tempDir, "new") + require.NoError(t, os.Mkdir(newDir, 0o755)) + + // Move files to new location Review Comment: Could you please explain why you're moving the files? It's not a user case we discussed before. ########## pkg/tls/reloader_test.go: ########## @@ -0,0 +1,953 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tls + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/apache/skywalking-banyandb/pkg/logger" +) + +func generateSelfSignedCert(t *testing.T, commonName string) (certPEM, keyPEM []byte) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: commonName, + }, + DNSNames: []string{commonName}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + + certPEM = pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + return certPEM, keyPEM +} + +func TestReloader_CertificateRotation(t *testing.T) { + // Create temporary directory for test files + tempDir, err := os.MkdirTemp("", "tls-test") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + certFile := filepath.Join(tempDir, "cert.pem") + keyFile := filepath.Join(tempDir, "key.pem") + + // Generate initial certificate + certPEM1, keyPEM1 := generateSelfSignedCert(t, "test1.local") + err = os.WriteFile(certFile, certPEM1, 0o600) + require.NoError(t, err) + err = os.WriteFile(keyFile, keyPEM1, 0o600) + require.NoError(t, err) + + log := logger.GetLogger("tls-test") + reloader, err := NewReloader(certFile, keyFile, log) + require.NoError(t, err) + defer reloader.Stop() + + // Start reloader + reloader.Start() + + // Wait for initial certificate to be loaded + time.Sleep(100 * time.Millisecond) + + // Verify initial certificate + tlsConfig := reloader.GetTLSConfig() + cert, err := tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Equal(t, "test1.local", leafCert.Subject.CommonName) + + // Create a new directory and move files there to trigger watcher + newDir := filepath.Join(tempDir, "new") + require.NoError(t, os.Mkdir(newDir, 0o755)) + + // Move files to new location + newCertFile := filepath.Join(newDir, "cert.pem") + newKeyFile := filepath.Join(newDir, "key.pem") + + // Move files to trigger watcher events + require.NoError(t, os.Rename(certFile, newCertFile)) + require.NoError(t, os.Rename(keyFile, newKeyFile)) + + // Wait for watcher to detect changes + time.Sleep(500 * time.Millisecond) + + // Verify certificate is still available (last known good state) + cert, err = tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err = x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Equal(t, "test1.local", leafCert.Subject.CommonName) +} + +func TestReloader_FileOperations(t *testing.T) { + // Each subtest needs its own reloader instance to avoid interference between tests + t.Run("Move files to new location", func(t *testing.T) { + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "cert.pem") + keyFile := filepath.Join(tempDir, "key.pem") + + // Create initial files + certPEM, keyPEM := generateSelfSignedCert(t, "initial.local") + require.NoError(t, os.WriteFile(certFile, certPEM, 0o600)) + require.NoError(t, os.WriteFile(keyFile, keyPEM, 0o600)) + + // Create reloader + log := logger.GetLogger("tls-test") + reloader, err := NewReloader(certFile, keyFile, log) + require.NoError(t, err) + require.NoError(t, reloader.Start()) + defer reloader.Stop() + + // Wait for initial certificate to be loaded + time.Sleep(100 * time.Millisecond) + + // Initial verification + tlsConfig := reloader.GetTLSConfig() + cert, err := tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Equal(t, "initial.local", leafCert.Subject.CommonName) + + // Create a new directory for the files + newDir := filepath.Join(tempDir, "new") + require.NoError(t, os.Mkdir(newDir, 0o755)) + + // Move files to new location + newCertFile := filepath.Join(newDir, "cert.pem") + newKeyFile := filepath.Join(newDir, "key.pem") + + // Move files to trigger watcher events + require.NoError(t, os.Rename(certFile, newCertFile)) + require.NoError(t, os.Rename(keyFile, newKeyFile)) + + // Wait for watcher to detect changes + time.Sleep(500 * time.Millisecond) + + // Verify certificate is still available (last known good state) + cert, err = tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err = x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Equal(t, "initial.local", leafCert.Subject.CommonName) + }) + + t.Run("Remove files without recreation", func(t *testing.T) { + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "cert.pem") + keyFile := filepath.Join(tempDir, "key.pem") + + // Create initial files + certPEM, keyPEM := generateSelfSignedCert(t, "initial.local") + require.NoError(t, os.WriteFile(certFile, certPEM, 0o600)) + require.NoError(t, os.WriteFile(keyFile, keyPEM, 0o600)) + + // Create reloader + log := logger.GetLogger("tls-test") + reloader, err := NewReloader(certFile, keyFile, log) + require.NoError(t, err) + require.NoError(t, reloader.Start()) + defer reloader.Stop() + + // Wait for initial certificate to be loaded + time.Sleep(100 * time.Millisecond) + + // Initial verification + tlsConfig := reloader.GetTLSConfig() + cert, err := tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Equal(t, "initial.local", leafCert.Subject.CommonName) + + // Remove files + require.NoError(t, os.Remove(certFile)) + require.NoError(t, os.Remove(keyFile)) + + // Wait for file system events to be processed + time.Sleep(500 * time.Millisecond) + + // Verify certificate is still available (last known good state) + cert, err = tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err = x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Equal(t, "initial.local", leafCert.Subject.CommonName) + }) + + t.Run("Create invalid files", func(t *testing.T) { + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "cert.pem") + keyFile := filepath.Join(tempDir, "key.pem") + + // Create initial files + certPEM, keyPEM := generateSelfSignedCert(t, "initial.local") + require.NoError(t, os.WriteFile(certFile, certPEM, 0o600)) + require.NoError(t, os.WriteFile(keyFile, keyPEM, 0o600)) + + // Create reloader + log := logger.GetLogger("tls-test") + reloader, err := NewReloader(certFile, keyFile, log) + require.NoError(t, err) + require.NoError(t, reloader.Start()) + defer reloader.Stop() + + // Wait for initial certificate to be loaded + time.Sleep(100 * time.Millisecond) + + // First ensure valid certificate is loaded through watcher + certPEM2, keyPEM2 := generateSelfSignedCert(t, "before.invalid.local") + require.NoError(t, os.WriteFile(certFile, certPEM2, 0o600)) + require.NoError(t, os.WriteFile(keyFile, keyPEM2, 0o600)) + time.Sleep(500 * time.Millisecond) + + // Verify certificate was updated + tlsConfig := reloader.GetTLSConfig() + cert, err := tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + assert.Equal(t, "before.invalid.local", leafCert.Subject.CommonName) + + // Create invalid certificate and key files + require.NoError(t, os.WriteFile(certFile, []byte("invalid cert"), 0o600)) + require.NoError(t, os.WriteFile(keyFile, []byte("invalid key"), 0o600)) + time.Sleep(500 * time.Millisecond) + + // Should still return the last valid certificate + cert, err = tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err = x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + assert.Equal(t, "before.invalid.local", leafCert.Subject.CommonName) + }) + + t.Run("Recover from invalid files", func(t *testing.T) { + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "cert.pem") + keyFile := filepath.Join(tempDir, "key.pem") + + // Create initial files with a valid certificate + certPEM, keyPEM := generateSelfSignedCert(t, "before.invalid.local") + require.NoError(t, os.WriteFile(certFile, certPEM, 0o600)) + require.NoError(t, os.WriteFile(keyFile, keyPEM, 0o600)) + + // Create reloader + log := logger.GetLogger("tls-test") + reloader, err := NewReloader(certFile, keyFile, log) + require.NoError(t, err) + require.NoError(t, reloader.Start()) + defer reloader.Stop() + + // Wait for initial certificate to be loaded + time.Sleep(100 * time.Millisecond) + + // Initial verification + tlsConfig := reloader.GetTLSConfig() + cert, err := tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + assert.Equal(t, "before.invalid.local", leafCert.Subject.CommonName) + + // Create invalid certificate and key files + require.NoError(t, os.WriteFile(certFile, []byte("invalid cert"), 0o600)) + require.NoError(t, os.WriteFile(keyFile, []byte("invalid key"), 0o600)) + time.Sleep(500 * time.Millisecond) + + // Create valid files after invalid ones + certPEM3, keyPEM3 := generateSelfSignedCert(t, "recovered.local") + require.NoError(t, os.WriteFile(certFile, certPEM3, 0o600)) + require.NoError(t, os.WriteFile(keyFile, keyPEM3, 0o600)) + time.Sleep(500 * time.Millisecond) + + // Certificate should be updated automatically + cert, err = tlsConfig.GetCertificate(nil) + require.NoError(t, err) + leafCert, err = x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + assert.Equal(t, "recovered.local", leafCert.Subject.CommonName) + }) +} + +func TestNewReloader_Errors(t *testing.T) { + t.Run("empty cert and key files", func(t *testing.T) { + log := logger.GetLogger("tls-test") + reloader, err := NewReloader("", "", log) + assert.Error(t, err) + assert.Nil(t, reloader) + assert.Contains(t, err.Error(), "certFile and keyFile must be provided") + }) + + t.Run("empty cert file", func(t *testing.T) { + log := logger.GetLogger("tls-test") + reloader, err := NewReloader("", "key.pem", log) + assert.Error(t, err) + assert.Nil(t, reloader) + assert.Contains(t, err.Error(), "certFile and keyFile must be provided") + }) + + t.Run("empty key file", func(t *testing.T) { + log := logger.GetLogger("tls-test") + reloader, err := NewReloader("cert.pem", "", log) + assert.Error(t, err) + assert.Nil(t, reloader) + assert.Contains(t, err.Error(), "certFile and keyFile must be provided") + }) + + t.Run("nil logger", func(t *testing.T) { + reloader, err := NewReloader("cert.pem", "key.pem", nil) + assert.Error(t, err) + assert.Nil(t, reloader) + assert.Contains(t, err.Error(), "logger must not be nil") + }) + + t.Run("invalid certificate files", func(t *testing.T) { + log := logger.GetLogger("tls-test") + reloader, err := NewReloader("nonexistent.pem", "nonexistent.pem", log) + assert.Error(t, err) + assert.Nil(t, reloader) + assert.Contains(t, err.Error(), "failed to load initial TLS certificate") + }) +} + +func TestReloader_InvalidCertificates(t *testing.T) { Review Comment: The test appears to be a duplicate of "Create Invalid Files." ########## banyand/liaison/grpc/server.go: ########## @@ -232,18 +242,20 @@ func (s *server) Validate() error { if s.keyFile == "" { return errServerKey } - creds, errTLS := credentials.NewServerTLSFromFile(s.certFile, s.keyFile) - if errTLS != nil { - return errors.Wrap(errTLS, "failed to load cert and key") - } - s.creds = creds return nil } func (s *server) Serve() run.StopNotify { var opts []grpclib.ServerOption if s.tls { - opts = []grpclib.ServerOption{grpclib.Creds(s.creds)} + if err := s.tlsReloader.Start(); err != nil { + s.log.Error().Err(err).Msg("Failed to start TLSReloader for gRPC") + close(s.stopCh) + return s.stopCh + } + s.log.Info().Str("certFile", s.certFile).Str("keyFile", s.keyFile).Msg("Starting TLS file monitoring") + creds := s.tlsReloader.GetGRPCTransportCredentials() Review Comment: You seem to revert the changes to the wrong commit, which can't load certs dynamically. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
