This is an automated email from the ASF dual-hosted git repository.

pbacsko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/yunikorn-core.git


The following commit(s) were added to refs/heads/master by this push:
     new ec17e0b2 [YUNIKORN-2147] Limit the number of concurrent event streams 
(#777)
ec17e0b2 is described below

commit ec17e0b22a257e6a8c420955756adfd373acab23
Author: Peter Bacsko <[email protected]>
AuthorDate: Fri Feb 2 15:10:16 2024 +0100

    [YUNIKORN-2147] Limit the number of concurrent event streams (#777)
    
    Closes: #777
    
    Signed-off-by: Peter Bacsko <[email protected]>
---
 pkg/common/configs/configs.go          |   4 +
 pkg/webservice/handlers.go             |   9 +++
 pkg/webservice/handlers_test.go        |  49 +++++++++---
 pkg/webservice/streaming_limit.go      | 137 +++++++++++++++++++++++++++++++++
 pkg/webservice/streaming_limit_test.go | 124 +++++++++++++++++++++++++++++
 5 files changed, 314 insertions(+), 9 deletions(-)

diff --git a/pkg/common/configs/configs.go b/pkg/common/configs/configs.go
index 2cef08b6..4671c896 100644
--- a/pkg/common/configs/configs.go
+++ b/pkg/common/configs/configs.go
@@ -36,12 +36,16 @@ const (
        CMEventTrackingEnabled    = PrefixEvent + "trackingEnabled"    // 
Application Tracking
        CMEventRequestCapacity    = PrefixEvent + "requestCapacity"    // 
Request Capacity
        CMEventRingBufferCapacity = PrefixEvent + "ringBufferCapacity" // Ring 
Buffer Capacity
+       CMMaxEventStreams         = PrefixEvent + "maxStreams"
+       CMMaxEventStreamsPerHost  = PrefixEvent + "maxStreamsPerHost"
 
        // defaults
        DefaultHealthCheckInterval     = 30 * time.Second
        DefaultEventTrackingEnabled    = true
        DefaultEventRequestCapacity    = 1000
        DefaultEventRingBufferCapacity = 100000
+       DefaultMaxStreams              = uint64(100)
+       DefaultMaxStreamsPerHost       = uint64(15)
 )
 
 var ConfigContext *SchedulerConfigContext
diff --git a/pkg/webservice/handlers.go b/pkg/webservice/handlers.go
index 31d1dff7..809b49a6 100644
--- a/pkg/webservice/handlers.go
+++ b/pkg/webservice/handlers.go
@@ -63,6 +63,7 @@ const (
 
 var allowedActiveStatusMsg string
 var allowedAppActiveStatuses map[string]bool
+var streamingLimiter *StreamingLimiter
 
 func init() {
        allowedAppActiveStatuses = make(map[string]bool)
@@ -80,6 +81,8 @@ func init() {
                activeStatuses = append(activeStatuses, k)
        }
        allowedActiveStatusMsg = fmt.Sprintf("Only following active statuses 
are allowed: %s", strings.Join(activeStatuses, ","))
+
+       streamingLimiter = NewStreamingLimiter()
 }
 
 func getStackInfo(w http.ResponseWriter, r *http.Request) {
@@ -1102,6 +1105,12 @@ func getStream(w http.ResponseWriter, r *http.Request) {
                return
        }
 
+       if !streamingLimiter.AddHost(r.Host) {
+               buildJSONErrorResponse(w, "Too many streaming connections", 
http.StatusServiceUnavailable)
+               return
+       }
+       defer streamingLimiter.RemoveHost(r.Host)
+
        var count uint64
        if countStr := r.URL.Query().Get("count"); countStr != "" {
                var err error
diff --git a/pkg/webservice/handlers_test.go b/pkg/webservice/handlers_test.go
index f07b97e9..c598dc4f 100644
--- a/pkg/webservice/handlers_test.go
+++ b/pkg/webservice/handlers_test.go
@@ -2095,7 +2095,7 @@ func TestGetStream_TrackingDisabled(t *testing.T) {
        _, req := initEventsAndCreateRequest(t)
        resp := httptest.NewRecorder()
 
-       assertGetStreamError(t, req, resp, "Event tracking is disabled")
+       assertGetStreamError(t, req, resp, http.StatusInternalServerError, 
"Event tracking is disabled")
 }
 
 func TestGetStream_NoWriteDeadline(t *testing.T) {
@@ -2103,7 +2103,7 @@ func TestGetStream_NoWriteDeadline(t *testing.T) {
        defer ev.Stop()
        resp := httptest.NewRecorder() // does not have SetWriteDeadline()
 
-       assertGetStreamError(t, req, resp, "Cannot set write deadline: feature 
not supported")
+       assertGetStreamError(t, req, resp, http.StatusInternalServerError, 
"Cannot set write deadline: feature not supported")
 }
 
 func TestGetStream_SetWriteDeadlineFails(t *testing.T) {
@@ -2122,7 +2122,7 @@ func TestGetStream_SetWriteDeadlineFails(t *testing.T) {
        }()
 
        getStream(resp, req)
-       checkGetStreamErrorResult(t, resp.Result(), "Cannot set write deadline: 
SetWriteDeadline failed")
+       checkGetStreamErrorResult(t, resp.Result(), 
http.StatusInternalServerError, "Cannot set write deadline: SetWriteDeadline 
failed")
 }
 
 func TestGetStream_SetReadDeadlineFails(t *testing.T) {
@@ -2130,11 +2130,42 @@ func TestGetStream_SetReadDeadlineFails(t *testing.T) {
        resp := NewResponseRecorderWithDeadline()
        resp.setReadFails = true
 
-       assertGetStreamError(t, req, resp, "Cannot set read deadline: 
SetReadDeadline failed")
+       assertGetStreamError(t, req, resp, http.StatusInternalServerError, 
"Cannot set read deadline: SetReadDeadline failed")
 }
 
-func assertGetStreamError(t *testing.T, req *http.Request, resp interface{},
-       expectedMsg string) {
+func TestGetStream_Limit(t *testing.T) {
+       current := configs.GetConfigMap()
+       defer func() {
+               configs.SetConfigMap(current)
+       }()
+       configs.SetConfigMap(map[string]string{
+               configs.CMMaxEventStreams: "3",
+       })
+       resp := NewResponseRecorderWithDeadline()
+       ev, req := initEventsAndCreateRequest(t)
+       defer ev.Stop()
+
+       cancelCtx, cancel := context.WithCancel(context.Background())
+       req = req.Clone(cancelCtx)
+       defer cancel()
+       req.Host = "host-1"
+
+       // start simulated connections in the background
+       go getStream(NewResponseRecorderWithDeadline(), req)
+       go getStream(NewResponseRecorderWithDeadline(), req)
+       go getStream(NewResponseRecorderWithDeadline(), req)
+
+       // wait until the StreamingLimiter.AddHost() calls
+       err := common.WaitFor(time.Millisecond, time.Second, func() bool {
+               streamingLimiter.Lock()
+               defer streamingLimiter.Unlock()
+               return streamingLimiter.streams == 3
+       })
+       assert.NilError(t, err)
+       assertGetStreamError(t, req, resp, http.StatusServiceUnavailable, "Too 
many streaming connections")
+}
+
+func assertGetStreamError(t *testing.T, req *http.Request, resp interface{}, 
statusCode int, expectedMsg string) {
        t.Helper()
        var response *http.Response
 
@@ -2149,17 +2180,17 @@ func assertGetStreamError(t *testing.T, req 
*http.Request, resp interface{},
                t.Fatalf("unknown response recorder type")
        }
 
-       checkGetStreamErrorResult(t, response, expectedMsg)
+       checkGetStreamErrorResult(t, response, statusCode, expectedMsg)
 }
 
-func checkGetStreamErrorResult(t *testing.T, response *http.Response, 
expectedMsg string) {
+func checkGetStreamErrorResult(t *testing.T, response *http.Response, 
statusCode int, expectedMsg string) {
        t.Helper()
        output := make([]byte, 256)
        n, err := response.Body.Read(output)
        assert.NilError(t, err)
        line := string(output[:n])
        assertYunikornError(t, line, expectedMsg)
-       assert.Equal(t, http.StatusInternalServerError, response.StatusCode)
+       assert.Equal(t, statusCode, response.StatusCode)
 }
 
 func initEventsAndCreateRequest(t *testing.T) (*events.EventSystemImpl, 
*http.Request) {
diff --git a/pkg/webservice/streaming_limit.go 
b/pkg/webservice/streaming_limit.go
new file mode 100644
index 00000000..fe2df79c
--- /dev/null
+++ b/pkg/webservice/streaming_limit.go
@@ -0,0 +1,137 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package webservice
+
+import (
+       "fmt"
+       "strconv"
+       "sync"
+       "sync/atomic"
+
+       "go.uber.org/zap"
+
+       "github.com/apache/yunikorn-core/pkg/common/configs"
+       "github.com/apache/yunikorn-core/pkg/log"
+)
+
+var idGen atomic.Uint64
+
+// StreamingLimiter tracks the concurrent streaming connections.
+type StreamingLimiter struct {
+       perHostStreams map[string]uint64 // number of connections per host
+       streams        uint64            // number of connections (total)
+       id             string            // unique name for configmap callback
+
+       maxStreams        uint64 // maximum number of event streams
+       maxPerHostStreams uint64 // maximum number of event streams per host
+
+       sync.Mutex
+}
+
+func NewStreamingLimiter() *StreamingLimiter {
+       sl := &StreamingLimiter{
+               perHostStreams: make(map[string]uint64),
+               id:             fmt.Sprintf("stream-limiter-%d", idGen.Add(1)),
+       }
+
+       configs.AddConfigMapCallback(sl.id, func() {
+               log.Log(log.REST).Info("Reloading streaming limit settings")
+               sl.setLimits()
+       })
+       sl.setLimits()
+
+       return sl
+}
+
+func (sl *StreamingLimiter) AddHost(host string) bool {
+       sl.Lock()
+       defer sl.Unlock()
+
+       if sl.streams >= sl.maxStreams {
+               log.Log(log.SchedHealth).Info("Number of maximum stream 
connections reached",
+                       zap.Uint64("limit", sl.maxStreams),
+                       zap.String("host", host))
+               return false
+       }
+       if sl.perHostStreams[host] >= sl.maxPerHostStreams {
+               log.Log(log.SchedHealth).Info("Per host connection limit 
reached",
+                       zap.Uint64("limit", sl.maxPerHostStreams),
+                       zap.String("host", host))
+               return false
+       }
+
+       sl.streams++
+       sl.perHostStreams[host]++
+       return true
+}
+
+func (sl *StreamingLimiter) RemoveHost(host string) {
+       sl.Lock()
+       defer sl.Unlock()
+
+       count, ok := sl.perHostStreams[host]
+       if !ok {
+               log.Log(log.REST).Warn("Tried to remove a non-existing host 
from tracking",
+                       zap.String("host", host))
+               return
+       }
+
+       sl.streams--
+       if count <= 1 {
+               delete(sl.perHostStreams, host)
+               return
+       }
+       sl.perHostStreams[host]--
+}
+
+func (sl *StreamingLimiter) setLimits() {
+       sl.Lock()
+       defer sl.Unlock()
+
+       maxStreams := configs.DefaultMaxStreams
+       configMap := configs.GetConfigMap()
+
+       if value, ok := configMap[configs.CMMaxEventStreams]; ok {
+               parsed, err := strconv.ParseUint(value, 10, 64)
+               if err != nil {
+                       log.Log(log.REST).Warn("Failed to parse configuration 
value",
+                               zap.String("key", configs.CMMaxEventStreams),
+                               zap.String("value", value),
+                               zap.Error(err))
+               } else {
+                       maxStreams = parsed
+               }
+       }
+
+       maxStreamsPerHost := configs.DefaultMaxStreamsPerHost
+       if value, ok := configMap[configs.CMMaxEventStreamsPerHost]; ok {
+               parsed, err := strconv.ParseUint(value, 10, 64)
+               if err != nil {
+                       log.Log(log.REST).Warn("Failed to parse configuration 
value",
+                               zap.String("key", 
configs.CMMaxEventStreamsPerHost),
+                               zap.String("value", value),
+                               zap.Error(err))
+               } else {
+                       maxStreamsPerHost = parsed
+               }
+       }
+
+       sl.maxStreams = maxStreams
+       sl.maxPerHostStreams = maxStreamsPerHost
+}
diff --git a/pkg/webservice/streaming_limit_test.go 
b/pkg/webservice/streaming_limit_test.go
new file mode 100644
index 00000000..dc1fa6cb
--- /dev/null
+++ b/pkg/webservice/streaming_limit_test.go
@@ -0,0 +1,124 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package webservice
+
+import (
+       "testing"
+
+       "gotest.tools/v3/assert"
+
+       "github.com/apache/yunikorn-core/pkg/common/configs"
+)
+
+func TestAddRemoveHost(t *testing.T) {
+       sl := NewStreamingLimiter()
+       defer sl.Stop()
+       assert.Assert(t, sl.AddHost("host-1"))
+       assert.Assert(t, sl.AddHost("host-1"))
+       assert.Assert(t, sl.AddHost("host-2"))
+       assert.Equal(t, 2, len(sl.perHostStreams))
+       assert.Equal(t, uint64(3), sl.streams)
+
+       sl.RemoveHost("host-3") // remove non-existing
+       assert.Equal(t, 2, len(sl.perHostStreams))
+       assert.Equal(t, uint64(3), sl.streams)
+
+       sl.RemoveHost("host-1")
+       assert.Equal(t, 2, len(sl.perHostStreams))
+       assert.Equal(t, uint64(2), sl.streams)
+
+       sl.RemoveHost("host-2")
+       assert.Equal(t, 1, len(sl.perHostStreams))
+       assert.Equal(t, uint64(1), sl.streams)
+
+       sl.RemoveHost("host-1")
+       assert.Equal(t, 0, len(sl.perHostStreams))
+       assert.Equal(t, uint64(0), sl.streams)
+}
+
+func TestAddHost_TotalLimitHit(t *testing.T) {
+       current := configs.GetConfigMap()
+       defer func() {
+               configs.SetConfigMap(current)
+       }()
+       configs.SetConfigMap(map[string]string{
+               configs.CMMaxEventStreams: "2",
+       })
+       sl := NewStreamingLimiter()
+       sl.Stop()
+
+       assert.Assert(t, sl.AddHost("host-1"))
+       assert.Assert(t, sl.AddHost("host-2"))
+       assert.Assert(t, !sl.AddHost("host-3"))
+}
+
+func TestAddHost_PerHostLimitHit(t *testing.T) {
+       current := configs.GetConfigMap()
+       defer func() {
+               configs.SetConfigMap(current)
+       }()
+       configs.SetConfigMap(map[string]string{
+               configs.CMMaxEventStreamsPerHost: "2",
+       })
+       sl := NewStreamingLimiter()
+       defer sl.Stop()
+
+       assert.Assert(t, sl.AddHost("host-1"))
+       assert.Assert(t, sl.AddHost("host-1"))
+       assert.Assert(t, !sl.AddHost("host-1"))
+}
+
+func TestGetLimits(t *testing.T) {
+       current := configs.GetConfigMap()
+       defer func() {
+               configs.SetConfigMap(current)
+       }()
+       sl := NewStreamingLimiter()
+       defer sl.Stop()
+
+       sl.setLimits()
+       assert.Equal(t, uint64(100), sl.maxStreams)
+       assert.Equal(t, uint64(15), sl.maxPerHostStreams)
+
+       configs.SetConfigMap(map[string]string{
+               configs.CMMaxEventStreams: "123",
+       })
+       sl.setLimits()
+       assert.Equal(t, uint64(123), sl.maxStreams)
+       assert.Equal(t, uint64(15), sl.maxPerHostStreams)
+
+       configs.SetConfigMap(map[string]string{
+               configs.CMMaxEventStreamsPerHost: "321",
+       })
+       sl.setLimits()
+       assert.Equal(t, uint64(100), sl.maxStreams)
+       assert.Equal(t, uint64(321), sl.maxPerHostStreams)
+
+       configs.SetConfigMap(map[string]string{
+               configs.CMMaxEventStreams:        "xxx",
+               configs.CMMaxEventStreamsPerHost: "yyy",
+       })
+       sl.setLimits()
+       assert.Equal(t, uint64(100), sl.maxStreams)
+       assert.Equal(t, uint64(15), sl.maxPerHostStreams)
+}
+
+func (sl *StreamingLimiter) Stop() {
+       configs.RemoveConfigMapCallback(sl.id)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to