This is an automated email from the ASF dual-hosted git repository.

ccondit pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/yunikorn-core.git


The following commit(s) were added to refs/heads/master by this push:
     new ac32595a [YUNIKORN-2967] Cleanup REST response headers (#994)
ac32595a is described below

commit ac32595a9712fecfd5e983e0a822fb810f3b3f14
Author: Wilfred Spiegelenburg <[email protected]>
AuthorDate: Thu Nov 14 12:42:23 2024 -0600

    [YUNIKORN-2967] Cleanup REST response headers (#994)
    
    Only respond with the allowed methods for the request, not with a
    general all allowed set. OPTIONS is supported via the generic config.
    Add a test to make sure a change in router does not break that.
    
    Remove the Access-Control-Allow-Credentials as recommended in the RFC.
    We also do not use cookies or authentication so not relevant to set.
    
    Closes: #994
    
    Signed-off-by: Craig Condit <[email protected]>
---
 pkg/webservice/handlers.go        | 65 ++++++++++++++++++++-------------------
 pkg/webservice/handlers_test.go   | 20 +++++++-----
 pkg/webservice/state_dump.go      |  2 +-
 pkg/webservice/webservice_test.go | 49 +++++++++++++++++++++++++++--
 4 files changed, 94 insertions(+), 42 deletions(-)

diff --git a/pkg/webservice/handlers.go b/pkg/webservice/handlers.go
index 93468516..73afdafc 100644
--- a/pkg/webservice/handlers.go
+++ b/pkg/webservice/handlers.go
@@ -127,7 +127,7 @@ func redirectDebug(w http.ResponseWriter, r *http.Request) {
 }
 
 func getStackInfo(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        var stack = func() []byte {
                buf := make([]byte, 1024)
                for {
@@ -145,7 +145,7 @@ func getStackInfo(w http.ResponseWriter, r *http.Request) {
 }
 
 func getClusterInfo(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
 
        lists := schedulerContext.Load().GetPartitionMapClone()
        clustersInfo := getClusterDAO(lists)
@@ -167,7 +167,7 @@ func validateQueue(queuePath string) error {
 }
 
 func validateConf(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        requestBytes, err := io.ReadAll(r.Body)
        if err == nil {
                _, err = configs.LoadSchedulerConfigFromByteArray(requestBytes)
@@ -184,11 +184,14 @@ func validateConf(w http.ResponseWriter, r *http.Request) 
{
        }
 }
 
-func writeHeaders(w http.ResponseWriter) {
+func writeHeaders(w http.ResponseWriter, method string) {
        w.Header().Set("Content-Type", "application/json; charset=UTF-8")
        w.Header().Set("Access-Control-Allow-Origin", "*")
-       w.Header().Set("Access-Control-Allow-Credentials", "true")
-       w.Header().Set("Access-Control-Allow-Methods", "GET,POST,HEAD,OPTIONS")
+       methods := "GET, OPTIONS"
+       if method == http.MethodPost {
+               methods = "OPTIONS, POST"
+       }
+       w.Header().Set("Access-Control-Allow-Methods", methods)
        w.Header().Set("Access-Control-Allow-Headers", 
"X-Requested-With,Content-Type,Accept,Origin")
 }
 
@@ -233,7 +236,7 @@ func getClusterUtilJSON(partition 
*scheduler.PartitionContext) []*dao.ClusterUti
                        }
                        utils = append(utils, utilization)
                }
-       } else if !getResource {
+       } else {
                utilization := &dao.ClusterUtilDAOInfo{
                        ResourceType: "N/A",
                        Total:        int64(-1),
@@ -446,7 +449,7 @@ func getNodesDAO(entries []*objects.Node) 
[]*dao.NodeDAOInfo {
 // Only check the default partition
 // Deprecated - To be removed in next major release. Replaced with 
getNodesUtilisations
 func getNodeUtilisation(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        partitionContext := 
schedulerContext.Load().GetPartitionWithoutClusterID(configs.DefaultPartition)
        if partitionContext == nil {
                buildJSONErrorResponse(w, PartitionDoesNotExists, 
http.StatusInternalServerError)
@@ -510,7 +513,7 @@ func getNodesUtilJSON(partition 
*scheduler.PartitionContext, name string) *dao.N
 }
 
 func getNodeUtilisations(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        var result []*dao.PartitionNodesUtilDAOInfo
        for _, part := range schedulerContext.Load().GetPartitionMapClone() {
                result = append(result, getPartitionNodesUtilJSON(part))
@@ -583,7 +586,7 @@ func getPartitionNodesUtilJSON(partition 
*scheduler.PartitionContext) *dao.Parti
 }
 
 func getApplicationHistory(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
 
        // There is nothing to return but we did not really encounter a problem
        if imHistory == nil {
@@ -600,7 +603,7 @@ func getApplicationHistory(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getContainerHistory(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
 
        // There is nothing to return but we did not really encounter a problem
        if imHistory == nil {
@@ -617,7 +620,7 @@ func getContainerHistory(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getClusterConfig(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
 
        var marshalledConf []byte
        var err error
@@ -653,7 +656,7 @@ func getClusterConfigDAO() *dao.ConfigDAOInfo {
 }
 
 func checkHealthStatus(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
 
        // Fetch last healthCheck result
        result := schedulerContext.Load().GetLastHealthCheckResult()
@@ -675,8 +678,8 @@ func checkHealthStatus(w http.ResponseWriter, r 
*http.Request) {
        }
 }
 
-func getPartitions(w http.ResponseWriter, _ *http.Request) {
-       writeHeaders(w)
+func getPartitions(w http.ResponseWriter, r *http.Request) {
+       writeHeaders(w, r.Method)
 
        lists := schedulerContext.Load().GetPartitionMapClone()
        partitionsInfo := getPartitionInfoDAO(lists)
@@ -686,7 +689,7 @@ func getPartitions(w http.ResponseWriter, _ *http.Request) {
 }
 
 func getPartitionQueues(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -707,7 +710,7 @@ func getPartitionQueues(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getPartitionQueue(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -742,7 +745,7 @@ func getPartitionQueue(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getPartitionNodes(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -761,7 +764,7 @@ func getPartitionNodes(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getPartitionNode(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -786,7 +789,7 @@ func getPartitionNode(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getQueueApplications(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -827,7 +830,7 @@ func getQueueApplications(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getPartitionApplicationsByState(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -876,7 +879,7 @@ func getPartitionApplicationsByState(w http.ResponseWriter, 
r *http.Request) {
 }
 
 func getApplication(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -924,7 +927,7 @@ func getApplication(w http.ResponseWriter, r *http.Request) 
{
 }
 
 func getPartitionRules(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -943,7 +946,7 @@ func getPartitionRules(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getQueueApplicationsByState(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -1180,8 +1183,8 @@ func getMetrics(w http.ResponseWriter, r *http.Request) {
        promhttp.Handler().ServeHTTP(w, r)
 }
 
-func getUsersResourceUsage(w http.ResponseWriter, _ *http.Request) {
-       writeHeaders(w)
+func getUsersResourceUsage(w http.ResponseWriter, r *http.Request) {
+       writeHeaders(w, r.Method)
        userManager := ugm.GetUserManager()
        trackers := userManager.GetUserTrackers()
        result := make([]*dao.UserResourceUsageDAOInfo, len(trackers))
@@ -1194,7 +1197,7 @@ func getUsersResourceUsage(w http.ResponseWriter, _ 
*http.Request) {
 }
 
 func getUserResourceUsage(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -1222,7 +1225,7 @@ func getUserResourceUsage(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getGroupsResourceUsage(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        userManager := ugm.GetUserManager()
        trackers := userManager.GetGroupTrackers()
        result := make([]*dao.GroupResourceUsageDAOInfo, len(trackers))
@@ -1235,7 +1238,7 @@ func getGroupsResourceUsage(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getGroupResourceUsage(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        vars := httprouter.ParamsFromContext(r.Context())
        if vars == nil {
                buildJSONErrorResponse(w, MissingParamsName, 
http.StatusBadRequest)
@@ -1263,7 +1266,7 @@ func getGroupResourceUsage(w http.ResponseWriter, r 
*http.Request) {
 }
 
 func getEvents(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        eventSystem := events.GetEventSystem()
        if !eventSystem.IsEventTrackingEnabled() {
                buildJSONErrorResponse(w, "Event tracking is disabled", 
http.StatusInternalServerError)
@@ -1311,7 +1314,7 @@ func getEvents(w http.ResponseWriter, r *http.Request) {
 }
 
 func getStream(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        eventSystem := events.GetEventSystem()
        if !eventSystem.IsEventTrackingEnabled() {
                buildJSONErrorResponse(w, "Event tracking is disabled", 
http.StatusInternalServerError)
diff --git a/pkg/webservice/handlers_test.go b/pkg/webservice/handlers_test.go
index f262140b..19cb1328 100644
--- a/pkg/webservice/handlers_test.go
+++ b/pkg/webservice/handlers_test.go
@@ -1285,16 +1285,18 @@ func TestGetPartitionQueueHandler(t *testing.T) {
 func TestGetClusterInfo(t *testing.T) {
        schedulerContext.Store(&scheduler.ClusterContext{})
        resp := &MockResponseWriter{}
-       getClusterInfo(resp, nil)
+       req, err := http.NewRequest("GET", "/ws/v1/clusters", 
strings.NewReader(""))
+       assert.NilError(t, err, "error while creating http request")
+       getClusterInfo(resp, req)
        var data []*dao.ClusterDAOInfo
-       err := json.Unmarshal(resp.outputBytes, &data)
+       err = json.Unmarshal(resp.outputBytes, &data)
        assert.NilError(t, err)
        assert.Equal(t, 0, len(data))
 
        setup(t, configTwoLevelQueues, 2)
 
        resp = &MockResponseWriter{}
-       getClusterInfo(resp, nil)
+       getClusterInfo(resp, req)
        err = json.Unmarshal(resp.outputBytes, &data)
        assert.NilError(t, err)
        assert.Equal(t, 2, len(data))
@@ -1412,11 +1414,11 @@ func TestGetPartitionNode(t *testing.T) {
        _, allocCreated, err := partition.UpdateAllocation(alloc1)
        assert.NilError(t, err, "add alloc-1 should not have failed")
        assert.Check(t, allocCreated)
-       falloc1 := newForeignAlloc("foreign-1", "", node1ID, resAlloc1, 
siCommon.AllocTypeDefault, 0)
+       falloc1 := newForeignAlloc("foreign-1", node1ID, resAlloc1, 
siCommon.AllocTypeDefault, 0)
        _, allocCreated, err = partition.UpdateAllocation(falloc1)
        assert.NilError(t, err, "add falloc-1 should not have failed")
        assert.Check(t, allocCreated)
-       falloc2 := newForeignAlloc("foreign-2", "", node1ID, resAlloc2, 
siCommon.AllocTypeStatic, 123)
+       falloc2 := newForeignAlloc("foreign-2", node1ID, resAlloc2, 
siCommon.AllocTypeStatic, 123)
        _, allocCreated, err = partition.UpdateAllocation(falloc2)
        assert.NilError(t, err, "add falloc-2 should not have failed")
        assert.Check(t, allocCreated)
@@ -1746,6 +1748,7 @@ func checkGetQueueAppByState(t *testing.T, partition, 
queue, state, status strin
                url = 
fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s", 
partition, queue, state, status)
        }
        req, err := http.NewRequest("GET", url, strings.NewReader(""))
+       assert.NilError(t, err, "unexpected error creating request")
        req = req.WithContext(context.WithValue(req.Context(), 
httprouter.ParamsKey, httprouter.Params{
                httprouter.Param{Key: "partition", Value: partition},
                httprouter.Param{Key: "queue", Value: queue},
@@ -1780,6 +1783,7 @@ func checkGetQueueAppByIllegalStateOrStatus(t *testing.T, 
partition, queue, stat
                url = 
fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s", 
partition, queue, state, status)
        }
        req, err := http.NewRequest("GET", url, strings.NewReader(""))
+       assert.NilError(t, err, "unexpected error creating request")
        req = req.WithContext(context.WithValue(req.Context(), 
httprouter.ParamsKey, httprouter.Params{
                httprouter.Param{Key: "partition", Value: partition},
                httprouter.Param{Key: "queue", Value: queue},
@@ -2115,9 +2119,9 @@ func TestFullStateDumpPath(t *testing.T) {
        prepareSchedulerContext(t)
 
        partitionContext := schedulerContext.Load().GetPartitionMapClone()
-       context := partitionContext[normalizedPartitionName]
+       ctx := partitionContext[normalizedPartitionName]
        app := newApplication("appID", normalizedPartitionName, "root.default", 
rmID, security.UserGroup{})
-       err := context.AddApplication(app)
+       err := ctx.AddApplication(app)
        assert.NilError(t, err, "failed to add Application to partition")
 
        imHistory = history.NewInternalMetricsHistory(5)
@@ -3053,7 +3057,7 @@ func newAlloc(allocationKey string, appID string, nodeID 
string, resAlloc *resou
        })
 }
 
-func newForeignAlloc(allocationKey string, appID string, nodeID string, 
resAlloc *resources.Resource, fType string, priority int32) *objects.Allocation 
{
+func newForeignAlloc(allocationKey string, nodeID string, resAlloc 
*resources.Resource, fType string, priority int32) *objects.Allocation {
        return objects.NewAllocationFromSI(&si.Allocation{
                AllocationKey:    allocationKey,
                NodeID:           nodeID,
diff --git a/pkg/webservice/state_dump.go b/pkg/webservice/state_dump.go
index 5cb7efc0..a37ef2cb 100644
--- a/pkg/webservice/state_dump.go
+++ b/pkg/webservice/state_dump.go
@@ -54,7 +54,7 @@ type AggregatedStateInfo struct {
 }
 
 func getFullStateDump(w http.ResponseWriter, r *http.Request) {
-       writeHeaders(w)
+       writeHeaders(w, r.Method)
        if err := doStateDump(w); err != nil {
                buildJSONErrorResponse(w, err.Error(), 
http.StatusInternalServerError)
        }
diff --git a/pkg/webservice/webservice_test.go 
b/pkg/webservice/webservice_test.go
index 1b723f4a..5c8cb66b 100644
--- a/pkg/webservice/webservice_test.go
+++ b/pkg/webservice/webservice_test.go
@@ -30,6 +30,8 @@ import (
        "github.com/apache/yunikorn-core/pkg/scheduler"
 )
 
+const base = "http://localhost:9080";
+
 func Test_RedirectDebugHandler(t *testing.T) {
        defer ResetIMHistory()
        s := NewWebApp(&scheduler.ClusterContext{}, 
history.NewInternalMetricsHistory(5))
@@ -40,7 +42,6 @@ func Test_RedirectDebugHandler(t *testing.T) {
                        t.Fatal("failed to stop webapp")
                }
        }(s)
-       base := "http://localhost:9080";
        tests := []struct {
                name     string
                reqURL   string
@@ -76,7 +77,6 @@ func Test_RouterHandling(t *testing.T) {
                        t.Fatal("failed to stop webapp")
                }
        }(s)
-       base := "http://localhost:9080";
        client := &http.Client{}
        // unsupported POST
        resp, err := client.Post(base+"/ws/v1/clusters", "application/json; 
charset=UTF-8", nil)
@@ -105,3 +105,48 @@ func Test_RouterHandling(t *testing.T) {
        _ = resp.Body.Close()
        assert.Equal(t, resp.StatusCode, http.StatusOK, "expected OK")
 }
+
+func Test_HeaderChecks(t *testing.T) {
+       s := NewWebApp(&scheduler.ClusterContext{}, nil)
+       s.StartWebApp()
+       defer func(s *WebService) {
+               err := s.StopWebApp()
+               if err != nil {
+                       t.Fatal("failed to stop webapp")
+               }
+       }(s)
+       client := http.DefaultClient
+       tests := []struct {
+               name     string
+               reqURL   string
+               method   string
+               expected string
+       }{
+               {"get options", "/ws/v1/clusters", http.MethodOptions, "GET, 
OPTIONS"},
+               {"get", "/ws/v1/clusters", http.MethodGet, "GET, OPTIONS"},
+               {"post options", "/ws/v1/validate-conf", http.MethodOptions, 
"OPTIONS, POST"},
+               {"post", "/ws/v1/validate-conf", http.MethodPost, "OPTIONS, 
POST"},
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       req, err := http.NewRequest(tt.method, base+tt.reqURL, 
nil)
+                       assert.NilError(t, err, "unexpected error creating 
request")
+                       var resp *http.Response
+                       resp, err = client.Do(req)
+                       assert.NilError(t, err, "unexpected error executing 
request")
+                       assert.Equal(t, resp.StatusCode, http.StatusOK, 
"expected OK")
+                       switch tt.method {
+                       case http.MethodGet, http.MethodPost:
+                               assert.Equal(t, 
resp.Header.Get("Access-Control-Allow-Methods"), tt.expected, "wrong methods 
returned")
+                       case http.MethodOptions:
+                               // OPTIONS requests are handled by default via 
httpdrouter, not defined in the routes
+                               assert.Equal(t, resp.Header.Get("Allow"), 
tt.expected, "expected only get and options to be returned")
+                       }
+                       var body []byte
+                       body, err = io.ReadAll(resp.Body)
+                       _ = resp.Body.Close()
+                       assert.NilError(t, err, "unexpected error reading body")
+                       assert.Assert(t, body != nil, "expected body with 
status text")
+               })
+       }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to