This is an automated email from the ASF dual-hosted git repository.
pbacsko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/yunikorn-core.git
The following commit(s) were added to refs/heads/master by this push:
new ec17e0b2 [YUNIKORN-2147] Limit the number of concurrent event streams
(#777)
ec17e0b2 is described below
commit ec17e0b22a257e6a8c420955756adfd373acab23
Author: Peter Bacsko <[email protected]>
AuthorDate: Fri Feb 2 15:10:16 2024 +0100
[YUNIKORN-2147] Limit the number of concurrent event streams (#777)
Closes: #777
Signed-off-by: Peter Bacsko <[email protected]>
---
pkg/common/configs/configs.go | 4 +
pkg/webservice/handlers.go | 9 +++
pkg/webservice/handlers_test.go | 49 +++++++++---
pkg/webservice/streaming_limit.go | 137 +++++++++++++++++++++++++++++++++
pkg/webservice/streaming_limit_test.go | 124 +++++++++++++++++++++++++++++
5 files changed, 314 insertions(+), 9 deletions(-)
diff --git a/pkg/common/configs/configs.go b/pkg/common/configs/configs.go
index 2cef08b6..4671c896 100644
--- a/pkg/common/configs/configs.go
+++ b/pkg/common/configs/configs.go
@@ -36,12 +36,16 @@ const (
CMEventTrackingEnabled = PrefixEvent + "trackingEnabled" //
Application Tracking
CMEventRequestCapacity = PrefixEvent + "requestCapacity" //
Request Capacity
CMEventRingBufferCapacity = PrefixEvent + "ringBufferCapacity" // Ring
Buffer Capacity
+ CMMaxEventStreams = PrefixEvent + "maxStreams"
+ CMMaxEventStreamsPerHost = PrefixEvent + "maxStreamsPerHost"
// defaults
DefaultHealthCheckInterval = 30 * time.Second
DefaultEventTrackingEnabled = true
DefaultEventRequestCapacity = 1000
DefaultEventRingBufferCapacity = 100000
+ DefaultMaxStreams = uint64(100)
+ DefaultMaxStreamsPerHost = uint64(15)
)
var ConfigContext *SchedulerConfigContext
diff --git a/pkg/webservice/handlers.go b/pkg/webservice/handlers.go
index 31d1dff7..809b49a6 100644
--- a/pkg/webservice/handlers.go
+++ b/pkg/webservice/handlers.go
@@ -63,6 +63,7 @@ const (
var allowedActiveStatusMsg string
var allowedAppActiveStatuses map[string]bool
+var streamingLimiter *StreamingLimiter
func init() {
allowedAppActiveStatuses = make(map[string]bool)
@@ -80,6 +81,8 @@ func init() {
activeStatuses = append(activeStatuses, k)
}
allowedActiveStatusMsg = fmt.Sprintf("Only following active statuses
are allowed: %s", strings.Join(activeStatuses, ","))
+
+ streamingLimiter = NewStreamingLimiter()
}
func getStackInfo(w http.ResponseWriter, r *http.Request) {
@@ -1102,6 +1105,12 @@ func getStream(w http.ResponseWriter, r *http.Request) {
return
}
+ if !streamingLimiter.AddHost(r.Host) {
+ buildJSONErrorResponse(w, "Too many streaming connections",
http.StatusServiceUnavailable)
+ return
+ }
+ defer streamingLimiter.RemoveHost(r.Host)
+
var count uint64
if countStr := r.URL.Query().Get("count"); countStr != "" {
var err error
diff --git a/pkg/webservice/handlers_test.go b/pkg/webservice/handlers_test.go
index f07b97e9..c598dc4f 100644
--- a/pkg/webservice/handlers_test.go
+++ b/pkg/webservice/handlers_test.go
@@ -2095,7 +2095,7 @@ func TestGetStream_TrackingDisabled(t *testing.T) {
_, req := initEventsAndCreateRequest(t)
resp := httptest.NewRecorder()
- assertGetStreamError(t, req, resp, "Event tracking is disabled")
+ assertGetStreamError(t, req, resp, http.StatusInternalServerError,
"Event tracking is disabled")
}
func TestGetStream_NoWriteDeadline(t *testing.T) {
@@ -2103,7 +2103,7 @@ func TestGetStream_NoWriteDeadline(t *testing.T) {
defer ev.Stop()
resp := httptest.NewRecorder() // does not have SetWriteDeadline()
- assertGetStreamError(t, req, resp, "Cannot set write deadline: feature
not supported")
+ assertGetStreamError(t, req, resp, http.StatusInternalServerError,
"Cannot set write deadline: feature not supported")
}
func TestGetStream_SetWriteDeadlineFails(t *testing.T) {
@@ -2122,7 +2122,7 @@ func TestGetStream_SetWriteDeadlineFails(t *testing.T) {
}()
getStream(resp, req)
- checkGetStreamErrorResult(t, resp.Result(), "Cannot set write deadline:
SetWriteDeadline failed")
+ checkGetStreamErrorResult(t, resp.Result(),
http.StatusInternalServerError, "Cannot set write deadline: SetWriteDeadline
failed")
}
func TestGetStream_SetReadDeadlineFails(t *testing.T) {
@@ -2130,11 +2130,42 @@ func TestGetStream_SetReadDeadlineFails(t *testing.T) {
resp := NewResponseRecorderWithDeadline()
resp.setReadFails = true
- assertGetStreamError(t, req, resp, "Cannot set read deadline:
SetReadDeadline failed")
+ assertGetStreamError(t, req, resp, http.StatusInternalServerError,
"Cannot set read deadline: SetReadDeadline failed")
}
-func assertGetStreamError(t *testing.T, req *http.Request, resp interface{},
- expectedMsg string) {
+func TestGetStream_Limit(t *testing.T) {
+ current := configs.GetConfigMap()
+ defer func() {
+ configs.SetConfigMap(current)
+ }()
+ configs.SetConfigMap(map[string]string{
+ configs.CMMaxEventStreams: "3",
+ })
+ resp := NewResponseRecorderWithDeadline()
+ ev, req := initEventsAndCreateRequest(t)
+ defer ev.Stop()
+
+ cancelCtx, cancel := context.WithCancel(context.Background())
+ req = req.Clone(cancelCtx)
+ defer cancel()
+ req.Host = "host-1"
+
+ // start simulated connections in the background
+ go getStream(NewResponseRecorderWithDeadline(), req)
+ go getStream(NewResponseRecorderWithDeadline(), req)
+ go getStream(NewResponseRecorderWithDeadline(), req)
+
+ // wait until the StreamingLimiter.AddHost() calls
+ err := common.WaitFor(time.Millisecond, time.Second, func() bool {
+ streamingLimiter.Lock()
+ defer streamingLimiter.Unlock()
+ return streamingLimiter.streams == 3
+ })
+ assert.NilError(t, err)
+ assertGetStreamError(t, req, resp, http.StatusServiceUnavailable, "Too
many streaming connections")
+}
+
+func assertGetStreamError(t *testing.T, req *http.Request, resp interface{},
statusCode int, expectedMsg string) {
t.Helper()
var response *http.Response
@@ -2149,17 +2180,17 @@ func assertGetStreamError(t *testing.T, req
*http.Request, resp interface{},
t.Fatalf("unknown response recorder type")
}
- checkGetStreamErrorResult(t, response, expectedMsg)
+ checkGetStreamErrorResult(t, response, statusCode, expectedMsg)
}
-func checkGetStreamErrorResult(t *testing.T, response *http.Response,
expectedMsg string) {
+func checkGetStreamErrorResult(t *testing.T, response *http.Response,
statusCode int, expectedMsg string) {
t.Helper()
output := make([]byte, 256)
n, err := response.Body.Read(output)
assert.NilError(t, err)
line := string(output[:n])
assertYunikornError(t, line, expectedMsg)
- assert.Equal(t, http.StatusInternalServerError, response.StatusCode)
+ assert.Equal(t, statusCode, response.StatusCode)
}
func initEventsAndCreateRequest(t *testing.T) (*events.EventSystemImpl,
*http.Request) {
diff --git a/pkg/webservice/streaming_limit.go
b/pkg/webservice/streaming_limit.go
new file mode 100644
index 00000000..fe2df79c
--- /dev/null
+++ b/pkg/webservice/streaming_limit.go
@@ -0,0 +1,137 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package webservice
+
+import (
+ "fmt"
+ "strconv"
+ "sync"
+ "sync/atomic"
+
+ "go.uber.org/zap"
+
+ "github.com/apache/yunikorn-core/pkg/common/configs"
+ "github.com/apache/yunikorn-core/pkg/log"
+)
+
+var idGen atomic.Uint64
+
+// StreamingLimiter tracks the concurrent streaming connections.
+type StreamingLimiter struct {
+ perHostStreams map[string]uint64 // number of connections per host
+ streams uint64 // number of connections (total)
+ id string // unique name for configmap callback
+
+ maxStreams uint64 // maximum number of event streams
+ maxPerHostStreams uint64 // maximum number of event streams per host
+
+ sync.Mutex
+}
+
+func NewStreamingLimiter() *StreamingLimiter {
+ sl := &StreamingLimiter{
+ perHostStreams: make(map[string]uint64),
+ id: fmt.Sprintf("stream-limiter-%d", idGen.Add(1)),
+ }
+
+ configs.AddConfigMapCallback(sl.id, func() {
+ log.Log(log.REST).Info("Reloading streaming limit settings")
+ sl.setLimits()
+ })
+ sl.setLimits()
+
+ return sl
+}
+
+func (sl *StreamingLimiter) AddHost(host string) bool {
+ sl.Lock()
+ defer sl.Unlock()
+
+ if sl.streams >= sl.maxStreams {
+ log.Log(log.SchedHealth).Info("Number of maximum stream
connections reached",
+ zap.Uint64("limit", sl.maxStreams),
+ zap.String("host", host))
+ return false
+ }
+ if sl.perHostStreams[host] >= sl.maxPerHostStreams {
+ log.Log(log.SchedHealth).Info("Per host connection limit
reached",
+ zap.Uint64("limit", sl.maxPerHostStreams),
+ zap.String("host", host))
+ return false
+ }
+
+ sl.streams++
+ sl.perHostStreams[host]++
+ return true
+}
+
+func (sl *StreamingLimiter) RemoveHost(host string) {
+ sl.Lock()
+ defer sl.Unlock()
+
+ count, ok := sl.perHostStreams[host]
+ if !ok {
+ log.Log(log.REST).Warn("Tried to remove a non-existing host
from tracking",
+ zap.String("host", host))
+ return
+ }
+
+ sl.streams--
+ if count <= 1 {
+ delete(sl.perHostStreams, host)
+ return
+ }
+ sl.perHostStreams[host]--
+}
+
+func (sl *StreamingLimiter) setLimits() {
+ sl.Lock()
+ defer sl.Unlock()
+
+ maxStreams := configs.DefaultMaxStreams
+ configMap := configs.GetConfigMap()
+
+ if value, ok := configMap[configs.CMMaxEventStreams]; ok {
+ parsed, err := strconv.ParseUint(value, 10, 64)
+ if err != nil {
+ log.Log(log.REST).Warn("Failed to parse configuration
value",
+ zap.String("key", configs.CMMaxEventStreams),
+ zap.String("value", value),
+ zap.Error(err))
+ } else {
+ maxStreams = parsed
+ }
+ }
+
+ maxStreamsPerHost := configs.DefaultMaxStreamsPerHost
+ if value, ok := configMap[configs.CMMaxEventStreamsPerHost]; ok {
+ parsed, err := strconv.ParseUint(value, 10, 64)
+ if err != nil {
+ log.Log(log.REST).Warn("Failed to parse configuration
value",
+ zap.String("key",
configs.CMMaxEventStreamsPerHost),
+ zap.String("value", value),
+ zap.Error(err))
+ } else {
+ maxStreamsPerHost = parsed
+ }
+ }
+
+ sl.maxStreams = maxStreams
+ sl.maxPerHostStreams = maxStreamsPerHost
+}
diff --git a/pkg/webservice/streaming_limit_test.go
b/pkg/webservice/streaming_limit_test.go
new file mode 100644
index 00000000..dc1fa6cb
--- /dev/null
+++ b/pkg/webservice/streaming_limit_test.go
@@ -0,0 +1,124 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package webservice
+
+import (
+ "testing"
+
+ "gotest.tools/v3/assert"
+
+ "github.com/apache/yunikorn-core/pkg/common/configs"
+)
+
+func TestAddRemoveHost(t *testing.T) {
+ sl := NewStreamingLimiter()
+ defer sl.Stop()
+ assert.Assert(t, sl.AddHost("host-1"))
+ assert.Assert(t, sl.AddHost("host-1"))
+ assert.Assert(t, sl.AddHost("host-2"))
+ assert.Equal(t, 2, len(sl.perHostStreams))
+ assert.Equal(t, uint64(3), sl.streams)
+
+ sl.RemoveHost("host-3") // remove non-existing
+ assert.Equal(t, 2, len(sl.perHostStreams))
+ assert.Equal(t, uint64(3), sl.streams)
+
+ sl.RemoveHost("host-1")
+ assert.Equal(t, 2, len(sl.perHostStreams))
+ assert.Equal(t, uint64(2), sl.streams)
+
+ sl.RemoveHost("host-2")
+ assert.Equal(t, 1, len(sl.perHostStreams))
+ assert.Equal(t, uint64(1), sl.streams)
+
+ sl.RemoveHost("host-1")
+ assert.Equal(t, 0, len(sl.perHostStreams))
+ assert.Equal(t, uint64(0), sl.streams)
+}
+
+func TestAddHost_TotalLimitHit(t *testing.T) {
+ current := configs.GetConfigMap()
+ defer func() {
+ configs.SetConfigMap(current)
+ }()
+ configs.SetConfigMap(map[string]string{
+ configs.CMMaxEventStreams: "2",
+ })
+ sl := NewStreamingLimiter()
+ sl.Stop()
+
+ assert.Assert(t, sl.AddHost("host-1"))
+ assert.Assert(t, sl.AddHost("host-2"))
+ assert.Assert(t, !sl.AddHost("host-3"))
+}
+
+func TestAddHost_PerHostLimitHit(t *testing.T) {
+ current := configs.GetConfigMap()
+ defer func() {
+ configs.SetConfigMap(current)
+ }()
+ configs.SetConfigMap(map[string]string{
+ configs.CMMaxEventStreamsPerHost: "2",
+ })
+ sl := NewStreamingLimiter()
+ defer sl.Stop()
+
+ assert.Assert(t, sl.AddHost("host-1"))
+ assert.Assert(t, sl.AddHost("host-1"))
+ assert.Assert(t, !sl.AddHost("host-1"))
+}
+
+func TestGetLimits(t *testing.T) {
+ current := configs.GetConfigMap()
+ defer func() {
+ configs.SetConfigMap(current)
+ }()
+ sl := NewStreamingLimiter()
+ defer sl.Stop()
+
+ sl.setLimits()
+ assert.Equal(t, uint64(100), sl.maxStreams)
+ assert.Equal(t, uint64(15), sl.maxPerHostStreams)
+
+ configs.SetConfigMap(map[string]string{
+ configs.CMMaxEventStreams: "123",
+ })
+ sl.setLimits()
+ assert.Equal(t, uint64(123), sl.maxStreams)
+ assert.Equal(t, uint64(15), sl.maxPerHostStreams)
+
+ configs.SetConfigMap(map[string]string{
+ configs.CMMaxEventStreamsPerHost: "321",
+ })
+ sl.setLimits()
+ assert.Equal(t, uint64(100), sl.maxStreams)
+ assert.Equal(t, uint64(321), sl.maxPerHostStreams)
+
+ configs.SetConfigMap(map[string]string{
+ configs.CMMaxEventStreams: "xxx",
+ configs.CMMaxEventStreamsPerHost: "yyy",
+ })
+ sl.setLimits()
+ assert.Equal(t, uint64(100), sl.maxStreams)
+ assert.Equal(t, uint64(15), sl.maxPerHostStreams)
+}
+
+func (sl *StreamingLimiter) Stop() {
+ configs.RemoveConfigMapCallback(sl.id)
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]