This is an automated email from the ASF dual-hosted git repository. ronething pushed a commit to branch feat/host_check in repository https://gitbox.apache.org/repos/asf/apisix-ingress-controller.git
commit 92f9396e22a7fb8c2139c98a8606ecd073bf2058 Author: Ashing Zheng <[email protected]> AuthorDate: Tue Oct 14 15:30:53 2025 +0800 feat: block when the same host using different cert Signed-off-by: Ashing Zheng <[email protected]> --- internal/adc/translator/apisixtls.go | 5 +- internal/adc/translator/apisixupstream.go | 3 +- internal/adc/translator/gateway.go | 70 +--- internal/adc/translator/ingress.go | 5 +- internal/controller/indexer/tlsroute.go | 3 +- internal/provider/init/init.go | 3 +- internal/provider/register.go | 3 +- internal/ssl/util.go | 130 ++++++ internal/webhook/v1/apisixtls_webhook.go | 33 +- internal/webhook/v1/gateway_webhook.go | 27 ++ internal/webhook/v1/ingress_webhook.go | 27 ++ internal/webhook/v1/ssl/conflict_detector.go | 473 ++++++++++++++++++++++ internal/webhook/v1/ssl/conflict_detector_test.go | 301 ++++++++++++++ 13 files changed, 1006 insertions(+), 77 deletions(-) diff --git a/internal/adc/translator/apisixtls.go b/internal/adc/translator/apisixtls.go index 2f05facf..ef46f255 100644 --- a/internal/adc/translator/apisixtls.go +++ b/internal/adc/translator/apisixtls.go @@ -27,6 +27,7 @@ import ( "github.com/apache/apisix-ingress-controller/internal/controller/label" "github.com/apache/apisix-ingress-controller/internal/id" "github.com/apache/apisix-ingress-controller/internal/provider" + sslutils "github.com/apache/apisix-ingress-controller/internal/ssl" ) func (t *Translator) TranslateApisixTls(tctx *provider.TranslateContext, tls *apiv2.ApisixTls) (*TranslateResult, error) { @@ -43,7 +44,7 @@ func (t *Translator) TranslateApisixTls(tctx *provider.TranslateContext, tls *ap } // Extract cert and key from secret - cert, key, err := extractKeyPair(secret, true) + cert, key, err := sslutils.ExtractKeyPair(secret, true) if err != nil { return nil, err } @@ -80,7 +81,7 @@ func (t *Translator) TranslateApisixTls(tctx *provider.TranslateContext, tls *ap return nil, fmt.Errorf("client CA secret %s not found", caSecretKey.String()) } - ca, _, err := extractKeyPair(caSecret, false) + ca, _, err := sslutils.ExtractKeyPair(caSecret, false) if err != nil { return nil, err } diff --git a/internal/adc/translator/apisixupstream.go b/internal/adc/translator/apisixupstream.go index 86a39e62..33e626fe 100644 --- a/internal/adc/translator/apisixupstream.go +++ b/internal/adc/translator/apisixupstream.go @@ -29,6 +29,7 @@ import ( "github.com/apache/apisix-ingress-controller/api/adc" apiv2 "github.com/apache/apisix-ingress-controller/api/v2" "github.com/apache/apisix-ingress-controller/internal/provider" + sslutils "github.com/apache/apisix-ingress-controller/internal/ssl" "github.com/apache/apisix-ingress-controller/internal/utils" ) @@ -187,7 +188,7 @@ func translateApisixUpstreamClientTLS(tctx *provider.TranslateContext, config *a return errors.Errorf("sercret %s not found", secretNN) } - cert, key, err := extractKeyPair(secret, true) + cert, key, err := sslutils.ExtractKeyPair(secret, true) if err != nil { return err } diff --git a/internal/adc/translator/gateway.go b/internal/adc/translator/gateway.go index db284845..d83507aa 100644 --- a/internal/adc/translator/gateway.go +++ b/internal/adc/translator/gateway.go @@ -18,14 +18,11 @@ package translator import ( - "crypto/x509" "encoding/json" - "encoding/pem" "fmt" "slices" "github.com/pkg/errors" - corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" @@ -34,6 +31,7 @@ import ( "github.com/apache/apisix-ingress-controller/internal/controller/label" "github.com/apache/apisix-ingress-controller/internal/id" "github.com/apache/apisix-ingress-controller/internal/provider" + sslutils "github.com/apache/apisix-ingress-controller/internal/ssl" internaltypes "github.com/apache/apisix-ingress-controller/internal/types" "github.com/apache/apisix-ingress-controller/internal/utils" ) @@ -99,7 +97,7 @@ func (t *Translator) translateSecret(tctx *provider.TranslateContext, listener g t.Log.Error(errors.New("secret data is nil"), "failed to get secret data", "secret", secretNN) return nil, fmt.Errorf("no secret data found for %s/%s", ns, name) } - cert, key, err := extractKeyPair(secret, true) + cert, key, err := sslutils.ExtractKeyPair(secret, true) if err != nil { t.Log.Error(err, "extract key pair", "secret", secretNN) return nil, err @@ -112,7 +110,7 @@ func (t *Translator) translateSecret(tctx *provider.TranslateContext, listener g if listener.Hostname != nil && *listener.Hostname != "" { sslObj.Snis = append(sslObj.Snis, string(*listener.Hostname)) } else { - hosts, err := extractHost(cert) + hosts, err := sslutils.ExtractHostsFromCertificate(cert) if err != nil { return nil, err } @@ -140,68 +138,6 @@ func (t *Translator) translateSecret(tctx *provider.TranslateContext, listener g return sslObjs, nil } -func extractHost(cert []byte) ([]string, error) { - block, _ := pem.Decode(cert) - if block == nil { - return nil, errors.New("parse certificate: not in PEM format") - } - der, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return nil, errors.Wrap(err, "parse certificate") - } - hosts := make([]string, 0, len(der.DNSNames)) - for _, dnsName := range der.DNSNames { - if dnsName != "*" { - hosts = append(hosts, dnsName) - } - } - return hosts, nil -} - -func extractKeyPair(s *corev1.Secret, hasPrivateKey bool) ([]byte, []byte, error) { - if _, ok := s.Data["cert"]; ok { - return extractApisixSecretKeyPair(s, hasPrivateKey) - } else if _, ok := s.Data[corev1.TLSCertKey]; ok { - return extractKubeSecretKeyPair(s, hasPrivateKey) - } else if ca, ok := s.Data[corev1.ServiceAccountRootCAKey]; ok && !hasPrivateKey { - return ca, nil, nil - } else { - return nil, nil, errors.New("unknown secret format") - } -} - -func extractApisixSecretKeyPair(s *corev1.Secret, hasPrivateKey bool) (cert []byte, key []byte, err error) { - var ok bool - cert, ok = s.Data["cert"] - if !ok { - return nil, nil, errors.New("missing cert field") - } - - if hasPrivateKey { - key, ok = s.Data["key"] - if !ok { - return nil, nil, errors.New("missing key field") - } - } - return -} - -func extractKubeSecretKeyPair(s *corev1.Secret, hasPrivateKey bool) (cert []byte, key []byte, err error) { - var ok bool - cert, ok = s.Data[corev1.TLSCertKey] - if !ok { - return nil, nil, errors.New("missing cert field") - } - - if hasPrivateKey { - key, ok = s.Data[corev1.TLSPrivateKeyKey] - if !ok { - return nil, nil, errors.New("missing key field") - } - } - return -} - // fillPluginsFromGatewayProxy fill plugins from GatewayProxy to given plugins func (t *Translator) fillPluginsFromGatewayProxy(plugins adctypes.GlobalRule, gatewayProxy *v1alpha1.GatewayProxy) { if gatewayProxy == nil { diff --git a/internal/adc/translator/ingress.go b/internal/adc/translator/ingress.go index f17b159f..35c7e447 100644 --- a/internal/adc/translator/ingress.go +++ b/internal/adc/translator/ingress.go @@ -30,19 +30,20 @@ import ( "github.com/apache/apisix-ingress-controller/internal/controller/label" "github.com/apache/apisix-ingress-controller/internal/id" "github.com/apache/apisix-ingress-controller/internal/provider" + sslutils "github.com/apache/apisix-ingress-controller/internal/ssl" internaltypes "github.com/apache/apisix-ingress-controller/internal/types" ) func (t *Translator) translateIngressTLS(ingressTLS *networkingv1.IngressTLS, secret *corev1.Secret, labels map[string]string) (*adctypes.SSL, error) { // extract the key pair from the secret - cert, key, err := extractKeyPair(secret, true) + cert, key, err := sslutils.ExtractKeyPair(secret, true) if err != nil { return nil, err } hosts := ingressTLS.Hosts if len(hosts) == 0 { - certHosts, err := extractHost(cert) + certHosts, err := sslutils.ExtractHostsFromCertificate(cert) if err != nil { return nil, err } diff --git a/internal/controller/indexer/tlsroute.go b/internal/controller/indexer/tlsroute.go index 567131c4..acef5317 100644 --- a/internal/controller/indexer/tlsroute.go +++ b/internal/controller/indexer/tlsroute.go @@ -20,10 +20,11 @@ package indexer import ( "context" - internaltypes "github.com/apache/apisix-ingress-controller/internal/types" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" gatewayv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" + + internaltypes "github.com/apache/apisix-ingress-controller/internal/types" ) func setupTLSRouteIndexer(mgr ctrl.Manager) error { diff --git a/internal/provider/init/init.go b/internal/provider/init/init.go index b6ed9e99..be21c07d 100644 --- a/internal/provider/init/init.go +++ b/internal/provider/init/init.go @@ -18,11 +18,12 @@ package init import ( + "github.com/go-logr/logr" + "github.com/apache/apisix-ingress-controller/internal/controller/status" "github.com/apache/apisix-ingress-controller/internal/manager/readiness" "github.com/apache/apisix-ingress-controller/internal/provider" "github.com/apache/apisix-ingress-controller/internal/provider/apisix" - "github.com/go-logr/logr" ) func init() { diff --git a/internal/provider/register.go b/internal/provider/register.go index fddb1af5..a9feb032 100644 --- a/internal/provider/register.go +++ b/internal/provider/register.go @@ -21,9 +21,10 @@ import ( "fmt" "net/http" + "github.com/go-logr/logr" + "github.com/apache/apisix-ingress-controller/internal/controller/status" "github.com/apache/apisix-ingress-controller/internal/manager/readiness" - "github.com/go-logr/logr" ) type RegisterHandler interface { diff --git a/internal/ssl/util.go b/internal/ssl/util.go new file mode 100644 index 00000000..64393ac9 --- /dev/null +++ b/internal/ssl/util.go @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssl + +import ( + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "encoding/pem" + "errors" + + corev1 "k8s.io/api/core/v1" +) + +var ( + // ErrUnknownSecretFormat indicates the secret does not contain supported TLS data keys. + ErrUnknownSecretFormat = errors.New("unknown secret format") + // ErrMissingCert indicates the secret is missing the certificate part. + ErrMissingCert = errors.New("missing cert field") + // ErrMissingKey indicates the secret is missing the private key part when it is required. + ErrMissingKey = errors.New("missing key field") + // ErrInvalidPEM is returned when the provided certificate is not valid PEM encoded data. + ErrInvalidPEM = errors.New("certificate is not valid PEM data") +) + +// ExtractKeyPair extracts the certificate and, optionally, the private key from a Secret. +// +// Supported formats: +// 1. APISIX style: data keys `cert` and `key` +// 2. Kubernetes TLS secret: data keys `tls.crt` and `tls.key` +// 3. Kubernetes CA secret: data key `ca.crt` (without private key) +func ExtractKeyPair(secret *corev1.Secret, includePrivateKey bool) ([]byte, []byte, error) { + if secret == nil { + return nil, nil, ErrMissingCert + } + + if cert, ok := secret.Data["cert"]; ok { + if includePrivateKey { + key, ok := secret.Data["key"] + if !ok { + return nil, nil, ErrMissingKey + } + return cert, key, nil + } + return cert, nil, nil + } + + if cert, ok := secret.Data[corev1.TLSCertKey]; ok { + if includePrivateKey { + key, ok := secret.Data[corev1.TLSPrivateKeyKey] + if !ok { + return nil, nil, ErrMissingKey + } + return cert, key, nil + } + return cert, nil, nil + } + + if cert, ok := secret.Data[corev1.ServiceAccountRootCAKey]; ok && !includePrivateKey { + return cert, nil, nil + } + + return nil, nil, ErrUnknownSecretFormat +} + +// ExtractCertificate extracts only the certificate data from a Secret. +func ExtractCertificate(secret *corev1.Secret) ([]byte, error) { + cert, _, err := ExtractKeyPair(secret, false) + return cert, err +} + +// ExtractHostsFromCertificate parses the certificate PEM block and returns the DNS names. +// +// Invalid or wildcard hosts are filtered out and the remaining hosts are lower-cased. +func ExtractHostsFromCertificate(certPEM []byte) ([]string, error) { + block, _ := pem.Decode(certPEM) + if block == nil { + return nil, ErrInvalidPEM + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, err + } + + hosts := make([]string, 0, len(cert.DNSNames)) + for _, dnsName := range cert.DNSNames { + if dnsName != "*" { + hosts = append(hosts, dnsName) + } + } + return hosts, nil +} + +// NormalizeHosts removes duplicate entries, lower-cases them, and filters out empty strings or wildcards. +func NormalizeHosts(hosts []string) []string { + if len(hosts) == 0 { + return nil + } + + normalized := make([]string, 0, len(hosts)) + seen := make(map[string]struct{}, len(hosts)) + for _, host := range hosts { + if _, ok := seen[host]; ok { + continue + } + seen[host] = struct{}{} + normalized = append(normalized, host) + } + return normalized +} + +// CertificateHash returns the SHA-256 hash of the certificate PEM bytes in hexadecimal form. +func CertificateHash(cert []byte) string { + sum := sha256.Sum256(cert) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/webhook/v1/apisixtls_webhook.go b/internal/webhook/v1/apisixtls_webhook.go index 16bcf88f..e45e2b84 100644 --- a/internal/webhook/v1/apisixtls_webhook.go +++ b/internal/webhook/v1/apisixtls_webhook.go @@ -30,6 +30,7 @@ import ( apisixv2 "github.com/apache/apisix-ingress-controller/api/v2" "github.com/apache/apisix-ingress-controller/internal/controller" "github.com/apache/apisix-ingress-controller/internal/webhook/v1/reference" + sslvalidator "github.com/apache/apisix-ingress-controller/internal/webhook/v1/ssl" ) var apisixTlsLog = logf.Log.WithName("apisixtls-resource") @@ -67,7 +68,21 @@ func (v *ApisixTlsCustomValidator) ValidateCreate(ctx context.Context, obj runti return nil, nil } - return v.collectWarnings(ctx, tls), nil + detector := sslvalidator.NewConflictDetector(v.Client) + mappings, mappingWarnings := detector.BuildApisixTlsMappings(ctx, tls) + warnings := v.collectWarnings(ctx, tls) + for _, warning := range mappingWarnings { + warnings = append(warnings, warning) + } + conflicts, err := detector.DetectConflicts(ctx, tls, mappings) + if err != nil { + return nil, err + } + if len(conflicts) > 0 { + return nil, fmt.Errorf("%s", sslvalidator.FormatConflicts(conflicts)) + } + + return warnings, nil } func (v *ApisixTlsCustomValidator) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) { @@ -80,7 +95,21 @@ func (v *ApisixTlsCustomValidator) ValidateUpdate(ctx context.Context, oldObj, n return nil, nil } - return v.collectWarnings(ctx, tls), nil + detector := sslvalidator.NewConflictDetector(v.Client) + mappings, mappingWarnings := detector.BuildApisixTlsMappings(ctx, tls) + warnings := v.collectWarnings(ctx, tls) + for _, warning := range mappingWarnings { + warnings = append(warnings, warning) + } + conflicts, err := detector.DetectConflicts(ctx, tls, mappings) + if err != nil { + return nil, err + } + if len(conflicts) > 0 { + return nil, fmt.Errorf("%s", sslvalidator.FormatConflicts(conflicts)) + } + + return warnings, nil } func (*ApisixTlsCustomValidator) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) { diff --git a/internal/webhook/v1/gateway_webhook.go b/internal/webhook/v1/gateway_webhook.go index bb21b236..9b5096e1 100644 --- a/internal/webhook/v1/gateway_webhook.go +++ b/internal/webhook/v1/gateway_webhook.go @@ -33,6 +33,7 @@ import ( v1alpha1 "github.com/apache/apisix-ingress-controller/api/v1alpha1" internaltypes "github.com/apache/apisix-ingress-controller/internal/types" "github.com/apache/apisix-ingress-controller/internal/webhook/v1/reference" + sslvalidator "github.com/apache/apisix-ingress-controller/internal/webhook/v1/ssl" ) // nolint:unused @@ -89,6 +90,19 @@ func (v *GatewayCustomValidator) ValidateCreate(ctx context.Context, obj runtime warnings := v.warnIfMissingGatewayProxyForGateway(ctx, gateway) warnings = append(warnings, v.collectReferenceWarnings(ctx, gateway)...) + detector := sslvalidator.NewConflictDetector(v.Client) + mappings, mappingWarnings := detector.BuildGatewayMappings(ctx, gateway) + for _, warning := range mappingWarnings { + warnings = append(warnings, warning) + } + conflicts, err := detector.DetectConflicts(ctx, gateway, mappings) + if err != nil { + return nil, err + } + if len(conflicts) > 0 { + return nil, fmt.Errorf("%s", sslvalidator.FormatConflicts(conflicts)) + } + return warnings, nil } @@ -112,6 +126,19 @@ func (v *GatewayCustomValidator) ValidateUpdate(ctx context.Context, oldObj, new warnings := v.warnIfMissingGatewayProxyForGateway(ctx, gateway) warnings = append(warnings, v.collectReferenceWarnings(ctx, gateway)...) + detector := sslvalidator.NewConflictDetector(v.Client) + mappings, mappingWarnings := detector.BuildGatewayMappings(ctx, gateway) + for _, warning := range mappingWarnings { + warnings = append(warnings, warning) + } + conflicts, err := detector.DetectConflicts(ctx, gateway, mappings) + if err != nil { + return nil, err + } + if len(conflicts) > 0 { + return nil, fmt.Errorf("%s", sslvalidator.FormatConflicts(conflicts)) + } + return warnings, nil } diff --git a/internal/webhook/v1/ingress_webhook.go b/internal/webhook/v1/ingress_webhook.go index 10e18ab4..b74073b6 100644 --- a/internal/webhook/v1/ingress_webhook.go +++ b/internal/webhook/v1/ingress_webhook.go @@ -31,6 +31,7 @@ import ( "github.com/apache/apisix-ingress-controller/internal/controller" "github.com/apache/apisix-ingress-controller/internal/webhook/v1/reference" + sslvalidator "github.com/apache/apisix-ingress-controller/internal/webhook/v1/ssl" ) var ingresslog = logf.Log.WithName("ingress-resource") @@ -146,6 +147,19 @@ func (v *IngressCustomValidator) ValidateCreate(ctx context.Context, obj runtime warnings := checkUnsupportedAnnotations(ingress) warnings = append(warnings, v.collectReferenceWarnings(ctx, ingress)...) + detector := sslvalidator.NewConflictDetector(v.Client) + mappings, mappingWarnings := detector.BuildIngressMappings(ctx, ingress) + for _, warning := range mappingWarnings { + warnings = append(warnings, warning) + } + conflicts, err := detector.DetectConflicts(ctx, ingress, mappings) + if err != nil { + return nil, err + } + if len(conflicts) > 0 { + return nil, fmt.Errorf("%s", sslvalidator.FormatConflicts(conflicts)) + } + return warnings, nil } @@ -164,6 +178,19 @@ func (v *IngressCustomValidator) ValidateUpdate(ctx context.Context, oldObj, new warnings := checkUnsupportedAnnotations(ingress) warnings = append(warnings, v.collectReferenceWarnings(ctx, ingress)...) + detector := sslvalidator.NewConflictDetector(v.Client) + mappings, mappingWarnings := detector.BuildIngressMappings(ctx, ingress) + for _, warning := range mappingWarnings { + warnings = append(warnings, warning) + } + conflicts, err := detector.DetectConflicts(ctx, ingress, mappings) + if err != nil { + return nil, err + } + if len(conflicts) > 0 { + return nil, fmt.Errorf("%s", sslvalidator.FormatConflicts(conflicts)) + } + return warnings, nil } diff --git a/internal/webhook/v1/ssl/conflict_detector.go b/internal/webhook/v1/ssl/conflict_detector.go new file mode 100644 index 00000000..b2a0a796 --- /dev/null +++ b/internal/webhook/v1/ssl/conflict_detector.go @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssl + +import ( + "context" + "fmt" + "strings" + + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" + gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" + + v1alpha1 "github.com/apache/apisix-ingress-controller/api/v1alpha1" + apiv2 "github.com/apache/apisix-ingress-controller/api/v2" + "github.com/apache/apisix-ingress-controller/internal/controller" + "github.com/apache/apisix-ingress-controller/internal/controller/config" + "github.com/apache/apisix-ingress-controller/internal/controller/indexer" + sslutil "github.com/apache/apisix-ingress-controller/internal/ssl" + internaltypes "github.com/apache/apisix-ingress-controller/internal/types" +) + +var logger = log.Log.WithName("ssl-conflict-detector") + +// HostCertMapping represents the relationship between a host and its certificate hash. +type HostCertMapping struct { + Host string + CertificateHash string + ResourceRef string +} + +// SSLConflict exposes the conflict details to the admission webhook for reporting. +type SSLConflict struct { + Host string + ConflictingResource string + CertificateHash string +} + +// ConflictDetector detects SSL conflicts among Gateway, Ingress, and ApisixTls resources. +type ConflictDetector struct { + client client.Client + secretCache map[types.NamespacedName]*secretInfo +} + +type secretInfo struct { + hash string + hosts []string +} + +// NewConflictDetector creates a detector backed by the provided client. +func NewConflictDetector(c client.Client) *ConflictDetector { + return &ConflictDetector{ + client: c, + secretCache: make(map[types.NamespacedName]*secretInfo), + } +} + +// DetectConflicts returns the list of conflicts between the provided mappings and +// existing resources that are associated with the same GatewayProxy. Best-effort: +// failures while enumerating existing resources or reading Secrets will be logged +// and result in no conflicts instead of blocking the admission. +func (d *ConflictDetector) DetectConflicts(ctx context.Context, obj client.Object, newMappings []HostCertMapping) ([]SSLConflict, error) { + gatewayProxy, err := d.resolveGatewayProxy(ctx, obj) + if err != nil { + logger.Error(err, "failed to resolve GatewayProxy", "object", objectKey(obj)) + return nil, nil + } + if gatewayProxy == nil { + return nil, nil + } + + existingMappings, err := d.collectExistingMappings(ctx, gatewayProxy, obj.GetUID()) + if err != nil { + logger.Error(err, "failed to collect existing SSL mappings", "gatewayProxy", objectKey(gatewayProxy)) + return nil, nil + } + + conflicts := make([]SSLConflict, 0) + byHost := make(map[string]HostCertMapping, len(existingMappings)) + for _, mapping := range existingMappings { + if mapping.Host == "" || mapping.CertificateHash == "" { + continue + } + if existing, ok := byHost[mapping.Host]; ok { + if existing.CertificateHash == mapping.CertificateHash { + continue + } + // keep the first encountered mapping to surface a deterministic conflict reference + continue + } + byHost[mapping.Host] = mapping + } + + seen := make(map[string]string, len(newMappings)) + // TODO: need to check with self-referencing mappings + for _, mapping := range newMappings { + if mapping.Host == "" || mapping.CertificateHash == "" { + continue + } + if prev, ok := seen[mapping.Host]; ok { + // prefer the first hash when duplicates appear inside the same object + if prev != mapping.CertificateHash { + seen[mapping.Host] = mapping.CertificateHash + } + continue + } + seen[mapping.Host] = mapping.CertificateHash + } + + for host, hash := range seen { + existing, ok := byHost[host] + if !ok { + continue + } + if existing.CertificateHash == hash { + continue + } + conflicts = append(conflicts, SSLConflict{ + Host: host, + ConflictingResource: existing.ResourceRef, + CertificateHash: existing.CertificateHash, + }) + } + + return conflicts, nil +} + +// FormatConflicts renders a human-readable error message for multiple conflicts. +func FormatConflicts(conflicts []SSLConflict) string { + if len(conflicts) == 0 { + return "" + } + var sb strings.Builder + sb.WriteString("SSL configuration conflicts detected:") + for _, conflict := range conflicts { + sb.WriteString(fmt.Sprintf("\n- Host '%s' is already configured with a different certificate in %s", conflict.Host, conflict.ConflictingResource)) + } + return sb.String() +} + +// BuildGatewayMappings calculates host-to-certificate mappings for a Gateway. +func (d *ConflictDetector) BuildGatewayMappings(ctx context.Context, gateway *gatewayv1.Gateway) ([]HostCertMapping, []string) { + mappings := make([]HostCertMapping, 0) + warnings := make([]string, 0) + + if gateway == nil { + return mappings, warnings + } + + for _, listener := range gateway.Spec.Listeners { + if listener.TLS == nil || listener.TLS.CertificateRefs == nil { + continue + } + for _, ref := range listener.TLS.CertificateRefs { + if ref.Kind != nil && *ref.Kind != internaltypes.KindSecret { + continue + } + if ref.Group != nil && string(*ref.Group) != corev1.GroupName { + continue + } + secretNN := types.NamespacedName{ + Namespace: gateway.Namespace, + Name: string(ref.Name), + } + if ref.Namespace != nil && *ref.Namespace != "" { + secretNN.Namespace = string(*ref.Namespace) + } + + info, err := d.getSecretInfo(ctx, secretNN) + if err != nil { + logger.Error(err, "failed to read secret for Gateway", "gateway", objectKey(gateway), "secret", secretNN) + warnings = append(warnings, fmt.Sprintf("failed to read Secret %s for Gateway %s/%s: %v", secretNN, gateway.Namespace, gateway.Name, err)) + continue + } + + hosts := make([]string, 0, 1) + if listener.Hostname != nil && *listener.Hostname != "" { + hosts = append(hosts, string(*listener.Hostname)) + } + hosts = sslutil.NormalizeHosts(hosts) + if len(hosts) == 0 { + hosts = info.hosts + } + for _, host := range hosts { + mappings = append(mappings, HostCertMapping{ + Host: host, + CertificateHash: info.hash, + ResourceRef: fmt.Sprintf("Gateway/%s/%s", gateway.Namespace, gateway.Name), + }) + } + } + } + + return mappings, warnings +} + +// BuildIngressMappings calculates host-to-certificate mappings for an Ingress. +func (d *ConflictDetector) BuildIngressMappings(ctx context.Context, ingress *networkingv1.Ingress) ([]HostCertMapping, []string) { + mappings := make([]HostCertMapping, 0) + warnings := make([]string, 0) + if ingress == nil { + return mappings, warnings + } + + for _, tls := range ingress.Spec.TLS { + if tls.SecretName == "" { + continue + } + secretNN := types.NamespacedName{Namespace: ingress.Namespace, Name: tls.SecretName} + info, err := d.getSecretInfo(ctx, secretNN) + if err != nil { + logger.Error(err, "failed to read secret for Ingress", "ingress", objectKey(ingress), "secret", secretNN) + warnings = append(warnings, fmt.Sprintf("failed to read Secret %s for Ingress %s/%s: %v", secretNN, ingress.Namespace, ingress.Name, err)) + continue + } + + hosts := sslutil.NormalizeHosts(tls.Hosts) + if len(hosts) == 0 { + hosts = info.hosts + } + for _, host := range hosts { + mappings = append(mappings, HostCertMapping{ + Host: host, + CertificateHash: info.hash, + ResourceRef: fmt.Sprintf("Ingress/%s/%s", ingress.Namespace, ingress.Name), + }) + } + } + + return mappings, warnings +} + +// BuildApisixTlsMappings calculates host-to-certificate mappings for an ApisixTls resource. +func (d *ConflictDetector) BuildApisixTlsMappings(ctx context.Context, tls *apiv2.ApisixTls) ([]HostCertMapping, []string) { + mappings := make([]HostCertMapping, 0) + warnings := make([]string, 0) + if tls == nil { + return mappings, warnings + } + + secretNN := types.NamespacedName{ + Namespace: tls.Spec.Secret.Namespace, + Name: tls.Spec.Secret.Name, + } + info, err := d.getSecretInfo(ctx, secretNN) + if err != nil { + logger.Error(err, "failed to read secret for ApisixTls", "apisixtls", objectKey(tls), "secret", secretNN) + warnings = append(warnings, fmt.Sprintf("failed to read Secret %s for ApisixTls %s/%s: %v", secretNN, tls.Namespace, tls.Name, err)) + return mappings, warnings + } + + hosts := make([]string, 0, len(tls.Spec.Hosts)) + for _, host := range tls.Spec.Hosts { + hosts = append(hosts, string(host)) + } + hosts = sslutil.NormalizeHosts(hosts) + // NOTICE: hosts is required by the CRD, so this should never happen + // if len(hosts) == 0 { + // hosts = info.hosts + // } + for _, host := range hosts { + mappings = append(mappings, HostCertMapping{ + Host: host, + CertificateHash: info.hash, + ResourceRef: fmt.Sprintf("ApisixTls/%s/%s", tls.Namespace, tls.Name), + }) + } + + return mappings, warnings +} + +func (d *ConflictDetector) getSecretInfo(ctx context.Context, nn types.NamespacedName) (*secretInfo, error) { + if nn.Name == "" || nn.Namespace == "" { + return nil, fmt.Errorf("secret namespaced name is incomplete: %s", nn) + } + if info, ok := d.secretCache[nn]; ok { + return info, nil + } + + var secret corev1.Secret + if err := d.client.Get(ctx, nn, &secret); err != nil { + return nil, err + } + + cert, err := sslutil.ExtractCertificate(&secret) + if err != nil { + return nil, err + } + + hash := sslutil.CertificateHash(cert) + hosts, err := sslutil.ExtractHostsFromCertificate(cert) + if err != nil { + logger.Error(err, "failed to extract hosts from certificate", "secret", nn) + hosts = nil + } + info := &secretInfo{ + hash: hash, + hosts: sslutil.NormalizeHosts(hosts), + } + d.secretCache[nn] = info + return info, nil +} + +func (d *ConflictDetector) resolveGatewayProxy(ctx context.Context, obj client.Object) (*v1alpha1.GatewayProxy, error) { + switch resource := obj.(type) { + case *gatewayv1.Gateway: + return controller.GetGatewayProxyByGateway(ctx, d.client, resource) + case *networkingv1.Ingress: + ingressClass, err := controller.FindMatchingIngressClass(ctx, d.client, logger, resource) + if err != nil { + return nil, err + } + if ingressClass == nil { + return nil, nil + } + return controller.GetGatewayProxyByIngressClass(ctx, d.client, ingressClass) + case *apiv2.ApisixTls: + ingressClass, err := controller.FindMatchingIngressClass(ctx, d.client, logger, resource) + if err != nil { + return nil, err + } + if ingressClass == nil { + return nil, nil + } + return controller.GetGatewayProxyByIngressClass(ctx, d.client, ingressClass) + default: + return nil, fmt.Errorf("unsupported object type %T", obj) + } +} + +func (d *ConflictDetector) collectExistingMappings(ctx context.Context, gatewayProxy *v1alpha1.GatewayProxy, excludeUID types.UID) ([]HostCertMapping, error) { + mappings := make([]HostCertMapping, 0) + + if gatewayProxy == nil { + return mappings, nil + } + + indexKey := indexer.GenIndexKey(gatewayProxy.Namespace, gatewayProxy.Name) + + processedGateways := make(map[types.UID]struct{}) + var gatewayList gatewayv1.GatewayList + if err := d.client.List(ctx, &gatewayList, client.MatchingFields{indexer.ParametersRef: indexKey}); err != nil { + return nil, err + } + for i := range gatewayList.Items { + gateway := &gatewayList.Items[i] + if gateway.GetUID() == excludeUID { + continue + } + if _, ok := processedGateways[gateway.GetUID()]; ok { + continue + } + gatewayMappings, _ := d.BuildGatewayMappings(ctx, gateway) + mappings = append(mappings, gatewayMappings...) + processedGateways[gateway.GetUID()] = struct{}{} + } + + processedIngress := make(map[types.UID]struct{}) + processedTls := make(map[types.UID]struct{}) + defaultIngressClasses := make(map[string]struct{}) + + var ingressClassList networkingv1.IngressClassList + if err := d.client.List(ctx, &ingressClassList, client.MatchingFields{indexer.IngressClassParametersRef: indexKey}); err != nil { + return nil, err + } + for i := range ingressClassList.Items { + ingressClass := &ingressClassList.Items[i] + if controller.IsDefaultIngressClass(ingressClass) && ingressClass.Spec.Controller == config.ControllerConfig.ControllerName { + defaultIngressClasses[ingressClass.Name] = struct{}{} + } + + var ingressList networkingv1.IngressList + if err := d.client.List(ctx, &ingressList, client.MatchingFields{indexer.IngressClassRef: ingressClass.Name}); err != nil { + return nil, err + } + for j := range ingressList.Items { + ingress := &ingressList.Items[j] + if ingress.GetUID() == excludeUID { + continue + } + if _, ok := processedIngress[ingress.GetUID()]; ok { + continue + } + ingressMappings, _ := d.BuildIngressMappings(ctx, ingress) + mappings = append(mappings, ingressMappings...) + processedIngress[ingress.GetUID()] = struct{}{} + } + + var tlsList apiv2.ApisixTlsList + if err := d.client.List(ctx, &tlsList, client.MatchingFields{indexer.IngressClassRef: ingressClass.Name}); err != nil { + return nil, err + } + for j := range tlsList.Items { + tls := &tlsList.Items[j] + if tls.GetUID() == excludeUID { + continue + } + if _, ok := processedTls[tls.GetUID()]; ok { + continue + } + tlsMappings, _ := d.BuildApisixTlsMappings(ctx, tls) + mappings = append(mappings, tlsMappings...) + processedTls[tls.GetUID()] = struct{}{} + } + } + + if len(defaultIngressClasses) > 0 { + var allIngress networkingv1.IngressList + if err := d.client.List(ctx, &allIngress); err != nil { + return nil, err + } + for i := range allIngress.Items { + ingress := &allIngress.Items[i] + if ingress.Spec.IngressClassName != nil { + continue + } + if ingress.GetUID() == excludeUID { + continue + } + if _, ok := processedIngress[ingress.GetUID()]; ok { + continue + } + ingressMappings, _ := d.BuildIngressMappings(ctx, ingress) + mappings = append(mappings, ingressMappings...) + processedIngress[ingress.GetUID()] = struct{}{} + } + + var allTls apiv2.ApisixTlsList + if err := d.client.List(ctx, &allTls); err != nil { + return nil, err + } + for i := range allTls.Items { + tls := &allTls.Items[i] + if tls.Spec.IngressClassName != "" { + continue + } + if tls.GetUID() == excludeUID { + continue + } + if _, ok := processedTls[tls.GetUID()]; ok { + continue + } + tlsMappings, _ := d.BuildApisixTlsMappings(ctx, tls) + mappings = append(mappings, tlsMappings...) + processedTls[tls.GetUID()] = struct{}{} + } + } + + return mappings, nil +} + +func objectKey(obj client.Object) types.NamespacedName { + if obj == nil { + return types.NamespacedName{} + } + return types.NamespacedName{Namespace: obj.GetNamespace(), Name: obj.GetName()} +} diff --git a/internal/webhook/v1/ssl/conflict_detector_test.go b/internal/webhook/v1/ssl/conflict_detector_test.go new file mode 100644 index 00000000..d29df23f --- /dev/null +++ b/internal/webhook/v1/ssl/conflict_detector_test.go @@ -0,0 +1,301 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssl + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" + + v1alpha1 "github.com/apache/apisix-ingress-controller/api/v1alpha1" + apiv2 "github.com/apache/apisix-ingress-controller/api/v2" + "github.com/apache/apisix-ingress-controller/internal/controller/config" + "github.com/apache/apisix-ingress-controller/internal/controller/indexer" + internaltypes "github.com/apache/apisix-ingress-controller/internal/types" +) + +const ( + testNamespace = "default" + testIngressClass = "example-class" +) + +func TestConflictDetectorDetectsGatewayConflict(t *testing.T) { + scheme := buildScheme(t) + secretA := newTLSSecret(t, "cert-a", []string{"example.com"}) + secretB := newTLSSecret(t, "cert-b", []string{"example.com"}) + + gatewayProxy := &v1alpha1.GatewayProxy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "demo-gp", + Namespace: testNamespace, + }, + } + + modeTerminate := gatewayv1.TLSModeTerminate + hostname := gatewayv1.Hostname("example.com") + gateway := &gatewayv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "demo-gateway", + Namespace: testNamespace, + }, + Spec: gatewayv1.GatewaySpec{ + GatewayClassName: gatewayv1.ObjectName("demo-gc"), + Listeners: []gatewayv1.Listener{ + { + Name: "tls", + Protocol: gatewayv1.HTTPSProtocolType, + Port: 443, + Hostname: &hostname, + TLS: &gatewayv1.GatewayTLSConfig{ + Mode: &modeTerminate, + CertificateRefs: []gatewayv1.SecretObjectReference{ + {Name: gatewayv1.ObjectName(secretA.Name)}, + }, + }, + }, + }, + }, + } + gateway.Spec.Infrastructure = &gatewayv1.GatewayInfrastructure{ + ParametersRef: &gatewayv1.LocalParametersReference{ + Group: gatewayv1.Group(v1alpha1.GroupVersion.Group), + Kind: gatewayv1.Kind(internaltypes.KindGatewayProxy), + Name: gatewayProxy.Name, + }, + } + + ingressClass := &networkingv1.IngressClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: testIngressClass, + }, + Spec: networkingv1.IngressClassSpec{ + Controller: config.ControllerConfig.ControllerName, + Parameters: &networkingv1.IngressClassParametersReference{ + APIGroup: ptr.To(v1alpha1.GroupVersion.Group), + Kind: internaltypes.KindGatewayProxy, + Name: gatewayProxy.Name, + Namespace: func() *string { + ns := testNamespace + return &ns + }(), + }, + }, + } + + client := fake.NewClientBuilder(). + WithScheme(scheme). + WithIndex(&gatewayv1.Gateway{}, indexer.ParametersRef, indexer.GatewayParametersRefIndexFunc). + WithIndex(&networkingv1.IngressClass{}, indexer.IngressClassParametersRef, indexer.IngressClassParametersRefIndexFunc). + WithIndex(&networkingv1.Ingress{}, indexer.IngressClassRef, indexer.IngressClassRefIndexFunc). + WithIndex(&apiv2.ApisixTls{}, indexer.IngressClassRef, indexer.ApisixTlsIngressClassIndexFunc). + WithObjects(secretA, secretB, gatewayProxy, gateway, ingressClass). + Build() + + detector := NewConflictDetector(client) + ctx := context.Background() + + newTls := &apiv2.ApisixTls{ + ObjectMeta: metav1.ObjectMeta{ + Name: "incoming", + Namespace: testNamespace, + }, + Spec: apiv2.ApisixTlsSpec{ + IngressClassName: testIngressClass, + Hosts: []apiv2.HostType{"example.com"}, + Secret: apiv2.ApisixSecret{ + Name: secretB.Name, + Namespace: secretB.Namespace, + }, + }, + } + + mappings, warnings := detector.BuildApisixTlsMappings(ctx, newTls) + if len(warnings) != 0 { + t.Fatalf("expected no build warnings, got %v", warnings) + } + conflicts, err := detector.DetectConflicts(ctx, newTls, mappings) + if err != nil { + t.Fatalf("DetectConflicts returned error: %v", err) + } + if len(conflicts) != 1 { + t.Fatalf("expected 1 conflict, got %d", len(conflicts)) + } + conflict := conflicts[0] + if conflict.Host != "example.com" { + t.Fatalf("unexpected host: %s", conflict.Host) + } + expectedRef := fmt.Sprintf("Gateway/%s/%s", gateway.Namespace, gateway.Name) + if conflict.ConflictingResource != expectedRef { + t.Fatalf("unexpected conflicting resource: %s", conflict.ConflictingResource) + } +} + +func TestConflictDetectorAllowedWhenCertificateMatches(t *testing.T) { + scheme := buildScheme(t) + secret := newTLSSecret(t, "shared-cert", []string{"shared.example.com"}) + + gatewayProxy := &v1alpha1.GatewayProxy{ObjectMeta: metav1.ObjectMeta{Name: "gp", Namespace: testNamespace}} + modeTerminate := gatewayv1.TLSModeTerminate + listenerHostname := gatewayv1.Hostname("shared.example.com") + gateway := &gatewayv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{Name: "gw", Namespace: testNamespace}, + Spec: gatewayv1.GatewaySpec{ + GatewayClassName: gatewayv1.ObjectName("gc"), + Listeners: []gatewayv1.Listener{ + { + Name: "tls", + Protocol: gatewayv1.HTTPSProtocolType, + Port: 443, + Hostname: &listenerHostname, + TLS: &gatewayv1.GatewayTLSConfig{ + Mode: &modeTerminate, + CertificateRefs: []gatewayv1.SecretObjectReference{{Name: gatewayv1.ObjectName(secret.Name)}}, + }, + }, + }, + }, + } + gateway.Spec.Infrastructure = &gatewayv1.GatewayInfrastructure{ + ParametersRef: &gatewayv1.LocalParametersReference{ + Group: gatewayv1.Group(v1alpha1.GroupVersion.Group), + Kind: gatewayv1.Kind(internaltypes.KindGatewayProxy), + Name: gatewayProxy.Name, + }, + } + + ingressClass := &networkingv1.IngressClass{ + ObjectMeta: metav1.ObjectMeta{Name: testIngressClass}, + Spec: networkingv1.IngressClassSpec{ + Controller: config.ControllerConfig.ControllerName, + Parameters: &networkingv1.IngressClassParametersReference{ + APIGroup: ptr.To(v1alpha1.GroupVersion.Group), + Kind: internaltypes.KindGatewayProxy, + Name: gatewayProxy.Name, + Namespace: func() *string { + ns := testNamespace + return &ns + }(), + }, + }, + } + + client := fake.NewClientBuilder(). + WithScheme(scheme). + WithIndex(&gatewayv1.Gateway{}, indexer.ParametersRef, indexer.GatewayParametersRefIndexFunc). + WithIndex(&networkingv1.IngressClass{}, indexer.IngressClassParametersRef, indexer.IngressClassParametersRefIndexFunc). + WithIndex(&networkingv1.Ingress{}, indexer.IngressClassRef, indexer.IngressClassRefIndexFunc). + WithIndex(&apiv2.ApisixTls{}, indexer.IngressClassRef, indexer.ApisixTlsIngressClassIndexFunc). + WithObjects(secret, gatewayProxy, gateway, ingressClass). + Build() + + detector := NewConflictDetector(client) + ctx := context.Background() + + newTls := &apiv2.ApisixTls{ + ObjectMeta: metav1.ObjectMeta{Name: "allowed", Namespace: testNamespace}, + Spec: apiv2.ApisixTlsSpec{ + IngressClassName: testIngressClass, + Hosts: []apiv2.HostType{"shared.example.com"}, + Secret: apiv2.ApisixSecret{Name: secret.Name, Namespace: secret.Namespace}, + }, + } + + mappings, _ := detector.BuildApisixTlsMappings(ctx, newTls) + conflicts, err := detector.DetectConflicts(ctx, newTls, mappings) + if err != nil { + t.Fatalf("DetectConflicts returned error: %v", err) + } + if len(conflicts) != 0 { + t.Fatalf("expected no conflicts, got %v", conflicts) + } +} + +func buildScheme(t *testing.T) *runtime.Scheme { + scheme := runtime.NewScheme() + for _, add := range []func(*runtime.Scheme) error{ + corev1.AddToScheme, + networkingv1.AddToScheme, + gatewayv1.Install, + apiv2.AddToScheme, + v1alpha1.AddToScheme, + } { + if err := add(scheme); err != nil { + t.Fatalf("failed to add to scheme: %v", err) + } + } + return scheme +} + +func newTLSSecret(t *testing.T, name string, hosts []string) *corev1.Secret { + cert, key := generateCertificate(t, hosts) + return &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: testNamespace, + }, + Type: corev1.SecretTypeTLS, + Data: map[string][]byte{ + corev1.TLSCertKey: cert, + corev1.TLSPrivateKeyKey: key, + }, + } +} + +func generateCertificate(t *testing.T, hosts []string) ([]byte, []byte) { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate private key: %v", err) + } + serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) + if err != nil { + t.Fatalf("failed to generate serial: %v", err) + } + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: hosts[0], + }, + DNSNames: hosts, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + return certPEM, keyPEM +}
