This is an automated email from the ASF dual-hosted git repository.

vishesh pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/cloudstack-kubernetes-provider.git


The following commit(s) were added to refs/heads/main by this push:
     new abddd58a Add support to update the loadbalancer rule when source cidr 
list is updated (#86)
abddd58a is described below

commit abddd58a8909e7d6c24334e225a5ffe3163584a3
Author: Vishesh <[email protected]>
AuthorDate: Thu Dec 11 17:44:57 2025 +0530

    Add support to update the loadbalancer rule when source cidr list is 
updated (#86)
---
 cloudstack.go                   |  27 ++++
 cloudstack_loadbalancer.go      |  97 ++++++++++----
 cloudstack_loadbalancer_test.go | 286 ++++++++++++++++++++++++++++++++++++++++
 cloudstack_test.go              | 156 ++++++++++++++++++++++
 go.mod                          |   4 +-
 5 files changed, 541 insertions(+), 29 deletions(-)

diff --git a/cloudstack.go b/cloudstack.go
index 5ea15071..242ff7dc 100644
--- a/cloudstack.go
+++ b/cloudstack.go
@@ -25,8 +25,10 @@ import (
        "fmt"
        "io"
        "os"
+       "strings"
 
        "github.com/apache/cloudstack-go/v2/cloudstack"
+       "github.com/blang/semver/v4"
        "gopkg.in/gcfg.v1"
        metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
        "k8s.io/apimachinery/pkg/types"
@@ -54,6 +56,7 @@ type CSCloud struct {
        client        *cloudstack.CloudStackClient
        projectID     string // If non-"", all resources will be created within 
this project
        zone          string
+       version       semver.Version
        clientBuilder cloudprovider.ControllerClientBuilder
 }
 
@@ -87,6 +90,7 @@ func newCSCloud(cfg *CSConfig) (*CSCloud, error) {
        cs := &CSCloud{
                projectID: cfg.Global.ProjectID,
                zone:      cfg.Global.Zone,
+               version:   semver.Version{},
        }
 
        if cfg.Global.APIURL != "" && cfg.Global.APIKey != "" && 
cfg.Global.SecretKey != "" {
@@ -97,9 +101,32 @@ func newCSCloud(cfg *CSConfig) (*CSCloud, error) {
                return nil, errors.New("no cloud provider config given")
        }
 
+       version, err := cs.getManagementServerVersion()
+       if err != nil {
+               return nil, err
+       }
+       cs.version = version
+
        return cs, nil
 }
 
+func (cs *CSCloud) getManagementServerVersion() (semver.Version, error) {
+       msServersResp, err := 
cs.client.Management.ListManagementServersMetrics(cs.client.Management.NewListManagementServersMetricsParams())
+       if err != nil {
+               return semver.Version{}, err
+       }
+       if msServersResp.Count == 0 {
+               return semver.Version{}, errors.New("no management servers 
found")
+       }
+       version := msServersResp.ManagementServersMetrics[0].Version
+       v, err := semver.ParseTolerant(strings.Join(strings.Split(version, 
".")[0:3], "."))
+       if err != nil {
+               klog.Errorf("failed to parse management server version: %v", 
err)
+               return semver.Version{}, err
+       }
+       return v, nil
+}
+
 // Initialize passes a Kubernetes clientBuilder interface to the cloud provider
 func (cs *CSCloud) Initialize(clientBuilder 
cloudprovider.ControllerClientBuilder, stop <-chan struct{}) {
        cs.clientBuilder = clientBuilder
diff --git a/cloudstack_loadbalancer.go b/cloudstack_loadbalancer.go
index 9d51cdde..98de2fbd 100644
--- a/cloudstack_loadbalancer.go
+++ b/cloudstack_loadbalancer.go
@@ -27,6 +27,7 @@ import (
        "strings"
 
        "github.com/apache/cloudstack-go/v2/cloudstack"
+       "github.com/blang/semver/v4"
        "k8s.io/klog/v2"
 
        corev1 "k8s.io/api/core/v1"
@@ -44,7 +45,12 @@ const (
        // CloudStack >= 4.6 is required for it to work.
        ServiceAnnotationLoadBalancerProxyProtocol        = 
"service.beta.kubernetes.io/cloudstack-load-balancer-proxy-protocol"
        ServiceAnnotationLoadBalancerLoadbalancerHostname = 
"service.beta.kubernetes.io/cloudstack-load-balancer-hostname"
-       ServiceAnnotationLoadBalancerSourceCidrs          = 
"service.beta.kubernetes.io/cloudstack-load-balancer-source-cidrs"
+
+       // ServiceAnnotationLoadBalancerSourceCidrs is the annotation used on 
the
+       // service to specify the source CIDR list for a CloudStack load 
balancer.
+       // The CIDR list is a comma-separated list of CIDR ranges (e.g., 
"10.0.0.0/8,192.168.1.0/24").
+       // If not specified, the default is to allow all sources ("0.0.0.0/0").
+       ServiceAnnotationLoadBalancerSourceCidrs = 
"service.beta.kubernetes.io/cloudstack-load-balancer-source-cidrs"
 )
 
 type loadBalancer struct {
@@ -143,7 +149,7 @@ func (cs *CSCloud) EnsureLoadBalancer(ctx context.Context, 
clusterName string, s
                lbRuleName := fmt.Sprintf("%s-%s-%d", lb.name, protocol, 
port.Port)
 
                // If the load balancer rule exists and is up-to-date, we move 
on to the next rule.
-               lbRule, needsUpdate, err := 
lb.checkLoadBalancerRule(lbRuleName, port, protocol)
+               lbRule, needsUpdate, err := 
lb.checkLoadBalancerRule(lbRuleName, port, protocol, service, cs.version)
                if err != nil {
                        return nil, err
                }
@@ -151,7 +157,7 @@ func (cs *CSCloud) EnsureLoadBalancer(ctx context.Context, 
clusterName string, s
                if lbRule != nil {
                        if needsUpdate {
                                klog.V(4).Infof("Updating load balancer rule: 
%v", lbRuleName)
-                               if err := lb.updateLoadBalancerRule(lbRuleName, 
protocol); err != nil {
+                               if err := lb.updateLoadBalancerRule(lbRuleName, 
protocol, service, cs.version); err != nil {
                                        return nil, err
                                }
                                // Delete the rule from the map, to prevent it 
being deleted.
@@ -561,37 +567,84 @@ func (lb *loadBalancer) releaseLoadBalancerIP() error {
        return nil
 }
 
+func (lb *loadBalancer) getCIDRList(service *corev1.Service) ([]string, error) 
{
+       sourceCIDRs := getStringFromServiceAnnotation(service, 
ServiceAnnotationLoadBalancerSourceCidrs, defaultAllowedCIDR)
+       var cidrList []string
+       if sourceCIDRs != "" {
+               cidrList = strings.Split(sourceCIDRs, ",")
+               for i, cidr := range cidrList {
+                       cidr = strings.TrimSpace(cidr)
+                       if _, _, err := net.ParseCIDR(cidr); err != nil {
+                               return nil, fmt.Errorf("invalid CIDR %s in 
annotation %s: %w", cidr, ServiceAnnotationLoadBalancerSourceCidrs, err)
+                       }
+                       cidrList[i] = cidr
+               }
+       }
+       return cidrList, nil
+}
+
 // checkLoadBalancerRule checks if the rule already exists and if it does, if 
it can be updated. If
 // it does exist but cannot be updated, it will delete the existing rule so it 
can be created again.
-func (lb *loadBalancer) checkLoadBalancerRule(lbRuleName string, port 
corev1.ServicePort, protocol LoadBalancerProtocol) 
(*cloudstack.LoadBalancerRule, bool, error) {
+func (lb *loadBalancer) checkLoadBalancerRule(lbRuleName string, port 
corev1.ServicePort, protocol LoadBalancerProtocol, service *corev1.Service, 
version semver.Version) (*cloudstack.LoadBalancerRule, bool, error) {
        lbRule, ok := lb.rules[lbRuleName]
        if !ok {
                return nil, false, nil
        }
 
-       // Check if any of the values we cannot update (those that require a 
new load balancer rule) are changed.
-       if lbRule.Publicip == lb.ipAddr && lbRule.Privateport == 
strconv.Itoa(int(port.NodePort)) && lbRule.Publicport == 
strconv.Itoa(int(port.Port)) {
-               updateAlgo := lbRule.Algorithm != lb.algorithm
-               updateProto := lbRule.Protocol != protocol.CSProtocol()
-               return lbRule, updateAlgo || updateProto, nil
+       cidrList, err := lb.getCIDRList(service)
+       if err != nil {
+               return nil, false, err
        }
 
-       // Delete the load balancer rule so we can create a new one using the 
new values.
-       if err := lb.deleteLoadBalancerRule(lbRule); err != nil {
-               return nil, false, err
+       var lbRuleCidrList []string
+       if lbRule.Cidrlist != "" {
+               lbRuleCidrList = strings.Split(lbRule.Cidrlist, " ")
+               for i, cidr := range lbRuleCidrList {
+                       cidr = strings.TrimSpace(cidr)
+                       lbRuleCidrList[i] = cidr
+               }
        }
 
-       return nil, false, nil
+       // Check if basic properties match (IP and ports). If not, we need to 
recreate the rule.
+       basicPropsMatch := lbRule.Publicip == lb.ipAddr &&
+               lbRule.Privateport == strconv.Itoa(int(port.NodePort)) &&
+               lbRule.Publicport == strconv.Itoa(int(port.Port))
+
+       cidrListChanged := len(cidrList) != len(lbRuleCidrList) || 
!compareStringSlice(cidrList, lbRuleCidrList)
+
+       // Check if CIDR list also changed and version < 4.22, then we must 
recreate the rule.
+       if !basicPropsMatch || (cidrListChanged && 
version.LT(semver.Version{Major: 4, Minor: 22, Patch: 0})) {
+               // Delete the load balancer rule so we can create a new one 
using the new values.
+               if err := lb.deleteLoadBalancerRule(lbRule); err != nil {
+                       return nil, false, err
+               }
+               return nil, false, nil
+       }
+
+       // Rule can be updated. Check what needs updating.
+       updateAlgo := lbRule.Algorithm != lb.algorithm
+       updateProto := lbRule.Protocol != protocol.CSProtocol()
+
+       return lbRule, updateAlgo || updateProto || cidrListChanged, nil
 }
 
 // updateLoadBalancerRule updates a load balancer rule.
-func (lb *loadBalancer) updateLoadBalancerRule(lbRuleName string, protocol 
LoadBalancerProtocol) error {
+func (lb *loadBalancer) updateLoadBalancerRule(lbRuleName string, protocol 
LoadBalancerProtocol, service *corev1.Service, version semver.Version) error {
        lbRule := lb.rules[lbRuleName]
 
        p := lb.LoadBalancer.NewUpdateLoadBalancerRuleParams(lbRule.Id)
        p.SetAlgorithm(lb.algorithm)
        p.SetProtocol(protocol.CSProtocol())
 
+       // If version >= 4.22, we can update the CIDR list.
+       if version.GTE(semver.Version{Major: 4, Minor: 22, Patch: 0}) {
+               cidrList, err := lb.getCIDRList(service)
+               if err != nil {
+                       return err
+               }
+               p.SetCidrlist(cidrList)
+       }
+
        _, err := lb.LoadBalancer.UpdateLoadBalancerRule(p)
        return err
 }
@@ -613,19 +666,9 @@ func (lb *loadBalancer) createLoadBalancerRule(lbRuleName 
string, port corev1.Se
        p.SetOpenfirewall(false)
 
        // Read the source CIDR annotation
-       sourceCIDRs, ok := 
service.Annotations[ServiceAnnotationLoadBalancerSourceCidrs]
-       var cidrList []string
-       if ok && sourceCIDRs != "" {
-               cidrList = strings.Split(sourceCIDRs, ",")
-               for i, cidr := range cidrList {
-                       cidr = strings.TrimSpace(cidr)
-                       if _, _, err := net.ParseCIDR(cidr); err != nil {
-                               return nil, fmt.Errorf("invalid CIDR in 
annotation %s: %s", ServiceAnnotationLoadBalancerSourceCidrs, cidr)
-                       }
-                       cidrList[i] = cidr
-               }
-       } else {
-               cidrList = []string{defaultAllowedCIDR}
+       cidrList, err := lb.getCIDRList(service)
+       if err != nil {
+               return nil, err
        }
 
        // Set the CIDR list in the parameters
diff --git a/cloudstack_loadbalancer_test.go b/cloudstack_loadbalancer_test.go
index bbd63066..847361a0 100644
--- a/cloudstack_loadbalancer_test.go
+++ b/cloudstack_loadbalancer_test.go
@@ -20,10 +20,14 @@
 package cloudstack
 
 import (
+       "reflect"
        "sort"
+       "strings"
        "testing"
 
        "github.com/apache/cloudstack-go/v2/cloudstack"
+       "github.com/blang/semver/v4"
+       "go.uber.org/mock/gomock"
        corev1 "k8s.io/api/core/v1"
        metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
 )
@@ -478,3 +482,285 @@ func TestGetBoolFromServiceAnnotation(t *testing.T) {
                })
        }
 }
+
+func TestGetCIDRList(t *testing.T) {
+       tests := []struct {
+               name        string
+               annotations map[string]string
+               want        []string
+               wantErr     bool
+               errContains string
+               expectEmpty bool
+       }{
+               {
+                       name:        "defaults to allow all when annotation 
missing",
+                       annotations: nil,
+                       want:        []string{defaultAllowedCIDR},
+               },
+               {
+                       name: "trims and splits cidrs",
+                       annotations: map[string]string{
+                               ServiceAnnotationLoadBalancerSourceCidrs: 
"10.0.0.0/8, 192.168.0.0/16",
+                       },
+                       want: []string{"10.0.0.0/8", "192.168.0.0/16"},
+               },
+               {
+                       name: "empty annotation returns empty list",
+                       annotations: map[string]string{
+                               ServiceAnnotationLoadBalancerSourceCidrs: "",
+                       },
+                       expectEmpty: true,
+               },
+               {
+                       name: "invalid cidr returns error",
+                       annotations: map[string]string{
+                               ServiceAnnotationLoadBalancerSourceCidrs: 
"invalid-cidr",
+                       },
+                       wantErr:     true,
+                       errContains: "invalid CIDR",
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       lb := &loadBalancer{}
+                       svc := &corev1.Service{
+                               ObjectMeta: metav1.ObjectMeta{
+                                       Name:        "svc",
+                                       Namespace:   "default",
+                                       Annotations: tt.annotations,
+                               },
+                       }
+
+                       got, err := lb.getCIDRList(svc)
+                       if tt.wantErr {
+                               if err == nil {
+                                       t.Fatalf("expected error, got nil")
+                               }
+                               if tt.errContains != "" && 
!strings.Contains(err.Error(), tt.errContains) {
+                                       t.Fatalf("error = %v, expected to 
contain %q", err, tt.errContains)
+                               }
+                               return
+                       }
+
+                       if err != nil {
+                               t.Fatalf("unexpected error: %v", err)
+                       }
+
+                       if tt.expectEmpty {
+                               if len(got) != 0 {
+                                       t.Fatalf("expected empty CIDR list, got 
%v", got)
+                               }
+                               return
+                       }
+
+                       if !reflect.DeepEqual(got, tt.want) {
+                               t.Fatalf("getCIDRList() = %v, want %v", got, 
tt.want)
+                       }
+               })
+       }
+}
+
+func TestCheckLoadBalancerRule(t *testing.T) {
+       t.Run("rule not present returns nil", func(t *testing.T) {
+               lb := &loadBalancer{
+                       rules: map[string]*cloudstack.LoadBalancerRule{},
+               }
+               port := corev1.ServicePort{Port: 80, NodePort: 30000, Protocol: 
corev1.ProtocolTCP}
+               service := &corev1.Service{}
+
+               rule, needsUpdate, err := lb.checkLoadBalancerRule("missing", 
port, LoadBalancerProtocolTCP, service, semver.Version{})
+               if err != nil {
+                       t.Fatalf("unexpected error: %v", err)
+               }
+               if rule != nil {
+                       t.Fatalf("expected nil rule, got %v", rule)
+               }
+               if needsUpdate {
+                       t.Fatalf("expected needsUpdate to be false")
+               }
+       })
+
+       t.Run("basic property mismatch deletes rule", func(t *testing.T) {
+               ctrl := gomock.NewController(t)
+               t.Cleanup(ctrl.Finish)
+
+               mockLB := cloudstack.NewMockLoadBalancerServiceIface(ctrl)
+               deleteParams := &cloudstack.DeleteLoadBalancerRuleParams{}
+
+               gomock.InOrder(
+                       
mockLB.EXPECT().NewDeleteLoadBalancerRuleParams("rule-id").Return(deleteParams),
+                       
mockLB.EXPECT().DeleteLoadBalancerRule(deleteParams).Return(&cloudstack.DeleteLoadBalancerRuleResponse{},
 nil),
+               )
+
+               lb := &loadBalancer{
+                       CloudStackClient: &cloudstack.CloudStackClient{
+                               LoadBalancer: mockLB,
+                       },
+                       ipAddr: "1.1.1.1",
+                       rules: map[string]*cloudstack.LoadBalancerRule{
+                               "rule": {
+                                       Id:          "rule-id",
+                                       Name:        "rule",
+                                       Publicip:    "2.2.2.2",
+                                       Privateport: "30000",
+                                       Publicport:  "80",
+                                       Cidrlist:    defaultAllowedCIDR,
+                                       Algorithm:   "roundrobin",
+                                       Protocol:    
LoadBalancerProtocolTCP.CSProtocol(),
+                               },
+                       },
+               }
+               port := corev1.ServicePort{Port: 80, NodePort: 30000, Protocol: 
corev1.ProtocolTCP}
+               service := &corev1.Service{}
+
+               rule, needsUpdate, err := lb.checkLoadBalancerRule("rule", 
port, LoadBalancerProtocolTCP, service, semver.Version{Major: 4, Minor: 21, 
Patch: 0})
+               if err != nil {
+                       t.Fatalf("unexpected error: %v", err)
+               }
+               if rule != nil {
+                       t.Fatalf("expected nil rule after deletion, got %v", 
rule)
+               }
+               if needsUpdate {
+                       t.Fatalf("expected needsUpdate to be false")
+               }
+               if _, exists := lb.rules["rule"]; exists {
+                       t.Fatalf("expected rule entry to be removed from map")
+               }
+       })
+
+       t.Run("cidr change triggers update on supported version", func(t 
*testing.T) {
+               ctrl := gomock.NewController(t)
+               t.Cleanup(ctrl.Finish)
+
+               // No expectations on the mock; any delete call would fail the 
test.
+               mockLB := cloudstack.NewMockLoadBalancerServiceIface(ctrl)
+
+               lbRule := &cloudstack.LoadBalancerRule{
+                       Id:          "rule-id",
+                       Name:        "rule",
+                       Publicip:    "1.1.1.1",
+                       Privateport: "30000",
+                       Publicport:  "80",
+                       Cidrlist:    "10.0.0.0/8",
+                       Algorithm:   "roundrobin",
+                       Protocol:    LoadBalancerProtocolTCP.CSProtocol(),
+               }
+
+               lb := &loadBalancer{
+                       CloudStackClient: &cloudstack.CloudStackClient{
+                               LoadBalancer: mockLB,
+                       },
+                       ipAddr:    "1.1.1.1",
+                       algorithm: "roundrobin",
+                       rules: map[string]*cloudstack.LoadBalancerRule{
+                               "rule": lbRule,
+                       },
+               }
+               port := corev1.ServicePort{Port: 80, NodePort: 30000, Protocol: 
corev1.ProtocolTCP}
+               service := &corev1.Service{
+                       ObjectMeta: metav1.ObjectMeta{
+                               Annotations: map[string]string{
+                                       
ServiceAnnotationLoadBalancerSourceCidrs: "10.0.0.0/8,192.168.0.0/16",
+                               },
+                       },
+               }
+
+               rule, needsUpdate, err := lb.checkLoadBalancerRule("rule", 
port, LoadBalancerProtocolTCP, service, semver.Version{Major: 4, Minor: 22, 
Patch: 0})
+               if err != nil {
+                       t.Fatalf("unexpected error: %v", err)
+               }
+               if rule != lbRule {
+                       t.Fatalf("expected existing rule to be returned")
+               }
+               if !needsUpdate {
+                       t.Fatalf("expected needsUpdate to be true due to CIDR 
change")
+               }
+       })
+
+       t.Run("cidr change triggers delete with older version", func(t 
*testing.T) {
+               ctrl := gomock.NewController(t)
+               t.Cleanup(ctrl.Finish)
+
+               // No expectations on the mock; any delete or create call would 
fail the test.
+               mockLB := cloudstack.NewMockLoadBalancerServiceIface(ctrl)
+
+               deleteParams := &cloudstack.DeleteLoadBalancerRuleParams{}
+
+               gomock.InOrder(
+                       
mockLB.EXPECT().NewDeleteLoadBalancerRuleParams("rule-id").Return(deleteParams),
+                       
mockLB.EXPECT().DeleteLoadBalancerRule(deleteParams).Return(&cloudstack.DeleteLoadBalancerRuleResponse{},
 nil),
+               )
+
+               lbRule := &cloudstack.LoadBalancerRule{
+                       Id:          "rule-id",
+                       Name:        "rule",
+                       Publicip:    "1.1.1.1",
+                       Privateport: "30000",
+                       Publicport:  "80",
+                       Cidrlist:    "10.0.0.0/8",
+                       Algorithm:   "roundrobin",
+                       Protocol:    LoadBalancerProtocolTCP.CSProtocol(),
+               }
+
+               lb := &loadBalancer{
+                       CloudStackClient: &cloudstack.CloudStackClient{
+                               LoadBalancer: mockLB,
+                       },
+                       ipAddr:    "1.1.1.1",
+                       algorithm: "roundrobin",
+                       rules: map[string]*cloudstack.LoadBalancerRule{
+                               "rule": lbRule,
+                       },
+               }
+               port := corev1.ServicePort{Port: 80, NodePort: 30000, Protocol: 
corev1.ProtocolTCP}
+               service := &corev1.Service{
+                       ObjectMeta: metav1.ObjectMeta{
+                               Annotations: map[string]string{
+                                       
ServiceAnnotationLoadBalancerSourceCidrs: "10.0.0.0/8,192.168.0.0/16",
+                               },
+                       },
+               }
+
+               rule, needsUpdate, err := lb.checkLoadBalancerRule("rule", 
port, LoadBalancerProtocolTCP, service, semver.Version{Major: 4, Minor: 12, 
Patch: 0})
+               if err != nil {
+                       t.Fatalf("unexpected error: %v", err)
+               }
+               if rule != nil {
+                       t.Fatalf("expected nil rule after deletion, got %v", 
rule)
+               }
+               if needsUpdate {
+                       t.Fatalf("expected needsUpdate to be false due to CIDR 
change with older version")
+               }
+       })
+
+       t.Run("invalid cidr returns error", func(t *testing.T) {
+               lb := &loadBalancer{
+                       rules: map[string]*cloudstack.LoadBalancerRule{
+                               "rule": {
+                                       Id:          "rule-id",
+                                       Name:        "rule",
+                                       Publicip:    "1.1.1.1",
+                                       Privateport: "30000",
+                                       Publicport:  "80",
+                                       Cidrlist:    defaultAllowedCIDR,
+                                       Algorithm:   "roundrobin",
+                                       Protocol:    
LoadBalancerProtocolTCP.CSProtocol(),
+                               },
+                       },
+               }
+               port := corev1.ServicePort{Port: 80, NodePort: 30000, Protocol: 
corev1.ProtocolTCP}
+               service := &corev1.Service{
+                       ObjectMeta: metav1.ObjectMeta{
+                               Annotations: map[string]string{
+                                       
ServiceAnnotationLoadBalancerSourceCidrs: "bad-cidr",
+                               },
+                       },
+               }
+
+               _, _, err := lb.checkLoadBalancerRule("rule", port, 
LoadBalancerProtocolTCP, service, semver.Version{Major: 4, Minor: 22, Patch: 0})
+               if err == nil {
+                       t.Fatalf("expected error for invalid CIDR")
+               }
+       })
+}
diff --git a/cloudstack_test.go b/cloudstack_test.go
index 48bb8807..a83b45b8 100644
--- a/cloudstack_test.go
+++ b/cloudstack_test.go
@@ -21,11 +21,15 @@ package cloudstack
 
 import (
        "context"
+       "errors"
        "os"
        "strconv"
        "strings"
        "testing"
 
+       "github.com/apache/cloudstack-go/v2/cloudstack"
+       "github.com/blang/semver/v4"
+       "go.uber.org/mock/gomock"
        corev1 "k8s.io/api/core/v1"
        metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
 )
@@ -119,3 +123,155 @@ func TestLoadBalancer(t *testing.T) {
                t.Fatalf("GetLoadBalancer(\"noexist\") returned exists")
        }
 }
+
+func TestGetManagementServerVersion(t *testing.T) {
+       t.Run("returns parsed version", func(t *testing.T) {
+               ctrl := gomock.NewController(t)
+               t.Cleanup(ctrl.Finish)
+
+               mockMgmt := cloudstack.NewMockManagementServiceIface(ctrl)
+               params := &cloudstack.ListManagementServersMetricsParams{}
+               resp := &cloudstack.ListManagementServersMetricsResponse{
+                       Count: 1,
+                       ManagementServersMetrics: 
[]*cloudstack.ManagementServersMetric{
+                               {Version: "4.17.1.0"},
+                       },
+               }
+
+               gomock.InOrder(
+                       
mockMgmt.EXPECT().NewListManagementServersMetricsParams().Return(params),
+                       
mockMgmt.EXPECT().ListManagementServersMetrics(params).Return(resp, nil),
+               )
+
+               cs := &CSCloud{
+                       client: &cloudstack.CloudStackClient{
+                               Management: mockMgmt,
+                       },
+               }
+
+               version, err := cs.getManagementServerVersion()
+               if err != nil {
+                       t.Fatalf("unexpected error: %v", err)
+               }
+
+               expected := semver.MustParse("4.17.1")
+               if !version.Equals(expected) {
+                       t.Fatalf("version = %v, want %v", version, expected)
+               }
+       })
+
+       t.Run("returns correct parsed version with development server", func(t 
*testing.T) {
+               ctrl := gomock.NewController(t)
+               t.Cleanup(ctrl.Finish)
+
+               mockMgmt := cloudstack.NewMockManagementServiceIface(ctrl)
+               params := &cloudstack.ListManagementServersMetricsParams{}
+               resp := &cloudstack.ListManagementServersMetricsResponse{
+                       Count: 1,
+                       ManagementServersMetrics: 
[]*cloudstack.ManagementServersMetric{
+                               {Version: "4.17.1.0-SNAPSHOT"},
+                       },
+               }
+
+               gomock.InOrder(
+                       
mockMgmt.EXPECT().NewListManagementServersMetricsParams().Return(params),
+                       
mockMgmt.EXPECT().ListManagementServersMetrics(params).Return(resp, nil),
+               )
+
+               cs := &CSCloud{
+                       client: &cloudstack.CloudStackClient{
+                               Management: mockMgmt,
+                       },
+               }
+
+               version, err := cs.getManagementServerVersion()
+               if err != nil {
+                       t.Fatalf("unexpected error: %v", err)
+               }
+
+               expected := semver.MustParse("4.17.1")
+               if !version.Equals(expected) {
+                       t.Fatalf("version = %v, want %v", version, expected)
+               }
+       })
+
+       t.Run("returns error when api call fails", func(t *testing.T) {
+               ctrl := gomock.NewController(t)
+               t.Cleanup(ctrl.Finish)
+
+               mockMgmt := cloudstack.NewMockManagementServiceIface(ctrl)
+               params := &cloudstack.ListManagementServersMetricsParams{}
+               apiErr := errors.New("api failure")
+
+               gomock.InOrder(
+                       
mockMgmt.EXPECT().NewListManagementServersMetricsParams().Return(params),
+                       
mockMgmt.EXPECT().ListManagementServersMetrics(params).Return(nil, apiErr),
+               )
+
+               cs := &CSCloud{
+                       client: &cloudstack.CloudStackClient{
+                               Management: mockMgmt,
+                       },
+               }
+
+               if _, err := cs.getManagementServerVersion(); err == nil {
+                       t.Fatalf("expected error, got nil")
+               }
+       })
+
+       t.Run("returns error when no servers found", func(t *testing.T) {
+               ctrl := gomock.NewController(t)
+               t.Cleanup(ctrl.Finish)
+
+               mockMgmt := cloudstack.NewMockManagementServiceIface(ctrl)
+               params := &cloudstack.ListManagementServersMetricsParams{}
+               resp := &cloudstack.ListManagementServersMetricsResponse{
+                       Count:                    0,
+                       ManagementServersMetrics: 
[]*cloudstack.ManagementServersMetric{},
+               }
+
+               gomock.InOrder(
+                       
mockMgmt.EXPECT().NewListManagementServersMetricsParams().Return(params),
+                       
mockMgmt.EXPECT().ListManagementServersMetrics(params).Return(resp, nil),
+               )
+
+               cs := &CSCloud{
+                       client: &cloudstack.CloudStackClient{
+                               Management: mockMgmt,
+                       },
+               }
+
+               if _, err := cs.getManagementServerVersion(); err == nil {
+                       t.Fatalf("expected error for zero management servers")
+               }
+       })
+
+       t.Run("returns error when version cannot be parsed", func(t *testing.T) 
{
+               ctrl := gomock.NewController(t)
+               t.Cleanup(ctrl.Finish)
+
+               mockMgmt := cloudstack.NewMockManagementServiceIface(ctrl)
+               params := &cloudstack.ListManagementServersMetricsParams{}
+               resp := &cloudstack.ListManagementServersMetricsResponse{
+                       Count: 1,
+                       ManagementServersMetrics: 
[]*cloudstack.ManagementServersMetric{
+                               {Version: "invalid.version.string"},
+                       },
+               }
+
+               gomock.InOrder(
+                       
mockMgmt.EXPECT().NewListManagementServersMetricsParams().Return(params),
+                       
mockMgmt.EXPECT().ListManagementServersMetrics(params).Return(resp, nil),
+               )
+
+               cs := &CSCloud{
+                       client: &cloudstack.CloudStackClient{
+                               Management: mockMgmt,
+                       },
+               }
+
+               if _, err := cs.getManagementServerVersion(); err == nil {
+                       t.Fatalf("expected parse error")
+               }
+       })
+}
diff --git a/go.mod b/go.mod
index 1fe612dd..24e177ad 100644
--- a/go.mod
+++ b/go.mod
@@ -4,7 +4,9 @@ go 1.23.0
 
 require (
        github.com/apache/cloudstack-go/v2 v2.19.0
+       github.com/blang/semver/v4 v4.0.0
        github.com/spf13/pflag v1.0.5
+       go.uber.org/mock v0.5.0
        gopkg.in/gcfg.v1 v1.2.3
        k8s.io/api v0.24.17
        k8s.io/apimachinery v0.24.17
@@ -17,7 +19,6 @@ require (
        github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // 
indirect
        github.com/NYTimes/gziphandler v1.1.1 // indirect
        github.com/beorn7/perks v1.0.1 // indirect
-       github.com/blang/semver/v4 v4.0.0 // indirect
        github.com/cespare/xxhash/v2 v2.2.0 // indirect
        github.com/coreos/go-semver v0.3.0 // indirect
        github.com/coreos/go-systemd/v22 v22.3.2 // indirect
@@ -71,7 +72,6 @@ require (
        go.opentelemetry.io/otel/trace v0.20.0 // indirect
        go.opentelemetry.io/proto/otlp v0.7.0 // indirect
        go.uber.org/atomic v1.7.0 // indirect
-       go.uber.org/mock v0.5.0 // indirect
        go.uber.org/multierr v1.6.0 // indirect
        go.uber.org/zap v1.19.0 // indirect
        golang.org/x/crypto v0.36.0 // indirect

Reply via email to