This is an automated email from the ASF dual-hosted git repository.
ccondit pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/yunikorn-core.git
The following commit(s) were added to refs/heads/master by this push:
new ac32595a [YUNIKORN-2967] Cleanup REST response headers (#994)
ac32595a is described below
commit ac32595a9712fecfd5e983e0a822fb810f3b3f14
Author: Wilfred Spiegelenburg <[email protected]>
AuthorDate: Thu Nov 14 12:42:23 2024 -0600
[YUNIKORN-2967] Cleanup REST response headers (#994)
Only respond with the allowed methods for the request, not with a
general all allowed set. OPTIONS is supported via the generic config.
Add a test to make sure a change in router does not break that.
Remove the Access-Control-Allow-Credentials as recommended in the RFC.
We also do not use cookies or authentication so not relevant to set.
Closes: #994
Signed-off-by: Craig Condit <[email protected]>
---
pkg/webservice/handlers.go | 65 ++++++++++++++++++++-------------------
pkg/webservice/handlers_test.go | 20 +++++++-----
pkg/webservice/state_dump.go | 2 +-
pkg/webservice/webservice_test.go | 49 +++++++++++++++++++++++++++--
4 files changed, 94 insertions(+), 42 deletions(-)
diff --git a/pkg/webservice/handlers.go b/pkg/webservice/handlers.go
index 93468516..73afdafc 100644
--- a/pkg/webservice/handlers.go
+++ b/pkg/webservice/handlers.go
@@ -127,7 +127,7 @@ func redirectDebug(w http.ResponseWriter, r *http.Request) {
}
func getStackInfo(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
var stack = func() []byte {
buf := make([]byte, 1024)
for {
@@ -145,7 +145,7 @@ func getStackInfo(w http.ResponseWriter, r *http.Request) {
}
func getClusterInfo(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
lists := schedulerContext.Load().GetPartitionMapClone()
clustersInfo := getClusterDAO(lists)
@@ -167,7 +167,7 @@ func validateQueue(queuePath string) error {
}
func validateConf(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
requestBytes, err := io.ReadAll(r.Body)
if err == nil {
_, err = configs.LoadSchedulerConfigFromByteArray(requestBytes)
@@ -184,11 +184,14 @@ func validateConf(w http.ResponseWriter, r *http.Request)
{
}
}
-func writeHeaders(w http.ResponseWriter) {
+func writeHeaders(w http.ResponseWriter, method string) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Allow-Credentials", "true")
- w.Header().Set("Access-Control-Allow-Methods", "GET,POST,HEAD,OPTIONS")
+ methods := "GET, OPTIONS"
+ if method == http.MethodPost {
+ methods = "OPTIONS, POST"
+ }
+ w.Header().Set("Access-Control-Allow-Methods", methods)
w.Header().Set("Access-Control-Allow-Headers",
"X-Requested-With,Content-Type,Accept,Origin")
}
@@ -233,7 +236,7 @@ func getClusterUtilJSON(partition
*scheduler.PartitionContext) []*dao.ClusterUti
}
utils = append(utils, utilization)
}
- } else if !getResource {
+ } else {
utilization := &dao.ClusterUtilDAOInfo{
ResourceType: "N/A",
Total: int64(-1),
@@ -446,7 +449,7 @@ func getNodesDAO(entries []*objects.Node)
[]*dao.NodeDAOInfo {
// Only check the default partition
// Deprecated - To be removed in next major release. Replaced with
getNodesUtilisations
func getNodeUtilisation(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
partitionContext :=
schedulerContext.Load().GetPartitionWithoutClusterID(configs.DefaultPartition)
if partitionContext == nil {
buildJSONErrorResponse(w, PartitionDoesNotExists,
http.StatusInternalServerError)
@@ -510,7 +513,7 @@ func getNodesUtilJSON(partition
*scheduler.PartitionContext, name string) *dao.N
}
func getNodeUtilisations(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
var result []*dao.PartitionNodesUtilDAOInfo
for _, part := range schedulerContext.Load().GetPartitionMapClone() {
result = append(result, getPartitionNodesUtilJSON(part))
@@ -583,7 +586,7 @@ func getPartitionNodesUtilJSON(partition
*scheduler.PartitionContext) *dao.Parti
}
func getApplicationHistory(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
// There is nothing to return but we did not really encounter a problem
if imHistory == nil {
@@ -600,7 +603,7 @@ func getApplicationHistory(w http.ResponseWriter, r
*http.Request) {
}
func getContainerHistory(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
// There is nothing to return but we did not really encounter a problem
if imHistory == nil {
@@ -617,7 +620,7 @@ func getContainerHistory(w http.ResponseWriter, r
*http.Request) {
}
func getClusterConfig(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
var marshalledConf []byte
var err error
@@ -653,7 +656,7 @@ func getClusterConfigDAO() *dao.ConfigDAOInfo {
}
func checkHealthStatus(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
// Fetch last healthCheck result
result := schedulerContext.Load().GetLastHealthCheckResult()
@@ -675,8 +678,8 @@ func checkHealthStatus(w http.ResponseWriter, r
*http.Request) {
}
}
-func getPartitions(w http.ResponseWriter, _ *http.Request) {
- writeHeaders(w)
+func getPartitions(w http.ResponseWriter, r *http.Request) {
+ writeHeaders(w, r.Method)
lists := schedulerContext.Load().GetPartitionMapClone()
partitionsInfo := getPartitionInfoDAO(lists)
@@ -686,7 +689,7 @@ func getPartitions(w http.ResponseWriter, _ *http.Request) {
}
func getPartitionQueues(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -707,7 +710,7 @@ func getPartitionQueues(w http.ResponseWriter, r
*http.Request) {
}
func getPartitionQueue(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -742,7 +745,7 @@ func getPartitionQueue(w http.ResponseWriter, r
*http.Request) {
}
func getPartitionNodes(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -761,7 +764,7 @@ func getPartitionNodes(w http.ResponseWriter, r
*http.Request) {
}
func getPartitionNode(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -786,7 +789,7 @@ func getPartitionNode(w http.ResponseWriter, r
*http.Request) {
}
func getQueueApplications(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -827,7 +830,7 @@ func getQueueApplications(w http.ResponseWriter, r
*http.Request) {
}
func getPartitionApplicationsByState(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -876,7 +879,7 @@ func getPartitionApplicationsByState(w http.ResponseWriter,
r *http.Request) {
}
func getApplication(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -924,7 +927,7 @@ func getApplication(w http.ResponseWriter, r *http.Request)
{
}
func getPartitionRules(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -943,7 +946,7 @@ func getPartitionRules(w http.ResponseWriter, r
*http.Request) {
}
func getQueueApplicationsByState(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -1180,8 +1183,8 @@ func getMetrics(w http.ResponseWriter, r *http.Request) {
promhttp.Handler().ServeHTTP(w, r)
}
-func getUsersResourceUsage(w http.ResponseWriter, _ *http.Request) {
- writeHeaders(w)
+func getUsersResourceUsage(w http.ResponseWriter, r *http.Request) {
+ writeHeaders(w, r.Method)
userManager := ugm.GetUserManager()
trackers := userManager.GetUserTrackers()
result := make([]*dao.UserResourceUsageDAOInfo, len(trackers))
@@ -1194,7 +1197,7 @@ func getUsersResourceUsage(w http.ResponseWriter, _
*http.Request) {
}
func getUserResourceUsage(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -1222,7 +1225,7 @@ func getUserResourceUsage(w http.ResponseWriter, r
*http.Request) {
}
func getGroupsResourceUsage(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
userManager := ugm.GetUserManager()
trackers := userManager.GetGroupTrackers()
result := make([]*dao.GroupResourceUsageDAOInfo, len(trackers))
@@ -1235,7 +1238,7 @@ func getGroupsResourceUsage(w http.ResponseWriter, r
*http.Request) {
}
func getGroupResourceUsage(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
vars := httprouter.ParamsFromContext(r.Context())
if vars == nil {
buildJSONErrorResponse(w, MissingParamsName,
http.StatusBadRequest)
@@ -1263,7 +1266,7 @@ func getGroupResourceUsage(w http.ResponseWriter, r
*http.Request) {
}
func getEvents(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
eventSystem := events.GetEventSystem()
if !eventSystem.IsEventTrackingEnabled() {
buildJSONErrorResponse(w, "Event tracking is disabled",
http.StatusInternalServerError)
@@ -1311,7 +1314,7 @@ func getEvents(w http.ResponseWriter, r *http.Request) {
}
func getStream(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
eventSystem := events.GetEventSystem()
if !eventSystem.IsEventTrackingEnabled() {
buildJSONErrorResponse(w, "Event tracking is disabled",
http.StatusInternalServerError)
diff --git a/pkg/webservice/handlers_test.go b/pkg/webservice/handlers_test.go
index f262140b..19cb1328 100644
--- a/pkg/webservice/handlers_test.go
+++ b/pkg/webservice/handlers_test.go
@@ -1285,16 +1285,18 @@ func TestGetPartitionQueueHandler(t *testing.T) {
func TestGetClusterInfo(t *testing.T) {
schedulerContext.Store(&scheduler.ClusterContext{})
resp := &MockResponseWriter{}
- getClusterInfo(resp, nil)
+ req, err := http.NewRequest("GET", "/ws/v1/clusters",
strings.NewReader(""))
+ assert.NilError(t, err, "error while creating http request")
+ getClusterInfo(resp, req)
var data []*dao.ClusterDAOInfo
- err := json.Unmarshal(resp.outputBytes, &data)
+ err = json.Unmarshal(resp.outputBytes, &data)
assert.NilError(t, err)
assert.Equal(t, 0, len(data))
setup(t, configTwoLevelQueues, 2)
resp = &MockResponseWriter{}
- getClusterInfo(resp, nil)
+ getClusterInfo(resp, req)
err = json.Unmarshal(resp.outputBytes, &data)
assert.NilError(t, err)
assert.Equal(t, 2, len(data))
@@ -1412,11 +1414,11 @@ func TestGetPartitionNode(t *testing.T) {
_, allocCreated, err := partition.UpdateAllocation(alloc1)
assert.NilError(t, err, "add alloc-1 should not have failed")
assert.Check(t, allocCreated)
- falloc1 := newForeignAlloc("foreign-1", "", node1ID, resAlloc1,
siCommon.AllocTypeDefault, 0)
+ falloc1 := newForeignAlloc("foreign-1", node1ID, resAlloc1,
siCommon.AllocTypeDefault, 0)
_, allocCreated, err = partition.UpdateAllocation(falloc1)
assert.NilError(t, err, "add falloc-1 should not have failed")
assert.Check(t, allocCreated)
- falloc2 := newForeignAlloc("foreign-2", "", node1ID, resAlloc2,
siCommon.AllocTypeStatic, 123)
+ falloc2 := newForeignAlloc("foreign-2", node1ID, resAlloc2,
siCommon.AllocTypeStatic, 123)
_, allocCreated, err = partition.UpdateAllocation(falloc2)
assert.NilError(t, err, "add falloc-2 should not have failed")
assert.Check(t, allocCreated)
@@ -1746,6 +1748,7 @@ func checkGetQueueAppByState(t *testing.T, partition,
queue, state, status strin
url =
fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s",
partition, queue, state, status)
}
req, err := http.NewRequest("GET", url, strings.NewReader(""))
+ assert.NilError(t, err, "unexpected error creating request")
req = req.WithContext(context.WithValue(req.Context(),
httprouter.ParamsKey, httprouter.Params{
httprouter.Param{Key: "partition", Value: partition},
httprouter.Param{Key: "queue", Value: queue},
@@ -1780,6 +1783,7 @@ func checkGetQueueAppByIllegalStateOrStatus(t *testing.T,
partition, queue, stat
url =
fmt.Sprintf("/ws/v1/partition/%s/queue/%s/applications/%s?status=%s",
partition, queue, state, status)
}
req, err := http.NewRequest("GET", url, strings.NewReader(""))
+ assert.NilError(t, err, "unexpected error creating request")
req = req.WithContext(context.WithValue(req.Context(),
httprouter.ParamsKey, httprouter.Params{
httprouter.Param{Key: "partition", Value: partition},
httprouter.Param{Key: "queue", Value: queue},
@@ -2115,9 +2119,9 @@ func TestFullStateDumpPath(t *testing.T) {
prepareSchedulerContext(t)
partitionContext := schedulerContext.Load().GetPartitionMapClone()
- context := partitionContext[normalizedPartitionName]
+ ctx := partitionContext[normalizedPartitionName]
app := newApplication("appID", normalizedPartitionName, "root.default",
rmID, security.UserGroup{})
- err := context.AddApplication(app)
+ err := ctx.AddApplication(app)
assert.NilError(t, err, "failed to add Application to partition")
imHistory = history.NewInternalMetricsHistory(5)
@@ -3053,7 +3057,7 @@ func newAlloc(allocationKey string, appID string, nodeID
string, resAlloc *resou
})
}
-func newForeignAlloc(allocationKey string, appID string, nodeID string,
resAlloc *resources.Resource, fType string, priority int32) *objects.Allocation
{
+func newForeignAlloc(allocationKey string, nodeID string, resAlloc
*resources.Resource, fType string, priority int32) *objects.Allocation {
return objects.NewAllocationFromSI(&si.Allocation{
AllocationKey: allocationKey,
NodeID: nodeID,
diff --git a/pkg/webservice/state_dump.go b/pkg/webservice/state_dump.go
index 5cb7efc0..a37ef2cb 100644
--- a/pkg/webservice/state_dump.go
+++ b/pkg/webservice/state_dump.go
@@ -54,7 +54,7 @@ type AggregatedStateInfo struct {
}
func getFullStateDump(w http.ResponseWriter, r *http.Request) {
- writeHeaders(w)
+ writeHeaders(w, r.Method)
if err := doStateDump(w); err != nil {
buildJSONErrorResponse(w, err.Error(),
http.StatusInternalServerError)
}
diff --git a/pkg/webservice/webservice_test.go
b/pkg/webservice/webservice_test.go
index 1b723f4a..5c8cb66b 100644
--- a/pkg/webservice/webservice_test.go
+++ b/pkg/webservice/webservice_test.go
@@ -30,6 +30,8 @@ import (
"github.com/apache/yunikorn-core/pkg/scheduler"
)
+const base = "http://localhost:9080"
+
func Test_RedirectDebugHandler(t *testing.T) {
defer ResetIMHistory()
s := NewWebApp(&scheduler.ClusterContext{},
history.NewInternalMetricsHistory(5))
@@ -40,7 +42,6 @@ func Test_RedirectDebugHandler(t *testing.T) {
t.Fatal("failed to stop webapp")
}
}(s)
- base := "http://localhost:9080"
tests := []struct {
name string
reqURL string
@@ -76,7 +77,6 @@ func Test_RouterHandling(t *testing.T) {
t.Fatal("failed to stop webapp")
}
}(s)
- base := "http://localhost:9080"
client := &http.Client{}
// unsupported POST
resp, err := client.Post(base+"/ws/v1/clusters", "application/json;
charset=UTF-8", nil)
@@ -105,3 +105,48 @@ func Test_RouterHandling(t *testing.T) {
_ = resp.Body.Close()
assert.Equal(t, resp.StatusCode, http.StatusOK, "expected OK")
}
+
+func Test_HeaderChecks(t *testing.T) {
+ s := NewWebApp(&scheduler.ClusterContext{}, nil)
+ s.StartWebApp()
+ defer func(s *WebService) {
+ err := s.StopWebApp()
+ if err != nil {
+ t.Fatal("failed to stop webapp")
+ }
+ }(s)
+ client := http.DefaultClient
+ tests := []struct {
+ name string
+ reqURL string
+ method string
+ expected string
+ }{
+ {"get options", "/ws/v1/clusters", http.MethodOptions, "GET,
OPTIONS"},
+ {"get", "/ws/v1/clusters", http.MethodGet, "GET, OPTIONS"},
+ {"post options", "/ws/v1/validate-conf", http.MethodOptions,
"OPTIONS, POST"},
+ {"post", "/ws/v1/validate-conf", http.MethodPost, "OPTIONS,
POST"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req, err := http.NewRequest(tt.method, base+tt.reqURL,
nil)
+ assert.NilError(t, err, "unexpected error creating
request")
+ var resp *http.Response
+ resp, err = client.Do(req)
+ assert.NilError(t, err, "unexpected error executing
request")
+ assert.Equal(t, resp.StatusCode, http.StatusOK,
"expected OK")
+ switch tt.method {
+ case http.MethodGet, http.MethodPost:
+ assert.Equal(t,
resp.Header.Get("Access-Control-Allow-Methods"), tt.expected, "wrong methods
returned")
+ case http.MethodOptions:
+ // OPTIONS requests are handled by default via
httpdrouter, not defined in the routes
+ assert.Equal(t, resp.Header.Get("Allow"),
tt.expected, "expected only get and options to be returned")
+ }
+ var body []byte
+ body, err = io.ReadAll(resp.Body)
+ _ = resp.Body.Close()
+ assert.NilError(t, err, "unexpected error reading body")
+ assert.Assert(t, body != nil, "expected body with
status text")
+ })
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]