This is an automated email from the ASF dual-hosted git repository.

albumenj pushed a commit to branch refactor-with-go
in repository https://gitbox.apache.org/repos/asf/dubbo-admin.git


The following commit(s) were added to refs/heads/refactor-with-go by this push:
     new 468b884  Add some uts for ca
468b884 is described below

commit 468b884a7521dac6b028ad20a4f2b969c204f71b
Author: Albumen Kevin <[email protected]>
AuthorDate: Fri Feb 24 16:56:48 2023 +0800

    Add some uts for ca
---
 ca/main.go                      |   9 +-
 ca/pkg/cert/storage.go          |  44 ++++++-
 ca/pkg/cert/storage_test.go     | 225 ++++++++++++++++++++++++++++++++
 ca/pkg/cert/util.go             |  68 +++++++---
 ca/pkg/cert/util_test.go        | 161 +++++++++++++++++++++++
 ca/pkg/k8s/client.go            |  24 +++-
 ca/pkg/security/server.go       |  39 ++++--
 ca/pkg/v1alpha1/ca_impl.go      |  24 ++--
 ca/pkg/v1alpha1/ca_impl_test.go | 276 ++++++++++++++++++++++++++++++++++++++++
 9 files changed, 813 insertions(+), 57 deletions(-)

diff --git a/ca/main.go b/ca/main.go
index ddfe992..649df54 100644
--- a/ca/main.go
+++ b/ca/main.go
@@ -20,6 +20,8 @@ import (
        "github.com/apache/dubbo-admin/ca/pkg/logger"
        "github.com/apache/dubbo-admin/ca/pkg/security"
        "os"
+       "os/signal"
+       "syscall"
 )
 
 // TODO read namespace from env
@@ -38,13 +40,14 @@ func main() {
                CertValidity:     1 * 60 * 60 * 1000,       // 1 hour
        }
 
-       s := &security.Server{
-               Options: options,
-       }
+       s := security.NewServer(options)
 
        s.Init()
        s.Start()
 
        c := make(chan os.Signal)
+       signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
+       signal.Notify(s.StopChan, syscall.SIGINT, syscall.SIGTERM)
+
        <-c
 }
