This is an automated email from the ASF dual-hosted git repository.

pbacsko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/yunikorn-core.git


The following commit(s) were added to refs/heads/master by this push:
     new 681a4cbf [YUNIKORN-1709] Add event streaming logic (#533)
681a4cbf is described below

commit 681a4cbf4b9929e579674e44e6309acf0e728e29
Author: Peter Bacsko <[email protected]>
AuthorDate: Wed Jan 24 10:28:17 2024 +0100

    [YUNIKORN-1709] Add event streaming logic (#533)
    
    Closes: #533
    
    Signed-off-by: Peter Bacsko <[email protected]>
---
 pkg/events/event_ringbuffer.go       |  28 +++-
 pkg/events/event_ringbuffer_test.go  |  33 +++++
 pkg/events/event_streaming.go        | 179 ++++++++++++++++++++++
 pkg/events/event_streaming_test.go   | 145 ++++++++++++++++++
 pkg/events/event_system.go           |  72 ++++++++-
 pkg/scheduler/objects/common_test.go |   8 +
 pkg/webservice/handler_mock_test.go  |   5 +
 pkg/webservice/handlers.go           |  80 ++++++++++
 pkg/webservice/handlers_test.go      | 278 +++++++++++++++++++++++++++++++++++
 pkg/webservice/routes.go             |   6 +
 10 files changed, 830 insertions(+), 4 deletions(-)

diff --git a/pkg/events/event_ringbuffer.go b/pkg/events/event_ringbuffer.go
index 479ea70f..3ae95331 100644
--- a/pkg/events/event_ringbuffer.go
+++ b/pkg/events/event_ringbuffer.go
@@ -70,7 +70,25 @@ func (e *eventRingBuffer) Add(event *si.EventRecord) {
        e.id++
 }
 
-// GetEventsFromID returns "count" number of event records from "id" if 
possible. The id can be determined from
+// GetRecentEvents returns the most recent "count" elements from the ring 
buffer.
+// It is allowed for "count" to be larger than the number of elements.
+func (e *eventRingBuffer) GetRecentEvents(count uint64) []*si.EventRecord {
+       e.RLock()
+       defer e.RUnlock()
+
+       lastID := e.getLastEventID()
+       var startID uint64
+       if lastID < count {
+               startID = 0
+       } else {
+               startID = lastID - count + 1
+       }
+
+       history, _, _ := e.getEventsFromID(startID, count)
+       return history
+}
+
+// GetEventsFromID returns "count" number of event records from id if 
possible. The id can be determined from
 // the first call of the method - if it returns nothing because the id is not 
in the buffer, the lowest valid
 // identifier is returned which can be used to get the first batch.
 // If the caller does not want to pose limit on the number of events returned, 
"count" must be set to a high
@@ -78,6 +96,14 @@ func (e *eventRingBuffer) Add(event *si.EventRecord) {
 func (e *eventRingBuffer) GetEventsFromID(id uint64, count uint64) 
([]*si.EventRecord, uint64, uint64) {
        e.RLock()
        defer e.RUnlock()
+
+       return e.getEventsFromID(id, count)
+}
+
+// getEventsFromID unlocked version of GetEventsFromID
+func (e *eventRingBuffer) getEventsFromID(id uint64, count uint64) 
([]*si.EventRecord, uint64, uint64) {
+       e.RLock()
+       defer e.RUnlock()
        lowest := e.getLowestID()
 
        pos, idFound := e.id2pos(id)
diff --git a/pkg/events/event_ringbuffer_test.go 
b/pkg/events/event_ringbuffer_test.go
index 6b390e69..d25fc864 100644
--- a/pkg/events/event_ringbuffer_test.go
+++ b/pkg/events/event_ringbuffer_test.go
@@ -277,6 +277,39 @@ func TestResize(t *testing.T) {
        assert.Equal(t, uint64(7), ringBuffer.resizeOffset)
 }
 
+func TestGetRecentEvents(t *testing.T) {
+       // empty
+       buffer := newEventRingBuffer(10)
+       records := buffer.GetRecentEvents(5)
+       assert.Equal(t, 0, len(records))
+
+       populate(buffer, 5)
+
+       // count < elements
+       records = buffer.GetRecentEvents(2)
+       assert.Equal(t, 2, len(records))
+       assert.Equal(t, int64(3), records[0].TimestampNano)
+       assert.Equal(t, int64(4), records[1].TimestampNano)
+
+       // count = elements
+       records = buffer.GetRecentEvents(5)
+       assert.Equal(t, 5, len(records))
+       assert.Equal(t, int64(0), records[0].TimestampNano)
+       assert.Equal(t, int64(1), records[1].TimestampNano)
+       assert.Equal(t, int64(2), records[2].TimestampNano)
+       assert.Equal(t, int64(3), records[3].TimestampNano)
+       assert.Equal(t, int64(4), records[4].TimestampNano)
+
+       // count > elements
+       records = buffer.GetRecentEvents(15)
+       assert.Equal(t, 5, len(records))
+       assert.Equal(t, int64(0), records[0].TimestampNano)
+       assert.Equal(t, int64(1), records[1].TimestampNano)
+       assert.Equal(t, int64(2), records[2].TimestampNano)
+       assert.Equal(t, int64(3), records[3].TimestampNano)
+       assert.Equal(t, int64(4), records[4].TimestampNano)
+}
+
 func populate(buffer *eventRingBuffer, count int) {
        for i := 0; i < count; i++ {
                buffer.Add(&si.EventRecord{
diff --git a/pkg/events/event_streaming.go b/pkg/events/event_streaming.go
new file mode 100644
index 00000000..4f7b9d26
--- /dev/null
+++ b/pkg/events/event_streaming.go
@@ -0,0 +1,179 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package events
+
+import (
+       "sync"
+       "time"
+
+       "go.uber.org/zap"
+
+       "github.com/apache/yunikorn-core/pkg/log"
+       "github.com/apache/yunikorn-scheduler-interface/lib/go/si"
+)
+
+const defaultChannelBufSize = 1000
+
+// EventStreaming implements the event streaming logic.
+// New events are immediately forwarded to all active consumers.
+type EventStreaming struct {
+       buffer       *eventRingBuffer
+       stopCh       chan struct{}
+       eventStreams map[*EventStream]eventConsumerDetails
+       sync.Mutex
+}
+
+type eventConsumerDetails struct {
+       local     chan *si.EventRecord
+       consumer  chan<- *si.EventRecord
+       stopCh    chan struct{}
+       name      string
+       createdAt time.Time
+}
+
+// EventStream handle type returned to the client that wants to capture the 
stream of events.
+type EventStream struct {
+       Events <-chan *si.EventRecord
+}
+
+// PublishEvent publishes an event to all event stream consumers.
+//
+// The streaming logic uses bridging to ensure proper ordering of existing and 
new events.
+// Events are sent to the "local" channel from where it is forwarded to the 
"consumer" channel.
+//
+// If "local" is full, it means that the consumer side has not processed the 
events at an appropriate pace.
+// Such a consumer is removed and the related channels are closed.
+func (e *EventStreaming) PublishEvent(event *si.EventRecord) {
+       e.Lock()
+       defer e.Unlock()
+
+       for consumer, details := range e.eventStreams {
+               if len(details.local) == defaultChannelBufSize {
+                       log.Log(log.Events).Warn("Listener buffer full due to 
potentially slow consumer, removing it")
+                       e.removeEventStream(consumer)
+                       continue
+               }
+
+               details.local <- event
+       }
+}
+
+// CreateEventStream sets up event streaming for a consumer. The returned 
EventStream object
+// contains a channel that can be used for reading.
+//
+// When a consumer is finished, it must call RemoveEventStream to free up 
resources.
+//
+// Consumers have an arbitrary name for logging purposes. The "count" 
parameter defines the number
+// of maximum historical events from the ring buffer. "0" is a valid value and 
means no past events.
+func (e *EventStreaming) CreateEventStream(name string, count uint64) 
*EventStream {
+       consumer := make(chan *si.EventRecord, defaultChannelBufSize)
+       stream := &EventStream{
+               Events: consumer,
+       }
+       local := make(chan *si.EventRecord, defaultChannelBufSize)
+       stop := make(chan struct{})
+       e.createEventStreamInternal(stream, local, consumer, stop, name, count)
+       history := e.buffer.GetRecentEvents(count)
+
+       go func(consumer chan<- *si.EventRecord, local <-chan *si.EventRecord, 
stop <-chan struct{}) {
+               // Store the refs of historical events; it's possible that some 
events are added to the
+               // ring buffer and also to "local" channel.
+               // It is because we use two separate locks, so event updates 
are not atomic.
+               // Example: an event has been just added to the ring buffer 
(before createEventStreamInternal()),
+               // and execution is about to enter PublishEvent(); at this 
point we have an updated "eventStreams"
+               // map, so "local" will also contain the new event.
+               seen := make(map[*si.EventRecord]bool)
+               for _, event := range history {
+                       consumer <- event
+                       seen[event] = true
+               }
+               for {
+                       select {
+                       case <-e.stopCh:
+                               close(consumer)
+                               return
+                       case <-stop:
+                               close(consumer)
+                               return
+                       case event := <-local:
+                               if seen[event] {
+                                       continue
+                               }
+                               // since events are processed in a single 
goroutine, doubling is no longer
+                               // possible at this point
+                               seen = make(map[*si.EventRecord]bool)
+                               consumer <- event
+                       }
+               }
+       }(consumer, local, stop)
+
+       log.Log(log.Events).Info("Created event stream", zap.String("consumer 
name", name))
+       return stream
+}
+
+func (e *EventStreaming) createEventStreamInternal(stream *EventStream,
+       local chan *si.EventRecord,
+       consumer chan *si.EventRecord,
+       stop chan struct{},
+       name string,
+       count uint64) {
+       // stuff that needs locking
+       e.Lock()
+       defer e.Unlock()
+
+       e.eventStreams[stream] = eventConsumerDetails{
+               local:     local,
+               consumer:  consumer,
+               stopCh:    stop,
+               name:      name,
+               createdAt: time.Now(),
+       }
+}
+
+// RemoveEventStream stops the streaming for a given consumer. Must be called 
to avoid resource leaks.
+func (e *EventStreaming) RemoveEventStream(consumer *EventStream) {
+       e.Lock()
+       defer e.Unlock()
+
+       e.removeEventStream(consumer)
+}
+
+func (e *EventStreaming) removeEventStream(consumer *EventStream) {
+       if details, ok := e.eventStreams[consumer]; ok {
+               log.Log(log.Events).Info("Removing event stream consumer", 
zap.String("name", details.name),
+                       zap.Time("creation time", details.createdAt))
+               close(details.stopCh)
+               close(details.local)
+               delete(e.eventStreams, consumer)
+       }
+}
+
+// Close stops event streaming completely.
+func (e *EventStreaming) Close() {
+       close(e.stopCh)
+}
+
+// NewEventStreaming creates a new event streaming infrastructure.
+func NewEventStreaming(eventBuffer *eventRingBuffer) *EventStreaming {
+       return &EventStreaming{
+               buffer:       eventBuffer,
+               stopCh:       make(chan struct{}),
+               eventStreams: make(map[*EventStream]eventConsumerDetails),
+       }
+}
diff --git a/pkg/events/event_streaming_test.go 
b/pkg/events/event_streaming_test.go
new file mode 100644
index 00000000..8afc770a
--- /dev/null
+++ b/pkg/events/event_streaming_test.go
@@ -0,0 +1,145 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package events
+
+import (
+       "testing"
+       "time"
+
+       "gotest.tools/v3/assert"
+
+       "github.com/apache/yunikorn-scheduler-interface/lib/go/si"
+)
+
+var defaultCount = uint64(10000)
+
+func TestEventStreaming_WithoutHistory(t *testing.T) {
+       buffer := newEventRingBuffer(10)
+       streaming := NewEventStreaming(buffer)
+       es := streaming.CreateEventStream("test", defaultCount)
+       defer streaming.Close()
+
+       sent := &si.EventRecord{
+               Message: "testMessage",
+       }
+       streaming.PublishEvent(sent)
+       received := receive(t, es.Events)
+       streaming.RemoveEventStream(es)
+       assert.Equal(t, 0, len(streaming.eventStreams[es].local))
+       assert.Equal(t, 0, len(streaming.eventStreams[es].consumer))
+       assert.Equal(t, received.Message, sent.Message)
+       assert.Equal(t, 0, len(streaming.eventStreams))
+}
+
+func TestEventStreaming_WithHistory(t *testing.T) {
+       buffer := newEventRingBuffer(10)
+       streaming := NewEventStreaming(buffer)
+       defer streaming.Close()
+
+       buffer.Add(&si.EventRecord{TimestampNano: 1})
+       buffer.Add(&si.EventRecord{TimestampNano: 5})
+       buffer.Add(&si.EventRecord{TimestampNano: 6})
+       buffer.Add(&si.EventRecord{TimestampNano: 9})
+       es := streaming.CreateEventStream("test", defaultCount)
+
+       streaming.PublishEvent(&si.EventRecord{TimestampNano: 10})
+
+       received1 := receive(t, es.Events)
+       received2 := receive(t, es.Events)
+       received3 := receive(t, es.Events)
+       received4 := receive(t, es.Events)
+       received5 := receive(t, es.Events)
+       assert.Equal(t, 0, len(streaming.eventStreams[es].local))
+       assert.Equal(t, 0, len(streaming.eventStreams[es].consumer))
+       assert.Equal(t, int64(1), received1.TimestampNano)
+       assert.Equal(t, int64(5), received2.TimestampNano)
+       assert.Equal(t, int64(6), received3.TimestampNano)
+       assert.Equal(t, int64(9), received4.TimestampNano)
+       assert.Equal(t, int64(10), received5.TimestampNano)
+       streaming.RemoveEventStream(es)
+       assert.Equal(t, 0, len(streaming.eventStreams))
+}
+
+func TestEventStreaming_WithHistoryCount(t *testing.T) {
+       buffer := newEventRingBuffer(10)
+       streaming := NewEventStreaming(buffer)
+       defer streaming.Close()
+
+       buffer.Add(&si.EventRecord{TimestampNano: 1})
+       buffer.Add(&si.EventRecord{TimestampNano: 5})
+       buffer.Add(&si.EventRecord{TimestampNano: 6})
+       buffer.Add(&si.EventRecord{TimestampNano: 9})
+       es := streaming.CreateEventStream("test", 2)
+
+       streaming.PublishEvent(&si.EventRecord{TimestampNano: 10})
+
+       received1 := receive(t, es.Events)
+       received2 := receive(t, es.Events)
+       received3 := receive(t, es.Events)
+       assert.Equal(t, 0, len(streaming.eventStreams[es].local))
+       assert.Equal(t, 0, len(streaming.eventStreams[es].consumer))
+       assert.Equal(t, int64(6), received1.TimestampNano)
+       assert.Equal(t, int64(9), received2.TimestampNano)
+       assert.Equal(t, int64(10), received3.TimestampNano)
+}
+
+func TestEventStreaming_TwoConsumers(t *testing.T) {
+       buffer := newEventRingBuffer(10)
+       streaming := NewEventStreaming(buffer)
+       defer streaming.Close()
+
+       es1 := streaming.CreateEventStream("stream1", defaultCount)
+       es2 := streaming.CreateEventStream("stream2", defaultCount)
+       for i := 0; i < 5; i++ {
+               streaming.PublishEvent(&si.EventRecord{TimestampNano: int64(i)})
+       }
+
+       for i := 0; i < 5; i++ {
+               assert.Equal(t, int64(i), receive(t, es1.Events).TimestampNano)
+               assert.Equal(t, int64(i), receive(t, es2.Events).TimestampNano)
+       }
+       assert.Equal(t, 0, len(streaming.eventStreams[es1].local))
+       assert.Equal(t, 0, len(streaming.eventStreams[es1].consumer))
+       assert.Equal(t, 0, len(streaming.eventStreams[es2].local))
+       assert.Equal(t, 0, len(streaming.eventStreams[es2].consumer))
+}
+
+func TestEventStreaming_SlowConsumer(t *testing.T) {
+       // simulating a slow event consumer by ignoring events
+       buffer := newEventRingBuffer(10)
+       streaming := NewEventStreaming(buffer)
+       defer streaming.Close()
+       streaming.CreateEventStream("test", 10000)
+
+       for i := 0; i < 2500; i++ {
+               streaming.PublishEvent(&si.EventRecord{TimestampNano: int64(i)})
+       }
+
+       assert.Equal(t, 0, len(streaming.eventStreams))
+}
+
+func receive(t *testing.T, input <-chan *si.EventRecord) *si.EventRecord {
+       select {
+       case event := <-input:
+               return event
+       case <-time.After(time.Second):
+               t.Fatal("receive failed")
+               return nil
+       }
+}
diff --git a/pkg/events/event_system.go b/pkg/events/event_system.go
index ef8f7c00..e89f6289 100644
--- a/pkg/events/event_system.go
+++ b/pkg/events/event_system.go
@@ -38,18 +38,50 @@ var once sync.Once
 var ev EventSystem
 
 type EventSystem interface {
+       // AddEvent adds an event record to the event system for processing:
+       // 1. It is added to a slice from where it is periodically read by the 
shim publisher.
+       // 2. It is added to an internal ring buffer so that clients can 
retrieve the event history.
+       // 3. Streaming clients are updated.
        AddEvent(event *si.EventRecord)
+
+       // StartService starts the event system.
+       // This method does not block. Events are processed on a separate 
goroutine.
        StartService()
+
+       // Stop stops the event system.
        Stop()
+
+       // IsEventTrackingEnabled whether history tracking is currently enabled 
or not.
        IsEventTrackingEnabled() bool
-       GetEventsFromID(uint64, uint64) ([]*si.EventRecord, uint64, uint64)
+
+       // GetEventsFromID retrieves "count" number of elements from the 
history buffer from "id". Every
+       // event has a unique ID inside the ring buffer.
+       // If "id" is not in the buffer, then no record is returned, but the 
currently available range
+       // [low..high] is set.
+       GetEventsFromID(id, count uint64) ([]*si.EventRecord, uint64, uint64)
+
+       // CreateEventStream creates an event stream (channel) for a consumer.
+       // The "name" argument is an arbitrary string for a consumer, which is 
used for logging. It does not need to be unique.
+       // The "count" argument defines how many historical elements should be 
returned on the stream. Zero is a valid value for "count".
+       // The returned type contains a read-only channel which is updated as 
soon as there is a new event record.
+       // It is also used as a handle to stop the streaming.
+       // Consumers must read the channel and process the event objects as 
soon as they can to avoid
+       // events piling up inside the channel buffers.
+       CreateEventStream(name string, count uint64) *EventStream
+
+       // RemoveStream stops streaming for a given consumer.
+       // Consumers that no longer wish to be updated (e.g., a remote client
+       // disconnected) *must* call this method to gracefully stop the 
streaming.
+       RemoveStream(*EventStream)
 }
 
+// EventSystemImpl main implementation of the event system which is used for 
history tracking.
 type EventSystemImpl struct {
        eventSystemId string
        Store         *EventStore // storing eventChannel
        publisher     *EventPublisher
        eventBuffer   *eventRingBuffer
+       streaming     *EventStreaming
 
        channel chan *si.EventRecord // channelling input eventChannel
        stop    chan bool            // whether the service is stopped
@@ -62,10 +94,22 @@ type EventSystemImpl struct {
        sync.RWMutex
 }
 
+// CreateEventStream creates an event stream. See the interface for details.
+func (ec *EventSystemImpl) CreateEventStream(name string, count uint64) 
*EventStream {
+       return ec.streaming.CreateEventStream(name, count)
+}
+
+// RemoveStream graceful termination of an event streaming for a consumer. See 
the interface for details.
+func (ec *EventSystemImpl) RemoveStream(consumer *EventStream) {
+       ec.streaming.RemoveEventStream(consumer)
+}
+
+// GetEventsFromID retrieves historical elements. See the interface for 
details.
 func (ec *EventSystemImpl) GetEventsFromID(id, count uint64) 
([]*si.EventRecord, uint64, uint64) {
        return ec.eventBuffer.GetEventsFromID(id, count)
 }
 
+// GetEventSystem returns the event system instance. Initialization happens 
during the first call.
 func GetEventSystem() EventSystem {
        once.Do(func() {
                Init()
@@ -73,42 +117,51 @@ func GetEventSystem() EventSystem {
        return ev
 }
 
+// IsEventTrackingEnabled whether history tracking is currently enabled or not.
 func (ec *EventSystemImpl) IsEventTrackingEnabled() bool {
        ec.RLock()
        defer ec.RUnlock()
        return ec.trackingEnabled
 }
 
+// GetRequestCapacity returns the capacity of an intermediate storage which is 
used by the shim publisher.
 func (ec *EventSystemImpl) GetRequestCapacity() int {
        ec.RLock()
        defer ec.RUnlock()
        return ec.requestCapacity
 }
 
+// GetRingBufferCapacity returns the capacity of the buffer which stores 
historical elements.
 func (ec *EventSystemImpl) GetRingBufferCapacity() uint64 {
        ec.RLock()
        defer ec.RUnlock()
        return ec.ringBufferCapacity
 }
 
-// VisibleForTesting
+// Init Initializes the event system.
+// Only exported for testing.
 func Init() {
        store := newEventStore()
+       buffer := newEventRingBuffer(defaultRingBufferSize)
        ev = &EventSystemImpl{
                Store:         store,
                channel:       make(chan *si.EventRecord, 
defaultEventChannelSize),
                stop:          make(chan bool),
                stopped:       false,
                publisher:     CreateShimPublisher(store),
-               eventBuffer:   newEventRingBuffer(defaultRingBufferSize),
+               eventBuffer:   buffer,
                eventSystemId: fmt.Sprintf("event-system-%d", 
time.Now().Unix()),
+               streaming:     NewEventStreaming(buffer),
        }
 }
 
+// StartService starts the event processing in the background. See the 
interface for details.
 func (ec *EventSystemImpl) StartService() {
        ec.StartServiceWithPublisher(true)
 }
 
+// StartServiceWithPublisher starts the event processing background routines.
+// Only exported for testing.
 func (ec *EventSystemImpl) StartServiceWithPublisher(withPublisher bool) {
        ec.Lock()
        defer ec.Unlock()
@@ -134,6 +187,7 @@ func (ec *EventSystemImpl) 
StartServiceWithPublisher(withPublisher bool) {
                                if event != nil {
                                        ec.Store.Store(event)
                                        ec.eventBuffer.Add(event)
+                                       ec.streaming.PublishEvent(event)
                                        
metrics.GetEventMetrics().IncEventsProcessed()
                                }
                        }
@@ -144,6 +198,7 @@ func (ec *EventSystemImpl) 
StartServiceWithPublisher(withPublisher bool) {
        }
 }
 
+// Stop stops the event system.
 func (ec *EventSystemImpl) Stop() {
        ec.Lock()
        defer ec.Unlock()
@@ -163,6 +218,7 @@ func (ec *EventSystemImpl) Stop() {
        ec.stopped = true
 }
 
+// AddEvent adds an event record to the event system. See the interface for 
details.
 func (ec *EventSystemImpl) AddEvent(event *si.EventRecord) {
        metrics.GetEventMetrics().IncEventsCreated()
        select {
@@ -192,11 +248,21 @@ func (ec *EventSystemImpl) isRestartNeeded() bool {
        return ec.readIsTrackingEnabled() != ec.trackingEnabled
 }
 
+// Restart restarts the event system, used during config update.
 func (ec *EventSystemImpl) Restart() {
        ec.Stop()
        ec.StartServiceWithPublisher(true)
 }
 
+// VisibleForTesting
+func (ec *EventSystemImpl) CloseAllStreams() {
+       ec.streaming.Lock()
+       defer ec.streaming.Unlock()
+       for consumer := range ec.streaming.eventStreams {
+               ec.streaming.removeEventStream(consumer)
+       }
+}
+
 func (ec *EventSystemImpl) reloadConfig() {
        ec.updateRequestCapacity()
 
diff --git a/pkg/scheduler/objects/common_test.go 
b/pkg/scheduler/objects/common_test.go
index 72f0e805..290a32fa 100644
--- a/pkg/scheduler/objects/common_test.go
+++ b/pkg/scheduler/objects/common_test.go
@@ -20,6 +20,7 @@ import (
        "github.com/google/btree"
 
        "github.com/apache/yunikorn-core/pkg/common/resources"
+       "github.com/apache/yunikorn-core/pkg/events"
        "github.com/apache/yunikorn-scheduler-interface/lib/go/si"
 )
 
@@ -28,6 +29,13 @@ type EventSystemMock struct {
        enabled bool
 }
 
+func (m *EventSystemMock) CreateEventStream(_ string, _ uint64) 
*events.EventStream {
+       return nil
+}
+
+func (m *EventSystemMock) RemoveStream(_ *events.EventStream) {
+}
+
 func (m *EventSystemMock) AddEvent(event *si.EventRecord) {
        m.events = append(m.events, event)
 }
diff --git a/pkg/webservice/handler_mock_test.go 
b/pkg/webservice/handler_mock_test.go
index 9759afae..439efe4f 100644
--- a/pkg/webservice/handler_mock_test.go
+++ b/pkg/webservice/handler_mock_test.go
@@ -19,6 +19,7 @@ package webservice
 
 import (
        "net/http"
+       "time"
 )
 
 // InternalMetricHistory needs resetting between tests
@@ -49,3 +50,7 @@ func (trw *MockResponseWriter) Write(bytes []byte) (int, 
error) {
 func (trw *MockResponseWriter) WriteHeader(statusCode int) {
        trw.statusCode = statusCode
 }
+
+func (trw *MockResponseWriter) SetWriteDeadline(deadline time.Time) error {
+       return nil
+}
diff --git a/pkg/webservice/handlers.go b/pkg/webservice/handlers.go
index 69c781f0..d01f12cf 100644
--- a/pkg/webservice/handlers.go
+++ b/pkg/webservice/handlers.go
@@ -28,6 +28,7 @@ import (
        "sort"
        "strconv"
        "strings"
+       "time"
 
        "github.com/julienschmidt/httprouter"
        "github.com/prometheus/client_golang/prometheus/promhttp"
@@ -1084,3 +1085,82 @@ func getEvents(w http.ResponseWriter, r *http.Request) {
                buildJSONErrorResponse(w, err.Error(), 
http.StatusInternalServerError)
        }
 }
+
+func getStream(w http.ResponseWriter, r *http.Request) {
+       writeHeaders(w)
+       eventSystem := events.GetEventSystem()
+       if !eventSystem.IsEventTrackingEnabled() {
+               buildJSONErrorResponse(w, "Event tracking is disabled", 
http.StatusInternalServerError)
+               return
+       }
+
+       f, ok := w.(http.Flusher)
+       if !ok {
+               buildJSONErrorResponse(w, "Writer does not implement 
http.Flusher", http.StatusInternalServerError)
+               return
+       }
+
+       var count uint64
+       if countStr := r.URL.Query().Get("count"); countStr != "" {
+               var err error
+               count, err = strconv.ParseUint(countStr, 10, 64)
+               if err != nil {
+                       buildJSONErrorResponse(w, err.Error(), 
http.StatusBadRequest)
+                       return
+               }
+       }
+
+       rc := http.NewResponseController(w)
+       // make sure both deadlines can be set
+       if err := rc.SetWriteDeadline(time.Time{}); err != nil {
+               log.Log(log.REST).Error("Cannot set write deadline", 
zap.Error(err))
+               buildJSONErrorResponse(w, fmt.Sprintf("Cannot set write 
deadline: %v", err), http.StatusInternalServerError)
+               return
+       }
+       if err := rc.SetReadDeadline(time.Time{}); err != nil {
+               log.Log(log.REST).Error("Cannot set read deadline", 
zap.Error(err))
+               buildJSONErrorResponse(w, fmt.Sprintf("Cannot set read 
deadline: %v", err), http.StatusInternalServerError)
+               return
+       }
+       enc := json.NewEncoder(w)
+       stream := eventSystem.CreateEventStream(r.Host, count)
+
+       // Reading events in an infinite loop until either the client 
disconnects or Yunikorn closes the channel.
+       // This results in a persistent HTTP connection where the message body 
is never closed.
+       // Write deadline is adjusted before sending data to the client.
+       for {
+               select {
+               case <-r.Context().Done():
+                       log.Log(log.REST).Info("Connection closed for event 
stream client",
+                               zap.String("host", r.Host))
+                       eventSystem.RemoveStream(stream)
+                       return
+               case e, ok := <-stream.Events:
+                       err := rc.SetWriteDeadline(time.Now().Add(5 * 
time.Second))
+                       if err != nil {
+                               // should not fail at this point
+                               log.Log(log.REST).Error("Cannot set write 
deadline", zap.Error(err))
+                               buildJSONErrorResponse(w, fmt.Sprintf("Cannot 
set write deadline: %v", err), http.StatusInternalServerError)
+                               eventSystem.RemoveStream(stream)
+                               return
+                       }
+
+                       if !ok {
+                               // the channel was closed by the event system 
itself
+                               msg := "Event stream was closed by the producer"
+                               buildJSONErrorResponse(w, msg, http.StatusOK) 
// status code is 200 at this point, cannot be changed
+                               log.Log(log.REST).Error(msg)
+                               return
+                       }
+
+                       if err := enc.Encode(e); err != nil {
+                               log.Log(log.REST).Error("Marshalling error",
+                                       zap.String("host", r.Host))
+                               buildJSONErrorResponse(w, err.Error(), 
http.StatusOK) // status code is 200 at this point, cannot be changed
+                               eventSystem.RemoveStream(stream)
+                               return
+                       }
+                       f.Flush()
+               }
+       }
+}
diff --git a/pkg/webservice/handlers_test.go b/pkg/webservice/handlers_test.go
index 54bad552..f07b97e9 100644
--- a/pkg/webservice/handlers_test.go
+++ b/pkg/webservice/handlers_test.go
@@ -21,7 +21,9 @@ package webservice
 import (
        "context"
        "encoding/json"
+       "errors"
        "fmt"
+       "io"
        "net/http"
        "net/http/httptest"
        "reflect"
@@ -1944,6 +1946,253 @@ func TestGetEventsWhenTrackingDisabled(t *testing.T) {
        readIllegalRequest(t, req, http.StatusInternalServerError, "Event 
tracking is disabled")
 }
 
+func TestGetStream(t *testing.T) {
+       ev, req := initEventsAndCreateRequest(t)
+       defer ev.Stop()
+       cancelCtx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       req = req.Clone(cancelCtx)
+
+       resp := NewResponseRecorderWithDeadline() // MockResponseWriter does 
not implement http.Flusher
+
+       go func() {
+               time.Sleep(200 * time.Millisecond)
+               ev.AddEvent(&si.EventRecord{
+                       TimestampNano: 111,
+                       ObjectID:      "app-1",
+               })
+               ev.AddEvent(&si.EventRecord{
+                       TimestampNano: 222,
+                       ObjectID:      "node-1",
+               })
+               ev.AddEvent(&si.EventRecord{
+                       TimestampNano: 333,
+                       ObjectID:      "app-2",
+               })
+               time.Sleep(200 * time.Millisecond)
+               cancel()
+       }()
+       getStream(resp, req)
+
+       output := make([]byte, 256)
+       n, err := resp.Body.Read(output)
+       assert.NilError(t, err, "cannot read response body")
+
+       lines := strings.Split(string(output[:n]), "\n")
+       assertEvent(t, lines[0], 111, "app-1")
+       assertEvent(t, lines[1], 222, "node-1")
+       assertEvent(t, lines[2], 333, "app-2")
+}
+
+func TestGetStream_StreamClosedByProducer(t *testing.T) {
+       ev, req := initEventsAndCreateRequest(t)
+       defer ev.Stop()
+       resp := NewResponseRecorderWithDeadline() // MockResponseWriter does 
not implement http.Flusher
+
+       go func() {
+               time.Sleep(200 * time.Millisecond)
+               ev.AddEvent(&si.EventRecord{
+                       TimestampNano: 111,
+                       ObjectID:      "app-1",
+               })
+               time.Sleep(100 * time.Millisecond)
+               ev.CloseAllStreams()
+       }()
+
+       getStream(resp, req)
+
+       output := make([]byte, 256)
+       n, err := resp.Body.Read(output)
+       assert.Equal(t, http.StatusOK, resp.Code)
+       assert.NilError(t, err, "cannot read response body")
+       lines := strings.Split(string(output[:n]), "\n")
+       assertEvent(t, lines[0], 111, "app-1")
+       assertYunikornError(t, lines[1], "Event stream was closed by the 
producer")
+}
+
+func TestGetStream_NotFlusherImpl(t *testing.T) {
+       var req *http.Request
+       req, err := http.NewRequest("GET", "/ws/v1/events/stream", 
strings.NewReader(""))
+       assert.NilError(t, err)
+       resp := &MockResponseWriter{}
+
+       getStream(resp, req)
+
+       assert.Assert(t, strings.Contains(string(resp.outputBytes), "Writer 
does not implement http.Flusher"))
+       assert.Equal(t, http.StatusInternalServerError, resp.statusCode)
+}
+
+func TestGetStream_Count(t *testing.T) {
+       ev, req := initEventsAndCreateRequest(t)
+       defer ev.Stop()
+       cancelCtx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       req = req.Clone(cancelCtx)
+       resp := NewResponseRecorderWithDeadline() // MockResponseWriter does 
not implement http.Flusher
+
+       // add some existing events
+       ev.AddEvent(&si.EventRecord{TimestampNano: 0})
+       ev.AddEvent(&si.EventRecord{TimestampNano: 1})
+       ev.AddEvent(&si.EventRecord{TimestampNano: 2})
+       time.Sleep(100 * time.Millisecond) // let the events propagate
+
+       // case #1: "count" not set
+       go func() {
+               time.Sleep(100 * time.Millisecond)
+               cancel()
+       }()
+       getStream(resp, req)
+       output := make([]byte, 256)
+       n, err := resp.Body.Read(output)
+       assert.Error(t, io.EOF, err.Error())
+       assert.Equal(t, 0, n)
+
+       // case #2: "count" is set to "2"
+       req, err = http.NewRequest("GET", "/ws/v1/events/stream", 
strings.NewReader(""))
+       assert.NilError(t, err)
+       cancelCtx, cancel = context.WithCancel(context.Background())
+       req = req.Clone(cancelCtx)
+       defer cancel()
+       req.URL.RawQuery = "count=2"
+       go func() {
+               time.Sleep(100 * time.Millisecond)
+               cancel()
+       }()
+       getStream(resp, req)
+       output = make([]byte, 256)
+       n, err = resp.Body.Read(output)
+       assert.NilError(t, err)
+       lines := strings.Split(string(output[:n]), "\n")
+       assertEvent(t, lines[0], 1, "")
+       assertEvent(t, lines[1], 2, "")
+
+       // case #3: illegal value
+       req, err = http.NewRequest("GET", "/ws/v1/events/stream", 
strings.NewReader(""))
+       assert.NilError(t, err)
+       cancelCtx, cancel = context.WithCancel(context.Background())
+       req = req.Clone(cancelCtx)
+       defer cancel()
+       req.URL.RawQuery = "count=xyz"
+       getStream(resp, req)
+       output = make([]byte, 256)
+       n, err = resp.Body.Read(output)
+       assert.NilError(t, err)
+       line := string(output[:n])
+       assertYunikornError(t, line, `strconv.ParseUint: parsing "xyz": invalid 
syntax`)
+}
+
+func TestGetStream_TrackingDisabled(t *testing.T) {
+       original := configs.GetConfigMap()
+       defer func() {
+               ev := events.GetEventSystem().(*events.EventSystemImpl) 
//nolint:errcheck
+               ev.Stop()
+               configs.SetConfigMap(original)
+       }()
+       configMap := map[string]string{
+               configs.CMEventTrackingEnabled: "false",
+       }
+       configs.SetConfigMap(configMap)
+       _, req := initEventsAndCreateRequest(t)
+       resp := httptest.NewRecorder()
+
+       assertGetStreamError(t, req, resp, "Event tracking is disabled")
+}
+
+func TestGetStream_NoWriteDeadline(t *testing.T) {
+       ev, req := initEventsAndCreateRequest(t)
+       defer ev.Stop()
+       resp := httptest.NewRecorder() // does not have SetWriteDeadline()
+
+       assertGetStreamError(t, req, resp, "Cannot set write deadline: feature 
not supported")
+}
+
+func TestGetStream_SetWriteDeadlineFails(t *testing.T) {
+       ev, req := initEventsAndCreateRequest(t)
+       defer ev.Stop()
+       resp := NewResponseRecorderWithDeadline()
+       resp.setWriteFailsAt = 2 // only the second SetWriteDeadline() will fail
+       resp.setWriteFails = true
+
+       go func() {
+               time.Sleep(200 * time.Millisecond)
+               ev.AddEvent(&si.EventRecord{
+                       TimestampNano: 111,
+                       ObjectID:      "app-1",
+               })
+       }()
+
+       getStream(resp, req)
+       checkGetStreamErrorResult(t, resp.Result(), "Cannot set write deadline: 
SetWriteDeadline failed")
+}
+
+func TestGetStream_SetReadDeadlineFails(t *testing.T) {
+       _, req := initEventsAndCreateRequest(t)
+       resp := NewResponseRecorderWithDeadline()
+       resp.setReadFails = true
+
+       assertGetStreamError(t, req, resp, "Cannot set read deadline: 
SetReadDeadline failed")
+}
+
+func assertGetStreamError(t *testing.T, req *http.Request, resp interface{},
+       expectedMsg string) {
+       t.Helper()
+       var response *http.Response
+
+       switch rec := resp.(type) {
+       case *ResponseRecorderWithDeadline:
+               getStream(rec, req)
+               response = rec.Result()
+       case *httptest.ResponseRecorder:
+               getStream(rec, req)
+               response = rec.Result()
+       default:
+               t.Fatalf("unknown response recorder type")
+       }
+
+       checkGetStreamErrorResult(t, response, expectedMsg)
+}
+
+func checkGetStreamErrorResult(t *testing.T, response *http.Response, 
expectedMsg string) {
+       t.Helper()
+       output := make([]byte, 256)
+       n, err := response.Body.Read(output)
+       assert.NilError(t, err)
+       line := string(output[:n])
+       assertYunikornError(t, line, expectedMsg)
+       assert.Equal(t, http.StatusInternalServerError, response.StatusCode)
+}
+
+func initEventsAndCreateRequest(t *testing.T) (*events.EventSystemImpl, 
*http.Request) {
+       t.Helper()
+       events.Init()
+       ev := events.GetEventSystem().(*events.EventSystemImpl) 
//nolint:errcheck
+       ev.StartServiceWithPublisher(false)
+
+       var req *http.Request
+       req, err := http.NewRequest("GET", "/ws/v1/events/stream", 
strings.NewReader(""))
+       assert.NilError(t, err)
+
+       return ev, req
+}
+
+func assertEvent(t *testing.T, output string, tsNano int64, objectID string) {
+       t.Helper()
+       var evt si.EventRecord
+       err := json.Unmarshal([]byte(output), &evt)
+       assert.NilError(t, err)
+       assert.Equal(t, tsNano, evt.TimestampNano)
+       assert.Equal(t, objectID, evt.ObjectID)
+}
+
+func assertYunikornError(t *testing.T, output, errMsg string) {
+       t.Helper()
+       var ykErr dao.YAPIError
+       err := json.Unmarshal([]byte(output), &ykErr)
+       assert.NilError(t, err)
+       assert.Equal(t, errMsg, ykErr.Description)
+       assert.Equal(t, errMsg, ykErr.Message)
+}
+
 func addEvents(t *testing.T) (appEvent, nodeEvent, queueEvent *si.EventRecord) 
{
        t.Helper()
        events.Init()
@@ -2186,3 +2435,32 @@ func runHealthCheckTest(t *testing.T, expected 
*dao.SchedulerHealthDAOInfo) {
                assert.Equal(t, expectedHealthCheck.DiagnosisMessage, 
actualHealthCheck.DiagnosisMessage)
        }
 }
+
+type ResponseRecorderWithDeadline struct {
+       *httptest.ResponseRecorder
+       setWriteFails   bool
+       setWriteFailsAt int
+       setWriteCalls   int
+       setReadFails    bool
+}
+
+func (rrd *ResponseRecorderWithDeadline) SetWriteDeadline(_ time.Time) error {
+       rrd.setWriteCalls++
+       if rrd.setWriteFails && rrd.setWriteCalls == rrd.setWriteFailsAt {
+               return errors.New("SetWriteDeadline failed")
+       }
+       return nil
+}
+
+func (rrd *ResponseRecorderWithDeadline) SetReadDeadline(_ time.Time) error {
+       if rrd.setReadFails {
+               return errors.New("SetReadDeadline failed")
+       }
+       return nil
+}
+
+func NewResponseRecorderWithDeadline() *ResponseRecorderWithDeadline {
+       return &ResponseRecorderWithDeadline{
+               ResponseRecorder: httptest.NewRecorder(),
+       }
+}
diff --git a/pkg/webservice/routes.go b/pkg/webservice/routes.go
index 633fa78c..6871cc79 100644
--- a/pkg/webservice/routes.go
+++ b/pkg/webservice/routes.go
@@ -188,6 +188,12 @@ var webRoutes = routes{
                "/ws/v1/events/batch",
                getEvents,
        },
+       route{
+               "Scheduler",
+               "GET",
+               "/ws/v1/events/stream",
+               getStream,
+       },
        // endpoint to retrieve CPU, Memory profiling data,
        // this works with pprof tool. By default, pprof endpoints
        // are only registered to http.DefaultServeMux. Here, we


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to