This is an automated email from the ASF dual-hosted git repository.

pbacsko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/yunikorn-core.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f343cc3 [YUNIKORN-1937] Extract URL query properly in getEvents() 
(#628)
2f343cc3 is described below

commit 2f343cc387e0eb8cc056d8e6433a08bd9817b806
Author: Peter Bacsko <[email protected]>
AuthorDate: Mon Aug 28 15:25:50 2023 +0200

    [YUNIKORN-1937] Extract URL query properly in getEvents() (#628)
    
    Closes: #628
    
    Signed-off-by: Peter Bacsko <[email protected]>
---
 pkg/webservice/handlers.go      | 50 ++++++++++++++++++--------------------
 pkg/webservice/handlers_test.go | 54 ++++++++++++++++++++++++++---------------
 2 files changed, 59 insertions(+), 45 deletions(-)

diff --git a/pkg/webservice/handlers.go b/pkg/webservice/handlers.go
index aea31c8d..6942b42c 100644
--- a/pkg/webservice/handlers.go
+++ b/pkg/webservice/handlers.go
@@ -874,40 +874,38 @@ func getGroupResourceUsage(w http.ResponseWriter, r 
*http.Request) {
 func getEvents(w http.ResponseWriter, r *http.Request) {
        writeHeaders(w)
        eventSystem := events.GetEventSystem()
-       if eventSystem == nil {
-               buildJSONErrorResponse(w, "Event system is disabled", 
http.StatusBadRequest)
+       if !eventSystem.IsEventTrackingEnabled() {
+               buildJSONErrorResponse(w, "Event tracking is disabled", 
http.StatusBadRequest)
                return
        }
 
        count := uint64(10000)
        var start uint64
-       vars := httprouter.ParamsFromContext(r.Context())
-       if vars != nil {
-               if countStr := vars.ByName("count"); countStr != "" {
-                       c, err := strconv.ParseInt(countStr, 10, 64)
-                       if err != nil {
-                               buildJSONErrorResponse(w, err.Error(), 
http.StatusBadRequest)
-                               return
-                       }
-                       if c <= 0 {
-                               buildJSONErrorResponse(w, fmt.Sprintf("Illegal 
number of events: %d", c), http.StatusBadRequest)
-                               return
-                       }
-                       count = uint64(c)
+
+       if countStr := r.URL.Query().Get("count"); countStr != "" {
+               c, err := strconv.ParseInt(countStr, 10, 64)
+               if err != nil {
+                       buildJSONErrorResponse(w, err.Error(), 
http.StatusBadRequest)
+                       return
+               }
+               if c <= 0 {
+                       buildJSONErrorResponse(w, fmt.Sprintf("Illegal number 
of events: %d", c), http.StatusBadRequest)
+                       return
                }
+               count = uint64(c)
+       }
 
-               if startStr := vars.ByName("start"); startStr != "" {
-                       i, err := strconv.ParseInt(startStr, 10, 64)
-                       if err != nil {
-                               buildJSONErrorResponse(w, err.Error(), 
http.StatusBadRequest)
-                               return
-                       }
-                       if i < 0 {
-                               buildJSONErrorResponse(w, fmt.Sprintf("Illegal 
id: %d", i), http.StatusBadRequest)
-                               return
-                       }
-                       start = uint64(i)
+       if startStr := r.URL.Query().Get("start"); startStr != "" {
+               i, err := strconv.ParseInt(startStr, 10, 64)
+               if err != nil {
+                       buildJSONErrorResponse(w, err.Error(), 
http.StatusBadRequest)
+                       return
+               }
+               if i < 0 {
+                       buildJSONErrorResponse(w, fmt.Sprintf("Illegal id: %d", 
i), http.StatusBadRequest)
+                       return
                }
+               start = uint64(i)
        }
 
        records, lowestID, highestID := eventSystem.GetEventsFromID(start, 
count)
diff --git a/pkg/webservice/handlers_test.go b/pkg/webservice/handlers_test.go
index 12e9fbf5..df8e23be 100644
--- a/pkg/webservice/handlers_test.go
+++ b/pkg/webservice/handlers_test.go
@@ -1337,18 +1337,34 @@ func TestGetEvents(t *testing.T) {
 
        checkAllEvents(t, []*si.EventRecord{appEvent, nodeEvent, queueEvent})
 
-       checkSingleEvent(t, appEvent, httprouter.Params{
-               httprouter.Param{Key: "count", Value: "1"},
-       })
-       checkSingleEvent(t, queueEvent, httprouter.Params{
-               httprouter.Param{Key: "start", Value: "2"},
-       })
+       checkSingleEvent(t, appEvent, "count=1")
+       checkSingleEvent(t, queueEvent, "start=2")
 
        // illegal requests
-       checkIllegalBatchRequest(t, "count", "xyz", "strconv.ParseInt: parsing 
\"xyz\": invalid syntax")
-       checkIllegalBatchRequest(t, "count", "-100", "Illegal number of events: 
-100")
-       checkIllegalBatchRequest(t, "start", "xyz", "strconv.ParseInt: parsing 
\"xyz\": invalid syntax")
-       checkIllegalBatchRequest(t, "start", "-100", "Illegal id: -100")
+       checkIllegalBatchRequest(t, "count=xyz", "strconv.ParseInt: parsing 
\"xyz\": invalid syntax")
+       checkIllegalBatchRequest(t, "count=-100", "Illegal number of events: 
-100")
+       checkIllegalBatchRequest(t, "start=xyz", "strconv.ParseInt: parsing 
\"xyz\": invalid syntax")
+       checkIllegalBatchRequest(t, "start=-100", "Illegal id: -100")
+}
+
+func TestGetEventsWhenTrackingDisabled(t *testing.T) {
+       original := configs.GetConfigMap()
+       defer func() {
+               ev := events.GetEventSystem().(*events.EventSystemImpl) 
//nolint:errcheck
+               ev.Stop()
+               configs.SetConfigMap(original)
+       }()
+       configMap := map[string]string{
+               configs.CMEventTrackingEnabled: "false",
+       }
+       configs.SetConfigMap(configMap)
+       events.Init()
+       ev := events.GetEventSystem().(*events.EventSystemImpl) 
//nolint:errcheck
+       ev.StartServiceWithPublisher(false)
+
+       req, err := http.NewRequest("GET", "/ws/v1/events/batch", 
strings.NewReader(""))
+       assert.NilError(t, err)
+       readIllegalRequest(t, req, "Event tracking is disabled")
 }
 
 func addEvents(t *testing.T) (appEvent, nodeEvent, queueEvent *si.EventRecord) 
{
@@ -1400,21 +1416,21 @@ func addEvents(t *testing.T) (appEvent, nodeEvent, 
queueEvent *si.EventRecord) {
        return appEvent, nodeEvent, queueEvent
 }
 
-func checkSingleEvent(t *testing.T, event *si.EventRecord, params 
httprouter.Params) {
-       req, err := http.NewRequest("GET", "/ws/v1/events/batch/", 
strings.NewReader(""))
+func checkSingleEvent(t *testing.T, event *si.EventRecord, query string) {
+       req, err := http.NewRequest("GET", "/ws/v1/events/batch?"+query, 
strings.NewReader(""))
        assert.NilError(t, err)
-       req = req.WithContext(context.WithValue(req.Context(), 
httprouter.ParamsKey, params))
        eventDao := getEventRecordDao(t, req)
        assert.Equal(t, 1, len(eventDao.EventRecords))
        compareEvents(t, event, eventDao.EventRecords[0])
 }
 
-func checkIllegalBatchRequest(t *testing.T, key, value, msg string) {
-       req, err := http.NewRequest("GET", "/ws/v1/events/batch/", 
strings.NewReader(""))
+func checkIllegalBatchRequest(t *testing.T, query, msg string) {
+       req, err := http.NewRequest("GET", "/ws/v1/events/batch?"+query, 
strings.NewReader(""))
        assert.NilError(t, err)
-       req = req.WithContext(context.WithValue(req.Context(), 
httprouter.ParamsKey, httprouter.Params{
-               httprouter.Param{Key: key, Value: value},
-       }))
+       readIllegalRequest(t, req, msg)
+}
+
+func readIllegalRequest(t *testing.T, req *http.Request, errMsg string) {
        rr := httptest.NewRecorder()
        handler := http.HandlerFunc(getEvents)
        handler.ServeHTTP(rr, req)
@@ -1425,7 +1441,7 @@ func checkIllegalBatchRequest(t *testing.T, key, value, 
msg string) {
        var errObject dao.YAPIError
        err = json.Unmarshal(jsonBytes[:n], &errObject)
        assert.NilError(t, err, "cannot unmarshal events dao")
-       assert.Assert(t, strings.Contains(errObject.Message, msg))
+       assert.Assert(t, strings.Contains(errObject.Message, errMsg))
 }
 
 func checkAllEvents(t *testing.T, events []*si.EventRecord) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to