diff --git a/ca/pkg/cert/storage.go b/ca/pkg/cert/storage.go
index 2236c74..c85bd7b 100644
--- a/ca/pkg/cert/storage.go
+++ b/ca/pkg/cert/storage.go
@@ -19,14 +19,19 @@ import (
        "crypto/rsa"
        "crypto/tls"
        "crypto/x509"
+       "github.com/apache/dubbo-admin/ca/pkg/config"
        "github.com/apache/dubbo-admin/ca/pkg/logger"
        "math"
+       "os"
+       "reflect"
        "sync"
        "time"
 )
 
 type Storage struct {
-       Mutex        *sync.Mutex
+       Mutex    *sync.Mutex
+       StopChan chan os.Signal
+
        CaValidity   int64
        CertValidity int64
 
@@ -46,6 +51,18 @@ type Cert struct {
        tlsCert *tls.Certificate
 }
 
+func NewStorage(options *config.Options) *Storage {
+       return &Storage{
+               Mutex:    &sync.Mutex{},
+               StopChan: make(chan os.Signal, 1),
+
+               AuthorityCert: &Cert{},
+               TrustedCert:   []*Cert{},
+               CertValidity:  options.CertValidity,
+               CaValidity:    options.CaValidity,
+       }
+}
+
 func (c *Cert) IsValid() bool {
        if c.Cert == nil || c.CertPem == "" || c.PrivateKey == nil {
                return false
@@ -53,9 +70,16 @@ func (c *Cert) IsValid() bool {
        if time.Now().Before(c.Cert.NotBefore) || 
time.Now().After(c.Cert.NotAfter) {
                return false
        }
-       if c.Cert.PublicKey == c.PrivateKey.Public() {
-               return false
+
+       if c.tlsCert == nil || !reflect.DeepEqual(c.tlsCert.PrivateKey, 
c.PrivateKey) {
+               tlsCert, err := tls.X509KeyPair([]byte(c.CertPem), 
[]byte(EncodePri(c.PrivateKey)))
+               if err != nil {
+                       return false
+               }
+
+               c.tlsCert = &tlsCert
        }
+
        return true
 }
 
@@ -70,19 +94,19 @@ func (c *Cert) NeedRefresh() bool {
        if time.Now().Add(time.Duration(math.Floor(float64(validity)*0.2)) * 
time.Millisecond).After(c.Cert.NotAfter) {
                return true
        }
-       if c.Cert.PublicKey == c.PrivateKey.Public() {
+       if !reflect.DeepEqual(c.Cert.PublicKey, c.PrivateKey.Public()) {
                return true
        }
        return false
 }
 
 func (c *Cert) GetTlsCert() *tls.Certificate {
-       if c.tlsCert != nil {
+       if c.tlsCert != nil && reflect.DeepEqual(c.tlsCert.PrivateKey, 
c.PrivateKey) {
                return c.tlsCert
        }
        tlsCert, err := tls.X509KeyPair([]byte(c.CertPem), 
[]byte(EncodePri(c.PrivateKey)))
        if err != nil {
-               logger.Sugar.Infof("Failed to load x509 cert. %v", err)
+               logger.Sugar.Warnf("Failed to load x509 cert. %v", err)
        }
        c.tlsCert = &tlsCert
        return c.tlsCert
@@ -104,7 +128,6 @@ func (s *Storage) GetServerCert(serverName string) 
*tls.Certificate {
        if !nameSigned {
                s.ServerNames = append(s.ServerNames, serverName)
        }
-       s.ServerNames = append(s.ServerNames, serverName)
        s.ServerCerts = SignServerCert(s.AuthorityCert, s.ServerNames, 
s.CertValidity)
        return s.ServerCerts.GetTlsCert()
 }
@@ -119,5 +142,12 @@ func (s *Storage) RefreshServerCert() {
                        s.ServerCerts = SignServerCert(s.AuthorityCert, 
s.ServerNames, s.CertValidity)
                }
                s.Mutex.Unlock()
+
+               select {
+               case <-s.StopChan:
+                       return
+               default:
+                       continue
+               }
        }
 }
diff --git a/ca/pkg/cert/storage_test.go b/ca/pkg/cert/storage_test.go
new file mode 100644
index 0000000..a34fc37
--- /dev/null
+++ b/ca/pkg/cert/storage_test.go
@@ -0,0 +1,225 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cert
+
+import (
+       "crypto/rand"
+       "crypto/rsa"
+       "crypto/x509"
+       "github.com/apache/dubbo-admin/ca/pkg/config"
+       "github.com/apache/dubbo-admin/ca/pkg/logger"
+       "os"
+       "reflect"
+       "sync"
+       "testing"
+       "time"
+)
+
+func TestIsValid(t *testing.T) {
+       c := &Cert{}
+       if c.IsValid() {
+               t.Errorf("cert is not valid")
+       }
+
+       c.Cert = &x509.Certificate{}
+       if c.IsValid() {
+               t.Errorf("cert is not valid")
+       }
+
+       c.CertPem = "test"
+       if c.IsValid() {
+               t.Errorf("cert is not valid")
+       }
+
+       c.PrivateKey, _ = rsa.GenerateKey(rand.Reader, 4096)
+       if c.IsValid() {
+               t.Errorf("cert is not valid")
+       }
+
+       c.Cert.NotBefore = time.Now().Add(-1 * time.Hour)
+       c.Cert.NotAfter = time.Now().Add(1 * time.Hour)
+       if c.IsValid() {
+               t.Errorf("cert is not valid")
+       }
+
+       c = GenerateAuthorityCert(nil, 2*60*60*1000)
+       if !c.IsValid() {
+               t.Errorf("cert is valid")
+       }
+}
+
+func TestNeedRefresh(t *testing.T) {
+       c := &Cert{}
+       if !c.NeedRefresh() {
+               t.Errorf("cert is need refresh")
+       }
+
+       c.Cert = &x509.Certificate{}
+       if !c.NeedRefresh() {
+               t.Errorf("cert is need refresh")
+       }
+
+       c.CertPem = "test"
+       if !c.NeedRefresh() {
+               t.Errorf("cert is need refresh")
+       }
+
+       c.PrivateKey, _ = rsa.GenerateKey(rand.Reader, 4096)
+       if !c.NeedRefresh() {
+               t.Errorf("cert is need refresh")
+       }
+
+       c.Cert.NotBefore = time.Now().Add(1 * time.Hour)
+       if !c.NeedRefresh() {
+               t.Errorf("cert is not need refresh")
+       }
+
+       c.Cert.NotBefore = time.Now().Add(-1 * time.Hour)
+       c.Cert.NotAfter = time.Now().Add(-1 * time.Hour)
+       if !c.NeedRefresh() {
+               t.Errorf("cert is not need refresh")
+       }
+
+       c.Cert.NotBefore = time.Now().Add(-1 * time.Hour).Add(2 * 60 * -0.3 * 
time.Minute)
+       c.Cert.NotAfter = time.Now().Add(-1 * time.Hour).Add(2 * 60 * 0.7 * 
time.Minute)
+       if !c.NeedRefresh() {
+               t.Errorf("cert is need refresh")
+       }
+
+       c.Cert.NotAfter = time.Now().Add(1 * time.Hour)
+       if !c.NeedRefresh() {
+               t.Errorf("cert is need refresh")
+       }
+
+       c = GenerateAuthorityCert(nil, 2*60*60*1000)
+       if c.NeedRefresh() {
+               t.Errorf("cert is valid")
+       }
+}
+
+func TestGetTlsCert(t *testing.T) {
+       cert := GenerateAuthorityCert(nil, 2*60*60*1000)
+
+       tlsCert := cert.GetTlsCert()
+       if !reflect.DeepEqual(tlsCert.PrivateKey, cert.PrivateKey) {
+               t.Errorf("cert is not equal")
+       }
+
+       if tlsCert != cert.GetTlsCert() {
+               t.Errorf("cert is not equal")
+       }
+}
+
+func TestGetServerCert(t *testing.T) {
+       cert := GenerateAuthorityCert(nil, 24*60*60*1000)
+
+       s := &Storage{
+               AuthorityCert: cert,
+               Mutex:         &sync.Mutex{},
+               CaValidity:    24 * 60 * 60 * 1000,
+               CertValidity:  2 * 60 * 60 * 1000,
+       }
+
+       c := s.GetServerCert("localhost")
+
+       pool := x509.NewCertPool()
+       pool.AddCert(cert.Cert)
+       certificate, err := x509.ParseCertificate(c.Certificate[0])
+       if err != nil {
+               t.Errorf(err.Error())
+               return
+       }
+
+       _, err = certificate.Verify(x509.VerifyOptions{
+               Roots:   pool,
+               DNSName: "localhost",
+       })
+
+       if err != nil {
+               t.Errorf(err.Error())
+               return
+       }
+
+       if c != s.GetServerCert("localhost") {
+               t.Errorf("cert is not equal")
+       }
+
+       if c != s.GetServerCert("") {
+               t.Errorf("cert is not equal")
+       }
+
+       c = s.GetServerCert("newhost")
+
+       pool = x509.NewCertPool()
+       pool.AddCert(cert.Cert)
+       certificate, err = x509.ParseCertificate(c.Certificate[0])
+       if err != nil {
+               t.Errorf(err.Error())
+               return
+       }
+
+       _, err = certificate.Verify(x509.VerifyOptions{
+               Roots:   pool,
+               DNSName: "localhost",
+       })
+
+       if err != nil {
+               t.Errorf(err.Error())
+               return
+       }
+
+       _, err = certificate.Verify(x509.VerifyOptions{
+               Roots:   pool,
+               DNSName: "newhost",
+       })
+
+       if err != nil {
+               t.Errorf(err.Error())
+               return
+       }
+}
+
+func TestRefreshServerCert(t *testing.T) {
+       logger.Init()
+       s := NewStorage(&config.Options{
+               CaValidity:   24 * 60 * 60 * 1000,
+               CertValidity: 10,
+       })
+       s.AuthorityCert = GenerateAuthorityCert(nil, 24*60*60*1000)
+
+       go s.RefreshServerCert()
+
+       c := s.GetServerCert("localhost")
+       origin := s.ServerCerts
+
+       for i := 0; i < 100; i++ {
+               // at most 10s
+               time.Sleep(100 * time.Millisecond)
+               if origin != s.ServerCerts {
+                       break
+               }
+       }
+
+       if c == s.GetServerCert("localhost") {
+               t.Errorf("cert is not equal")
+       }
+
+       if reflect.DeepEqual(c, s.GetServerCert("localhost")) {
+               t.Errorf("cert is not equal")
+       }
+
+       s.StopChan <- os.Kill
+}
diff --git a/ca/pkg/cert/util.go b/ca/pkg/cert/util.go
index 9c038d8..90f5fc4 100644
--- a/ca/pkg/cert/util.go
+++ b/ca/pkg/cert/util.go
@@ -30,16 +30,11 @@ import (
 
 func DecodeCert(cert string) *x509.Certificate {
        block, _ := pem.Decode([]byte(cert))
-       p, err := x509.ParseCertificate(block.Bytes)
-       if err != nil {
-               logger.Sugar.Warnf("Failed to parse public key. " + err.Error())
+       if block == nil {
+               logger.Sugar.Warnf("Failed to parse public key.")
                return nil
        }
-       return p
-}
-
-func DecodePub(cert string) *rsa.PublicKey {
-       p, err := x509.ParsePKCS1PublicKey([]byte(cert))
+       p, err := x509.ParseCertificate(block.Bytes)
        if err != nil {
                logger.Sugar.Warnf("Failed to parse public key. " + err.Error())
                return nil
@@ -47,9 +42,12 @@ func DecodePub(cert string) *rsa.PublicKey {
        return p
 }
 
-func DecodePri(cert string) *rsa.PrivateKey {
+func DecodePrivateKey(cert string) *rsa.PrivateKey {
        block, _ := pem.Decode([]byte(cert))
-
+       if block == nil {
+               logger.Sugar.Warnf("Failed to parse private key.")
+               return nil
+       }
        p, err := x509.ParsePKCS1PrivateKey(block.Bytes)
        if err != nil {
                logger.Sugar.Warnf("Failed to parse private key. " + 
err.Error())
@@ -58,11 +56,11 @@ func DecodePri(cert string) *rsa.PrivateKey {
        return p
 }
 
-func CreateCA(rootCert *Cert, caValidity int64) *Cert {
+func GenerateAuthorityCert(rootCert *Cert, caValidity int64) *Cert {
        cert := &x509.Certificate{
                SerialNumber: big.NewInt(2019),
                Subject: pkix.Name{
-                       CommonName:   "Dubbo",
+                       CommonName:   "Dubbo RA",
                        Organization: []string{"Apache Dubbo"},
                },
                Issuer: pkix.Name{
@@ -77,12 +75,12 @@ func CreateCA(rootCert *Cert, caValidity int64) *Cert {
                BasicConstraintsValid: true,
        }
 
-       caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
+       privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
        if err != nil {
                log.Fatal(err)
        }
 
-       caBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, 
&caPrivKey.PublicKey, caPrivKey)
+       caBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, 
&privateKey.PublicKey, privateKey)
        if err != nil {
                log.Fatal(err)
        }
@@ -98,14 +96,14 @@ func CreateCA(rootCert *Cert, caValidity int64) *Cert {
        }
 
        return &Cert{
-               Cert:       cert,
+               Cert:       DecodeCert(caPEM.String()),
                CertPem:    caPEM.String(),
-               PrivateKey: caPrivKey,
+               PrivateKey: privateKey,
        }
 }
 
 func SignServerCert(authorityCert *Cert, serverName []string, certValidity 
int64) *Cert {
-       privKey, err := rsa.GenerateKey(rand.Reader, 4096)
+       privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
        if err != nil {
                log.Fatal(err)
        }
@@ -124,7 +122,7 @@ func SignServerCert(authorityCert *Cert, serverName 
[]string, certValidity int64
        }
        cert.DNSNames = serverName
 
-       c, err := x509.CreateCertificate(rand.Reader, cert, authorityCert.Cert, 
&privKey.PublicKey, authorityCert.PrivateKey)
+       c, err := x509.CreateCertificate(rand.Reader, cert, authorityCert.Cert, 
&privateKey.PublicKey, authorityCert.PrivateKey)
 
        certPem := new(bytes.Buffer)
        err = pem.Encode(certPem, &pem.Block{
@@ -138,8 +136,40 @@ func SignServerCert(authorityCert *Cert, serverName 
[]string, certValidity int64
        return &Cert{
                Cert:       cert,
                CertPem:    certPem.String(),
-               PrivateKey: privKey,
+               PrivateKey: privateKey,
+       }
+}
+
+func GenerateCSR() (string, *rsa.PrivateKey, error) {
+       csrTemplate := x509.CertificateRequest{
+               Subject: pkix.Name{
+                       CommonName:   "Dubbo",
+                       Organization: []string{"Apache Dubbo"},
+               },
+       }
+
+       privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
+       if err != nil {
+               log.Fatal(err)
+               return "", nil, err
+       }
+
+       csrBytes, err := x509.CreateCertificateRequest(rand.Reader, 
&csrTemplate, privateKey)
+       if err != nil {
+               return "", nil, err
+       }
+
+       csr := new(bytes.Buffer)
+       err = pem.Encode(csr, &pem.Block{
+               Type:  "CERTIFICATE REQUEST",
+               Bytes: csrBytes,
+       })
+
+       if err != nil {
+               logger.Sugar.Warnf("Failed to encode certificate. " + 
err.Error())
+               return "", nil, err
        }
+       return csr.String(), privateKey, nil
 }
 
 func LoadCSR(csrString string) (*x509.CertificateRequest, error) {
diff --git a/ca/pkg/cert/util_test.go b/ca/pkg/cert/util_test.go
new file mode 100644
index 0000000..e07069d
--- /dev/null
+++ b/ca/pkg/cert/util_test.go
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cert
+
+import "testing"
+
+func TestCSR(t *testing.T) {
+       csr, privateKey, err := GenerateCSR()
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       request, err := LoadCSR(csr)
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       cert := GenerateAuthorityCert(nil, 365*24*60*60*1000)
+
+       target, err := SignFromCSR(request, cert, 365*24*60*60*1000)
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       certificate := DecodeCert(target)
+
+       check := &Cert{
+               Cert:       certificate,
+               PrivateKey: privateKey,
+               CertPem:    target,
+       }
+
+       if !check.IsValid() {
+               t.Fatal("Cert is not valid")
+               return
+       }
+}
+
+func TestDecodeCert(t *testing.T) {
+       if DecodeCert("") != nil {
+               t.Fatal("DecodeCert should return nil")
+               return
+       }
+
+       if DecodeCert("123") != nil {
+               t.Fatal("DecodeCert should return nil")
+               return
+       }
+
+       if DecodeCert("-----BEGIN CERTIFICATE-----\n"+
+               "123\n"+
+               "-----END CERTIFICATE-----") != nil {
+               t.Fatal("DecodeCert should return nil")
+               return
+       }
+
+       if DecodeCert("-----BEGIN CERTIFICATE-----\n"+
+               
"MIICSjCCAbOgAwIBAgIJAJHGGR4dGioHMA0GCSqGSIb3DQEBCwUAMFYxCzAJBgNV\n"+
+               
"BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX\n"+
+               
"aWRnaXRzIFB0eSBMdGQxDzANBgNVBAMTBnRlc3RjYTAeFw0xNDExMTEyMjMxMjla\n"+
+               
"Fw0yNDExMDgyMjMxMjlaMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0\n"+
+               
"YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMT\n"+
+               
"BnRlc3RjYTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAwEDfBV5MYdlHVHJ7\n"+
+               
"+L4nxrZy7mBfAVXpOc5vMYztssUI7mL2/iYujiIXM+weZYNTEpLdjyJdu7R5gGUu\n"+
+               
"g1jSVK/EPHfc74O7AyZU34PNIP4Sh33N+/A5YexrNgJlPY+E3GdVYi4ldWJjgkAd\n"+
+               
"Qah2PH5ACLrIIC6tRka9hcaBlIECAwEAAaMgMB4wDAYDVR0TBAUwAwEB/zAOBgNV\n"+
+               
"HQ8BAf8EBAMCAgQwDQYJKoZIhvcNAQELBQADgYEAHzC7jdYlzAVmddi/gdAeKPau\n"+
+               
"sPBG/C2HCWqHzpCUHcKuvMzDVkY/MP2o6JIW2DBbY64bO/FceExhjcykgaYtCH/m\n"+
+               
"oIU63+CFOTtR7otyQAWHqXa7q4SbCDlG7DyRFxqG0txPtGvy12lgldA2+RgcigQG\n"+
+               "Dfcog5wrJytaQ6UA0wE=\n"+
+               "-----END CERTIFICATE-----\n") == nil {
+               t.Fatal("DecodeCert should not return nil")
+               return
+       }
+}
+
+func TestDecodePrivateKey(t *testing.T) {
+       if DecodePrivateKey("") != nil {
+               t.Fatal("DecodePrivateKey should return nil")
+               return
+       }
+
+       if DecodePrivateKey("123") != nil {
+               t.Fatal("DecodePrivateKey should return nil")
+               return
+       }
+
+       if DecodePrivateKey("-----BEGIN PRIVATE KEY-----\n"+
+               "123\n"+
+               "-----END PRIVATE KEY-----\n") != nil {
+               t.Fatal("DecodePrivateKey should return nil")
+               return
+       }
+
+       if DecodePrivateKey("-----BEGIN PRIVATE KEY-----\n"+
+               
"MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAMBA3wVeTGHZR1Ry\n"+
+               
"e/i+J8a2cu5gXwFV6TnObzGM7bLFCO5i9v4mLo4iFzPsHmWDUxKS3Y8iXbu0eYBl\n"+
+               
"LoNY0lSvxDx33O+DuwMmVN+DzSD+Eod9zfvwOWHsazYCZT2PhNxnVWIuJXViY4JA\n"+
+               
"HUGodjx+QAi6yCAurUZGvYXGgZSBAgMBAAECgYAxRi8i9BlFlufGSBVoGmydbJOm\n"+
+               
"bwLKl9dP3o33ODSP9hok5y6A0w5plWk3AJSF1hPLleK9VcSKYGYnt0clmPVHF35g\n"+
+               
"bx2rVK8dOT0mn7rz9Zr70jcSz1ETA2QonHZ+Y+niLmcic9At6hRtWiewblUmyFQm\n"+
+               
"GwggIzi7LOyEUHrEcQJBAOXxyQvnLvtKzXiqcsW/K6rExqVJVk+KF0fzzVyMzTJx\n"+
+               
"HRBxUVgvGdEJT7j+7P2kcTyafve0BBzDSPIaDyiJ+Y0CQQDWCb7jASFSbu5M3Zcd\n"+
+               
"Gkr4ZKN1XO3VLQX10b22bQYdF45hrTN2tnzRvVUR4q86VVnXmiGiTqmLkXcA2WWf\n"+
+               
"pHfFAkAhv9olUBo6MeF0i3frBEMRfm41hk0PwZHnMqZ6pgPcGnQMnMU2rzsXzkkQ\n"+
+               
"OwJnvAIOxhJKovZTjmofdqmw5odlAkBYVUdRWjsNUTjJwj3GRf6gyq/nFMYWz3EB\n"+
+               
"RWFdM1ttkDYzu45ctO2IhfHg4sPceDMO1s6AtKQmNI9/azkUjITdAkApNa9yFRzc\n"+
+               
"TBaDNPd5KVd58LVIzoPQ6i7uMHteLXJUWqSroji6S3s4gKMFJ/dO+ZXIlgQgfJJJ\n"+
+               "ZDL4cdrdkeoM\n"+
+               "-----END PRIVATE KEY-----\n") != nil {
+               t.Fatal("DecodePrivateKey should return nil")
+               return
+       }
+
+       if DecodePrivateKey("-----BEGIN RSA PRIVATE KEY-----\n"+
+               
"MIIEpgIBAAKCAQEAwQl8A5KYyOmXsz+Mk05NLWS9jHDhvJC1ekWgqOApwrb0Ecio\n"+
+               
"tv5dirqAtuEX+dGRVftxJdtZHWto+gKy3H6Ae866FBFt7TWgTZFkt0XW3tMmUmNG\n"+
+               
"bdzHAuZGK9+RlNNTNBTZJAx338kxM7/lqqOgEZig5SmX2Xt3u+DQjJPlsWB/lKDD\n"+
+               
"OKOc93lGo/8chdmMv70inE/xv6LQ9nugRvBe1XfXafuHEUVyj2rzF1v9y7yF5Tek\n"+
+               
"70wK/KV+O7ukBRc4SPwJ7YAWuofMhFneNtWGNHYaLShJBhvC+E7JXD+prJfHNdSc\n"+
+               
"ORnTz/LjMWsLbD1lhr/p7vrWXujDSGM6ZDR6EwIDAQABAoIBAQCjqjPwH4HUjmDl\n"+
+               
"RBMe7bt3qjsfcLGjm5mSQqh1piEiCtYioduR01ZiAcCRzYTzdWBg4x/Ktg/3ZpMJ\n"+
+               
"rfISCltLHTodO63U+auhOI2I6fjE0YdjQPJ8wTwmVDDYj+Qxp36a4LY93yhfn4hM\n"+
+               
"1P2XUMWtRZfc1AgAB7O7ol+PYPHVEX4n9ugbRDkn7/hpi05JPAOnGNimKDi61PpS\n"+
+               
"rWpkAKYCC6q2hLTOW+EKvfNqUjuK/YAzPQD14zP7KRQ9kkezAluwwVbwwaI2jJ4x\n"+
+               
"n6jHwPMOH1eKTQMtUg6Xxv59jBrcPmtD38dZvzzjBZDZYu4xcWJeeY4oP8/UE7uE\n"+
+               
"pTFACvBRAoGBAMlErLppeVZcgcW9gOP98qV2CbWK/+aB3oJQEQBlKB3UtQSjCnZ7\n"+
+               
"lLzxgMtDD+tcPocY5FY52MzJQ2UFgScSzW04JuBQPbsHcGmuzv/cahuB/S+xwB6m\n"+
+               
"I2RXbFkgPPirJ9mqTeuNMwcXgAhoVbPV3otMq45EsxHubATit7QvczabAoGBAPWH\n"+
+               
"yt0uxcf/j2k7EH3Ug5TkVKI8IftCM0fRs9eUzy2zPKTVRdTbQY75J/E2zkEQat2B\n"+
+               
"8hEONkkV/ScLV5Oon4oeBxCRq17h37H5znkW2yNYSMNLcqUN58ZcVxsRSPj/Eoq5\n"+
+               
"Ngotll+JmITrxtd6NpFcGqrDQ/KV9uM1AoqN4EXpAoGBALAXeLRD8dhAaX4TdgCD\n"+
+               
"v9dKNeZzLb+EYqRK3wUke/vVjWb4KwBM0W6aMWAlVXlLpJ1YhvZ1+Bv7/w4UydHg\n"+
+               
"3oCvfzwEmG3ZbV3ZhtxPATr9+QHQl9F49EAnSPGVhiLexKfpG/F6AWo0Al3Ywxrr\n"+
+               
"hKEFvJdlvfJzUmjX33gzh67/AoGBAMAnqBJ2On+NeFUozn1LxjbOg5X8bbPQWYXJ\n"+
+               
"jnAXnBTuA3YVG3O8rJASWroi5ERzbs8wlZvXfZCxTtAxxjZfb4yOd4T2HCJDr+f/\n"+
+               
"0yFdS99bhoahE3YtbckGF32th2inZ4F99db9WoQmkWDljVax5ObaKFygORsvVmr2\n"+
+               
"36hD5NORAoGBALKQZ6j9RYCC3NiV46K6GhN7RMu70u+/ET2skEdChn3Mu0OTYzqa\n"+
+               
"+qOCXvV+RWEiUoa2JX7UkSagEs+1O404afQv2+qnhdUOskxzUD+smQJBGOrXmdMq\n"+
+               "ubzSn24LsPYWYGWsgl3AJ+n8rmVMXgPaWZQD9qHkZD9Oe2wwI9W+4K74\n"+
+               "-----END RSA PRIVATE KEY-----\n") == nil {
+               t.Fatal("DecodePrivateKey should not return nil")
+               return
+       }
+}
diff --git a/ca/pkg/k8s/client.go b/ca/pkg/k8s/client.go
index 737610a..fa3636a 100644
--- a/ca/pkg/k8s/client.go
+++ b/ca/pkg/k8s/client.go
@@ -29,11 +29,23 @@ import (
        "path/filepath"
 )
 
-type Client struct {
+type Client interface {
+       Init() bool
+       GetAuthorityCert(namespace string) (string, string)
+       UpdateAuthorityCert(cert string, pri string, namespace string)
+       UpdateAuthorityPublicKey(cert string) bool
+       VerifyServiceAccount(token string) bool
+}
+
+type ClientImpl struct {
        kubeClient *kubernetes.Clientset
 }
 
-func (c *Client) Init() bool {
+func NewClient() Client {
+       return &ClientImpl{}
+}
+
+func (c *ClientImpl) Init() bool {
        config, err := rest.InClusterConfig()
        if err != nil {
                logger.Sugar.Infof("Failed to load config from Pod. Will fall 
back to kube config file.")
@@ -64,7 +76,7 @@ func (c *Client) Init() bool {
        return true
 }
 
-func (c *Client) GetAuthorityCert(namespace string) (string, string) {
+func (c *ClientImpl) GetAuthorityCert(namespace string) (string, string) {
        s, err := c.kubeClient.CoreV1().Secrets(namespace).Get(context.TODO(), 
"dubbo-ca-secret", metav1.GetOptions{})
        if err != nil {
                logger.Sugar.Warnf("Unable to get authority cert secret from 
kubernetes. " + err.Error())
@@ -72,7 +84,7 @@ func (c *Client) GetAuthorityCert(namespace string) (string, 
string) {
        return string(s.Data["cert.pem"]), string(s.Data["pri.pem"])
 }
 
-func (c *Client) UpdateAuthorityCert(cert string, pri string, namespace 
string) {
+func (c *ClientImpl) UpdateAuthorityCert(cert string, pri string, namespace 
string) {
        s, err := c.kubeClient.CoreV1().Secrets(namespace).Get(context.TODO(), 
"dubbo-ca-secret", metav1.GetOptions{})
        if err != nil {
                logger.Sugar.Warnf("Unable to get ca secret from kubernetes. 
Will try to create. " + err.Error())
@@ -106,7 +118,7 @@ func (c *Client) UpdateAuthorityCert(cert string, pri 
string, namespace string)
        }
 }
 
-func (c *Client) UpdateAuthorityPublicKey(cert string) bool {
+func (c *ClientImpl) UpdateAuthorityPublicKey(cert string) bool {
        ns, err := c.kubeClient.CoreV1().Namespaces().List(context.TODO(), 
metav1.ListOptions{})
        if err != nil {
                logger.Sugar.Warnf("Failed to get namespaces. " + err.Error())
@@ -149,7 +161,7 @@ func (c *Client) UpdateAuthorityPublicKey(cert string) bool 
{
        return true
 }
 
-func (c *Client) VerifyServiceAccount(token string) bool {
+func (c *ClientImpl) VerifyServiceAccount(token string) bool {
        tokenReview := &k8sauth.TokenReview{
                Spec: k8sauth.TokenReviewSpec{
                        Token: token,
diff --git a/ca/pkg/security/server.go b/ca/pkg/security/server.go
index 744be74..1b71ee4 100644
--- a/ca/pkg/security/server.go
+++ b/ca/pkg/security/server.go
@@ -28,36 +28,42 @@ import (
        "log"
        "math"
        "net"
+       "os"
        "strconv"
-       "sync"
        "time"
 )
 
 type Server struct {
+       StopChan chan os.Signal
+
        Options     *config.Options
        CertStorage *cert.Storage
 
-       KubeClient *k8s.Client
+       KubeClient k8s.Client
 
        CertificateServer *v1alpha1.DubboCertificateServiceServerImpl
        PlainServer       *grpc.Server
        SecureServer      *grpc.Server
 }
 
+func NewServer(options *config.Options) *Server {
+       return &Server{
+               Options:  options,
+               StopChan: make(chan os.Signal, 1),
+       }
+}
+
 func (s *Server) Init() {
        // TODO bypass k8s work
-       s.KubeClient = &k8s.Client{}
+       if s.KubeClient == nil {
+               s.KubeClient = k8s.NewClient()
+       }
        if !s.KubeClient.Init() {
                panic("Failed to create kubernetes client.")
        }
 
-       s.CertStorage = &cert.Storage{
-               AuthorityCert: &cert.Cert{},
-               TrustedCert:   []*cert.Cert{},
-               Mutex:         &sync.Mutex{},
-               CertValidity:  s.Options.CertValidity,
-               CaValidity:    s.Options.CaValidity,
-       }
+       s.CertStorage = cert.NewStorage(s.Options)
+       s.StopChan = s.StopChan
        go s.CertStorage.RefreshServerCert()
 
        // TODO inject pod based on Webhook
@@ -99,7 +105,7 @@ func (s *Server) LoadAuthorityCert() {
        if certStr != "" && priStr != "" {
                s.CertStorage.AuthorityCert.Cert = cert.DecodeCert(certStr)
                s.CertStorage.AuthorityCert.CertPem = certStr
-               s.CertStorage.AuthorityCert.PrivateKey = cert.DecodePri(priStr)
+               s.CertStorage.AuthorityCert.PrivateKey = 
cert.DecodePrivateKey(priStr)
        }
 
        s.RefreshAuthorityCert()
@@ -114,7 +120,7 @@ func (s *Server) ScheduleRefreshAuthorityCert() {
                        logger.Sugar.Infof("Authority cert is invalid, refresh 
it.")
                        // TODO lock if multi server
                        // TODO refresh signed cert
-                       s.CertStorage.AuthorityCert = 
cert.CreateCA(s.CertStorage.RootCert, s.Options.CaValidity)
+                       s.CertStorage.AuthorityCert = 
cert.GenerateAuthorityCert(s.CertStorage.RootCert, s.Options.CaValidity)
                        
s.KubeClient.UpdateAuthorityCert(s.CertStorage.AuthorityCert.CertPem, 
cert.EncodePri(s.CertStorage.AuthorityCert.PrivateKey), s.Options.Namespace)
                        if 
s.KubeClient.UpdateAuthorityPublicKey(s.CertStorage.AuthorityCert.CertPem) {
                                logger.Sugar.Infof("Write ca to config maps 
success.")
@@ -122,6 +128,13 @@ func (s *Server) ScheduleRefreshAuthorityCert() {
                                logger.Sugar.Warnf("Write ca to config maps 
failed.")
                        }
                }
+
+               select {
+               case <-s.StopChan:
+                       return
+               default:
+                       continue
+               }
        }
 }
 
@@ -130,7 +143,7 @@ func (s *Server) RefreshAuthorityCert() {
                logger.Sugar.Infof("Load authority cert from kubernetes secrect 
success.")
        } else {
                logger.Sugar.Warnf("Load authority cert from kubernetes secrect 
failed.")
-               s.CertStorage.AuthorityCert = 
cert.CreateCA(s.CertStorage.RootCert, s.Options.CaValidity)
+               s.CertStorage.AuthorityCert = 
cert.GenerateAuthorityCert(s.CertStorage.RootCert, s.Options.CaValidity)
 
                // TODO lock if multi server
                
s.KubeClient.UpdateAuthorityCert(s.CertStorage.AuthorityCert.CertPem, 
cert.EncodePri(s.CertStorage.AuthorityCert.PrivateKey), s.Options.Namespace)
diff --git a/ca/pkg/v1alpha1/ca_impl.go b/ca/pkg/v1alpha1/ca_impl.go
index ac23c89..1e8e218 100644
--- a/ca/pkg/v1alpha1/ca_impl.go
+++ b/ca/pkg/v1alpha1/ca_impl.go
@@ -31,11 +31,24 @@ type DubboCertificateServiceServerImpl struct {
        UnimplementedDubboCertificateServiceServer
        Options     *config.Options
        CertStorage *cert.Storage
-       KubeClient  *k8s.Client
+       KubeClient  k8s.Client
 }
 
 func (s *DubboCertificateServiceServerImpl) CreateCertificate(c 
context.Context, req *DubboCertificateRequest) (*DubboCertificateResponse, 
error) {
-       csr, _ := cert.LoadCSR(req.Csr)
+       if req.Csr == "" {
+               return &DubboCertificateResponse{
+                       Success: false,
+                       Message: "CSR is empty.",
+               }, nil
+       }
+
+       csr, err := cert.LoadCSR(req.Csr)
+       if csr == nil || err != nil {
+               return &DubboCertificateResponse{
+                       Success: false,
+                       Message: "Decode csr failed.",
+               }, nil
+       }
        p, _ := peer.FromContext(c)
 
        if s.Options.EnableKubernetes {
@@ -78,13 +91,6 @@ func (s *DubboCertificateServiceServerImpl) 
CreateCertificate(c context.Context,
        }
 
        // TODO check server token
-       if csr == nil {
-               logger.Sugar.Warnf("Failed to decode csr. RemoteAddr: %s", 
p.Addr.String())
-               return &DubboCertificateResponse{
-                       Success: false,
-                       Message: "Failed to read csr",
-               }, nil
-       }
        certPem, err := cert.SignFromCSR(csr, s.CertStorage.AuthorityCert, 
s.Options.CertValidity)
        if err != nil {
                logger.Sugar.Warnf("Failed to sign certificate from csr: %v. 
RemoteAddr: %s", err, p.Addr.String())
diff --git a/ca/pkg/v1alpha1/ca_impl_test.go b/ca/pkg/v1alpha1/ca_impl_test.go
new file mode 100644
index 0000000..e3de6bf
--- /dev/null
+++ b/ca/pkg/v1alpha1/ca_impl_test.go
@@ -0,0 +1,276 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package v1alpha1
+
+import (
+       "github.com/apache/dubbo-admin/ca/pkg/cert"
+       "github.com/apache/dubbo-admin/ca/pkg/config"
+       "github.com/apache/dubbo-admin/ca/pkg/k8s"
+       "github.com/apache/dubbo-admin/ca/pkg/logger"
+       "golang.org/x/net/context"
+       "google.golang.org/grpc/metadata"
+       "google.golang.org/grpc/peer"
+       "net"
+       "testing"
+)
+
+type MockKubeClient struct {
+       k8s.Client
+}
+
+func (c MockKubeClient) VerifyServiceAccount(token string) bool {
+       return "expceted-token" == token
+}
+
+type fakeAddr struct {
+       net.Addr
+}
+
+func (f *fakeAddr) String() string {
+       return ""
+}
+
+func TestCSRFailed(t *testing.T) {
+       logger.Init()
+
+       md := metadata.MD{}
+       md["authorization"] = []string{"Bearer 123"}
+       c := metadata.NewIncomingContext(context.TODO(), metadata.MD{})
+       c = peer.NewContext(c, &peer.Peer{Addr: &fakeAddr{}})
+
+       options := &config.Options{
+               EnableKubernetes: false,
+               CertValidity:     24 * 60 * 60 * 1000,
+               CaValidity:       365 * 24 * 60 * 60 * 1000,
+       }
+       storage := cert.NewStorage(options)
+       storage.AuthorityCert = cert.GenerateAuthorityCert(nil, 
options.CaValidity)
+
+       kubeClient := &MockKubeClient{}
+       impl := &DubboCertificateServiceServerImpl{
+               Options:     options,
+               CertStorage: storage,
+               KubeClient:  kubeClient.Client,
+       }
+
+       certificate, err := impl.CreateCertificate(c, &DubboCertificateRequest{
+               Csr: "",
+       })
+
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       if certificate.Success {
+               t.Fatal("Should sign failed")
+               return
+       }
+
+       certificate, err = impl.CreateCertificate(c, &DubboCertificateRequest{
+               Csr: "123",
+       })
+
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       if certificate.Success {
+               t.Fatal("Should sign failed")
+               return
+       }
+
+       certificate, err = impl.CreateCertificate(c, &DubboCertificateRequest{
+               Csr: "-----BEGIN CERTIFICATE-----\n" +
+                       "123\n" +
+                       "-----END CERTIFICATE-----",
+       })
+
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       if certificate.Success {
+               t.Fatal("Should sign failed")
+               return
+       }
+}
+
+func TestTokenFailed(t *testing.T) {
+       logger.Init()
+
+       p := peer.NewContext(context.TODO(), &peer.Peer{Addr: &fakeAddr{}})
+
+       options := &config.Options{
+               EnableKubernetes: true,
+               CertValidity:     24 * 60 * 60 * 1000,
+               CaValidity:       365 * 24 * 60 * 60 * 1000,
+       }
+       storage := cert.NewStorage(options)
+       storage.AuthorityCert = cert.GenerateAuthorityCert(nil, 
options.CaValidity)
+
+       kubeClient := &MockKubeClient{}
+       impl := &DubboCertificateServiceServerImpl{
+               Options:     options,
+               CertStorage: storage,
+               KubeClient:  kubeClient,
+       }
+
+       csr, privateKey, err := cert.GenerateCSR()
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       certificate, err := impl.CreateCertificate(p, &DubboCertificateRequest{
+               Csr: csr,
+       })
+
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       if certificate.Success {
+               t.Fatal("Should sign failed")
+               return
+       }
+
+       md := metadata.MD{}
+       md["authorization"] = []string{"123"}
+       c := metadata.NewIncomingContext(p, md)
+
+       certificate, err = impl.CreateCertificate(c, &DubboCertificateRequest{
+               Csr: csr,
+       })
+
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       if certificate.Success {
+               t.Fatal("Should sign failed")
+               return
+       }
+
+       md = metadata.MD{}
+       md["authorization"] = []string{"Bearer 123"}
+       c = metadata.NewIncomingContext(p, md)
+
+       certificate, err = impl.CreateCertificate(c, &DubboCertificateRequest{
+               Csr: csr,
+       })
+
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       if certificate.Success {
+               t.Fatal("Should sign failed")
+               return
+       }
+
+       md = metadata.MD{}
+       md["authorization"] = []string{"Bearer expceted-token"}
+       c = metadata.NewIncomingContext(p, md)
+
+       certificate, err = impl.CreateCertificate(c, &DubboCertificateRequest{
+               Csr: csr,
+       })
+
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       if !certificate.Success {
+               t.Fatal("Sign failed")
+               return
+       }
+
+       generatedCert := cert.DecodeCert(certificate.CertPem)
+       c2 := &cert.Cert{
+               Cert:       generatedCert,
+               CertPem:    certificate.CertPem,
+               PrivateKey: privateKey,
+       }
+
+       if !c2.IsValid() {
+               t.Fatal("Cert is not valid")
+               return
+       }
+
+}
+
+func TestSuccess(t *testing.T) {
+       logger.Init()
+
+       md := metadata.MD{}
+       md["authorization"] = []string{"Bearer 123"}
+       c := metadata.NewIncomingContext(context.TODO(), metadata.MD{})
+       c = peer.NewContext(c, &peer.Peer{Addr: &fakeAddr{}})
+
+       options := &config.Options{
+               EnableKubernetes: false,
+               CertValidity:     24 * 60 * 60 * 1000,
+               CaValidity:       365 * 24 * 60 * 60 * 1000,
+       }
+       storage := cert.NewStorage(options)
+       storage.AuthorityCert = cert.GenerateAuthorityCert(nil, 
options.CaValidity)
+
+       kubeClient := &MockKubeClient{}
+       impl := &DubboCertificateServiceServerImpl{
+               Options:     options,
+               CertStorage: storage,
+               KubeClient:  kubeClient,
+       }
+
+       csr, privateKey, err := cert.GenerateCSR()
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       certificate, err := impl.CreateCertificate(c, &DubboCertificateRequest{
+               Csr: csr,
+       })
+
+       if err != nil {
+               t.Fatal(err)
+               return
+       }
+
+       if !certificate.Success {
+               t.Fatal("Sign failed")
+               return
+       }
+
+       generatedCert := cert.DecodeCert(certificate.CertPem)
+       c2 := &cert.Cert{
+               Cert:       generatedCert,
+               CertPem:    certificate.CertPem,
+               PrivateKey: privateKey,
+       }
+
+       if !c2.IsValid() {
+               t.Fatal("Cert is not valid")
+               return
+       }
+}

Reply via email to