This is an automated email from the ASF dual-hosted git repository.
xianjin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new cf9afab4 [#545][operator] feat: support setting runtime class name and
env for rss (#548)
cf9afab4 is described below
commit cf9afab471a9595d8298ebd65c4d25c080ad4bbf
Author: jasonawang <[email protected]>
AuthorDate: Mon Feb 6 10:17:12 2023 +0800
[#545][operator] feat: support setting runtime class name and env for rss
(#548)
### What changes were proposed in this pull request?
Support setting custom runtime class name and env for rss objects.
### Why are the changes needed?
More flexibility.
Fix issue #545.
### Does this PR introduce _any_ user-facing change?
For RSS cluster admin, they can set custom runtime class name and env of
pods.
### How was this patch tested?
Manually verified and an added UT.
---
.../uniffle/v1alpha1/remoteshuffleservice_types.go | 17 +-
.../api/uniffle/v1alpha1/zz_generated.deepcopy.go | 19 +-
.../uniffle.apache.org_remoteshuffleservices.yaml | 16 ++
.../pkg/controller/sync/coordinator/coordinator.go | 9 +-
.../sync/coordinator/coordinator_test.go | 152 ++++++++++++++--
.../controller/sync/shuffleserver/shuffleserver.go | 27 ++-
.../sync/shuffleserver/shuffleserver_test.go | 158 +++++++++++++++--
.../operator/pkg/webhook/inspector/rss.go | 31 ++++
.../operator/pkg/webhook/inspector/rss_test.go | 196 ++++++++++++++-------
9 files changed, 520 insertions(+), 105 deletions(-)
diff --git
a/deploy/kubernetes/operator/api/uniffle/v1alpha1/remoteshuffleservice_types.go
b/deploy/kubernetes/operator/api/uniffle/v1alpha1/remoteshuffleservice_types.go
index e5025665..041dec9b 100644
---
a/deploy/kubernetes/operator/api/uniffle/v1alpha1/remoteshuffleservice_types.go
+++
b/deploy/kubernetes/operator/api/uniffle/v1alpha1/remoteshuffleservice_types.go
@@ -204,6 +204,18 @@ type RSSPodSpec struct {
// Labels represents labels to be added in coordinators or shuffle
servers' pods.
// +optional
Labels map[string]string `json:"labels,omitempty"`
+
+ // RuntimeClassName refers to a RuntimeClass object in the node.k8s.io
group, which should be used
+ // to run this pod. If no RuntimeClass resource matches the named
class, the pod will not be run.
+ // If unset or empty, the "legacy" RuntimeClass will be used, which is
an implicit class with an
+ // empty definition that uses the default runtime handler.
+ // +optional
+ RuntimeClassName *string `json:"runtimeClassName,omitempty"`
+
+ // NodeSelector is a selector which must be true for the pod to fit on
a node.
+ // Selector which must match a node's labels for the pod to be
scheduled on that node.
+ // +optional
+ NodeSelector map[string]string `json:"nodeSelector,omitempty"`
}
// MainContainer stores information of the main container of coordinators or
shuffle servers,
@@ -235,11 +247,6 @@ type MainContainer struct {
// VolumeMounts indicates describes mountings of volumes within shuffle
servers' container.
// +optional
VolumeMounts []corev1.VolumeMount `json:"volumeMounts,omitempty"`
-
- // NodeSelector is a selector which must be true for the pod to fit on
a node.
- // Selector which must match a node's labels for the pod to be
scheduled on that node.
- // +optional
- NodeSelector map[string]string `json:"nodeSelector,omitempty"`
}
// RemoteShuffleServiceStatus defines the observed state of
RemoteShuffleService
diff --git
a/deploy/kubernetes/operator/api/uniffle/v1alpha1/zz_generated.deepcopy.go
b/deploy/kubernetes/operator/api/uniffle/v1alpha1/zz_generated.deepcopy.go
index 4704e9bb..931cd6d9 100644
--- a/deploy/kubernetes/operator/api/uniffle/v1alpha1/zz_generated.deepcopy.go
+++ b/deploy/kubernetes/operator/api/uniffle/v1alpha1/zz_generated.deepcopy.go
@@ -142,13 +142,6 @@ func (in *MainContainer) DeepCopyInto(out *MainContainer) {
(*in)[i].DeepCopyInto(&(*out)[i])
}
}
- if in.NodeSelector != nil {
- in, out := &in.NodeSelector, &out.NodeSelector
- *out = make(map[string]string, len(*in))
- for key, val := range *in {
- (*out)[key] = val
- }
- }
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver,
creating a new MainContainer.
@@ -207,6 +200,18 @@ func (in *RSSPodSpec) DeepCopyInto(out *RSSPodSpec) {
(*out)[key] = val
}
}
+ if in.RuntimeClassName != nil {
+ in, out := &in.RuntimeClassName, &out.RuntimeClassName
+ *out = new(string)
+ **out = **in
+ }
+ if in.NodeSelector != nil {
+ in, out := &in.NodeSelector, &out.NodeSelector
+ *out = make(map[string]string, len(*in))
+ for key, val := range *in {
+ (*out)[key] = val
+ }
+ }
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver,
creating a new RSSPodSpec.
diff --git
a/deploy/kubernetes/operator/config/crd/bases/uniffle.apache.org_remoteshuffleservices.yaml
b/deploy/kubernetes/operator/config/crd/bases/uniffle.apache.org_remoteshuffleservices.yaml
index d45c2de1..ba86981d 100644
---
a/deploy/kubernetes/operator/config/crd/bases/uniffle.apache.org_remoteshuffleservices.yaml
+++
b/deploy/kubernetes/operator/config/crd/bases/uniffle.apache.org_remoteshuffleservices.yaml
@@ -344,6 +344,14 @@ spec:
description: RPCPort defines rpc port used by coordinators.
format: int32
type: integer
+ runtimeClassName:
+ description: RuntimeClassName refers to a RuntimeClass
object
+ in the node.k8s.io group, which should be used to run
this pod. If
+ no RuntimeClass resource matches the named class, the
pod will
+ not be run. If unset or empty, the "legacy" RuntimeClass
will
+ be used, which is an implicit class with an empty
definition
+ that uses the default runtime handler.
+ type: string
securityContext:
description: SecurityContext holds pod-level security
attributes
and common container settings.
@@ -3537,6 +3545,14 @@ spec:
description: RPCPort defines rpc port used by shuffle
servers.
format: int32
type: integer
+ runtimeClassName:
+ description: RuntimeClassName refers to a RuntimeClass
object
+ in the node.k8s.io group, which should be used to run
this pod. If
+ no RuntimeClass resource matches the named class, the
pod will
+ not be run. If unset or empty, the "legacy" RuntimeClass
will
+ be used, which is an implicit class with an empty
definition
+ that uses the default runtime handler.
+ type: string
securityContext:
description: SecurityContext holds pod-level security
attributes
and common container settings.
diff --git
a/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator.go
b/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator.go
index 45140a47..47e2d3ac 100644
--- a/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator.go
+++ b/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator.go
@@ -43,7 +43,9 @@ func init() {
defaultENVs.Insert(controllerconstants.CoordinatorRPCPortEnv,
controllerconstants.CoordinatorHTTPPortEnv,
controllerconstants.XmxSizeEnv,
- controllerconstants.ServiceNameEnv)
+ controllerconstants.ServiceNameEnv,
+ controllerconstants.NodeNameEnv,
+ controllerconstants.RssIPEnv)
}
// GenerateCoordinators generates objects related to coordinators
@@ -189,6 +191,11 @@ func GenerateDeploy(rss
*unifflev1alpha1.RemoteShuffleService, index int) *appsv
deploy.Spec.Template.Labels[k] = v
}
+ // set runtimeClassName
+ if rss.Spec.Coordinator.RuntimeClassName != nil {
+ deploy.Spec.Template.Spec.RuntimeClassName =
rss.Spec.Coordinator.RuntimeClassName
+ }
+
// add init containers, the main container and other containers.
deploy.Spec.Template.Spec.InitContainers =
util.GenerateInitContainers(rss.Spec.Coordinator.RSSPodSpec)
containers := []corev1.Container{*generateMainContainer(rss)}
diff --git
a/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator_test.go
b/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator_test.go
index 89ee7220..6c16eebc 100644
---
a/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator_test.go
+++
b/deploy/kubernetes/operator/pkg/controller/sync/coordinator/coordinator_test.go
@@ -18,48 +18,91 @@
package coordinator
import (
+ "encoding/json"
"fmt"
+ "reflect"
+ "strconv"
"testing"
appsv1 "k8s.io/api/apps/v1"
+ corev1 "k8s.io/api/core/v1"
+ "k8s.io/apimachinery/pkg/util/sets"
+ "k8s.io/utils/pointer"
- unifflev1alpha1
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/api/uniffle/v1alpha1"
+ uniffleapi
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/api/uniffle/v1alpha1"
+ controllerconstants
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/pkg/controller/constants"
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/pkg/utils"
)
+const (
+ testRuntimeClassName = "test-runtime"
+)
+
// IsValidDeploy checks generated deployment, returns whether it is valid and
error message.
-type IsValidDeploy func(*appsv1.Deployment) (bool, error)
+type IsValidDeploy func(*appsv1.Deployment, *uniffleapi.RemoteShuffleService)
(bool, error)
+
+var (
+ testLabels = map[string]string{
+ "key1": "value1",
+ "key2": "value2",
+ "key3": "value3",
+ }
+ testENVs = []corev1.EnvVar{
+ {
+ Name: "ENV1",
+ Value: "Value1",
+ },
+ {
+ Name: "ENV2",
+ Value: "Value2",
+ },
+ {
+ Name: "ENV3",
+ Value: "Value3",
+ },
+ {
+ Name: controllerconstants.RssIPEnv,
+ Value: "127.0.0.1",
+ },
+ }
+)
-var commonLabels = map[string]string{
- "key1": "value1",
- "key2": "value2",
- "key3": "value3",
+func buildRssWithLabels() *uniffleapi.RemoteShuffleService {
+ rss := utils.BuildRSSWithDefaultValue()
+ rss.Spec.Coordinator.Labels = testLabels
+ return rss
}
-func buildRssWithLabels() *unifflev1alpha1.RemoteShuffleService {
+func buildRssWithRuntimeClassName() *uniffleapi.RemoteShuffleService {
rss := utils.BuildRSSWithDefaultValue()
- rss.Spec.Coordinator.Labels = commonLabels
+ rss.Spec.Coordinator.RuntimeClassName =
pointer.String(testRuntimeClassName)
+ return rss
+}
+
+func buildRssWithCustomENVs() *uniffleapi.RemoteShuffleService {
+ rss := utils.BuildRSSWithDefaultValue()
+ rss.Spec.Coordinator.Env = testENVs
return rss
}
func TestGenerateDeploy(t *testing.T) {
for _, tt := range []struct {
name string
- rss *unifflev1alpha1.RemoteShuffleService
+ rss *uniffleapi.RemoteShuffleService
IsValidDeploy
}{
{
name: "add custom labels",
rss: buildRssWithLabels(),
- IsValidDeploy: func(deploy *appsv1.Deployment) (bool,
error) {
- var valid = true
- var err error
+ IsValidDeploy: func(deploy *appsv1.Deployment, rss
*uniffleapi.RemoteShuffleService) (
+ valid bool, err error) {
+ valid = true
expectedLabels := map[string]string{
"app": "rss-coordinator-rss-0",
}
- for k := range commonLabels {
- expectedLabels[k] = commonLabels[k]
+ for k := range testLabels {
+ expectedLabels[k] = testLabels[k]
}
currentLabels := deploy.Spec.Template.Labels
@@ -79,10 +122,89 @@ func TestGenerateDeploy(t *testing.T) {
return valid, err
},
},
+ {
+ name: "set custom runtime class name",
+ rss: buildRssWithRuntimeClassName(),
+ IsValidDeploy: func(deploy *appsv1.Deployment, rss
*uniffleapi.RemoteShuffleService) (
+ valid bool, err error) {
+ currentRuntimeClassName :=
deploy.Spec.Template.Spec.RuntimeClassName
+ if currentRuntimeClassName == nil {
+ return false, fmt.Errorf("unexpected
empty runtime class, expected: %v",
+ testRuntimeClassName)
+ }
+ if *currentRuntimeClassName !=
testRuntimeClassName {
+ return false, fmt.Errorf("unexpected
runtime class name: %v, expected: %v",
+ *currentRuntimeClassName,
testRuntimeClassName)
+ }
+ return true, nil
+ },
+ },
+ {
+ name: "set custom environment variables",
+ rss: buildRssWithCustomENVs(),
+ IsValidDeploy: func(deploy *appsv1.Deployment, rss
*uniffleapi.RemoteShuffleService) (
+ valid bool, err error) {
+ expectENVs := []corev1.EnvVar{
+ {
+ Name:
controllerconstants.CoordinatorRPCPortEnv,
+ Value:
strconv.FormatInt(int64(controllerconstants.ContainerCoordinatorRPCPort), 10),
+ },
+ {
+ Name:
controllerconstants.CoordinatorHTTPPortEnv,
+ Value:
strconv.FormatInt(int64(controllerconstants.ContainerCoordinatorHTTPPort), 10),
+ },
+ {
+ Name:
controllerconstants.XmxSizeEnv,
+ Value:
rss.Spec.Coordinator.XmxSize,
+ },
+ {
+ Name:
controllerconstants.ServiceNameEnv,
+ Value:
controllerconstants.CoordinatorServiceName,
+ },
+ {
+ Name:
controllerconstants.NodeNameEnv,
+ ValueFrom: &corev1.EnvVarSource{
+ FieldRef:
&corev1.ObjectFieldSelector{
+ APIVersion:
"v1",
+ FieldPath:
"spec.nodeName",
+ },
+ },
+ },
+ {
+ Name:
controllerconstants.RssIPEnv,
+ ValueFrom: &corev1.EnvVarSource{
+ FieldRef:
&corev1.ObjectFieldSelector{
+ APIVersion:
"v1",
+ FieldPath:
"status.podIP",
+ },
+ },
+ },
+ }
+ defaultEnvNames := sets.NewString()
+ for i := range expectENVs {
+
defaultEnvNames.Insert(expectENVs[i].Name)
+ }
+ for i := range testENVs {
+ if
!defaultEnvNames.Has(testENVs[i].Name) {
+ expectENVs = append(expectENVs,
testENVs[i])
+ }
+ }
+
+ actualENVs :=
deploy.Spec.Template.Spec.Containers[0].Env
+ valid = reflect.DeepEqual(expectENVs,
actualENVs)
+ if !valid {
+ actualEnvBody, _ :=
json.Marshal(actualENVs)
+ expectEnvBody, _ :=
json.Marshal(expectENVs)
+ err = fmt.Errorf("unexpected
ENVs:\n%v,\nexpected:\n%v",
+ string(actualEnvBody),
string(expectEnvBody))
+ }
+ return
+ },
+ },
} {
t.Run(tt.name, func(tc *testing.T) {
deploy := GenerateDeploy(tt.rss, 0)
- if valid, err := tt.IsValidDeploy(deploy); !valid {
+ if valid, err := tt.IsValidDeploy(deploy, tt.rss);
!valid {
tc.Error(err)
}
})
diff --git
a/deploy/kubernetes/operator/pkg/controller/sync/shuffleserver/shuffleserver.go
b/deploy/kubernetes/operator/pkg/controller/sync/shuffleserver/shuffleserver.go
index c461049d..e590f756 100644
---
a/deploy/kubernetes/operator/pkg/controller/sync/shuffleserver/shuffleserver.go
+++
b/deploy/kubernetes/operator/pkg/controller/sync/shuffleserver/shuffleserver.go
@@ -27,6 +27,7 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
+ "k8s.io/apimachinery/pkg/util/sets"
"k8s.io/utils/pointer"
unifflev1alpha1
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/api/uniffle/v1alpha1"
@@ -37,6 +38,19 @@ import (
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/pkg/utils"
)
+var defaultENVs sets.String
+
+func init() {
+ defaultENVs = sets.NewString()
+ defaultENVs.Insert(controllerconstants.ShuffleServerRPCPortEnv,
+ controllerconstants.ShuffleServerHTTPPortEnv,
+ controllerconstants.RSSCoordinatorQuorumEnv,
+ controllerconstants.XmxSizeEnv,
+ controllerconstants.ServiceNameEnv,
+ controllerconstants.NodeNameEnv,
+ controllerconstants.RssIPEnv)
+}
+
// GenerateShuffleServers generates objects related to shuffle servers.
func GenerateShuffleServers(rss *unifflev1alpha1.RemoteShuffleService) (
*corev1.ServiceAccount, []*corev1.Service, *appsv1.StatefulSet) {
@@ -217,6 +231,11 @@ func GenerateSts(rss
*unifflev1alpha1.RemoteShuffleService) *appsv1.StatefulSet
sts.Spec.Template.Labels[k] = v
}
+ // set runtimeClassName
+ if rss.Spec.ShuffleServer.RuntimeClassName != nil {
+ sts.Spec.Template.Spec.RuntimeClassName =
rss.Spec.ShuffleServer.RuntimeClassName
+ }
+
// add init containers, the main container and other containers.
sts.Spec.Template.Spec.InitContainers =
util.GenerateInitContainers(rss.Spec.ShuffleServer.RSSPodSpec)
containers := []corev1.Container{*generateMainContainer(rss)}
@@ -291,7 +310,7 @@ func generateMainContainerPorts(rss
*unifflev1alpha1.RemoteShuffleService) []cor
// generateMainContainerENV generates environment variables of main container
of shuffle servers.
func generateMainContainerENV(rss *unifflev1alpha1.RemoteShuffleService)
[]corev1.EnvVar {
- return []corev1.EnvVar{
+ env := []corev1.EnvVar{
{
Name: controllerconstants.ShuffleServerRPCPortEnv,
Value:
strconv.FormatInt(int64(controllerconstants.ContainerShuffleServerRPCPort), 10),
@@ -331,6 +350,12 @@ func generateMainContainerENV(rss
*unifflev1alpha1.RemoteShuffleService) []corev
},
},
}
+ for _, e := range rss.Spec.ShuffleServer.Env {
+ if !defaultENVs.Has(e.Name) {
+ env = append(env, e)
+ }
+ }
+ return env
}
// needGenerateNodePortSVC returns whether we need node port service for
shuffle servers.
diff --git
a/deploy/kubernetes/operator/pkg/controller/sync/shuffleserver/shuffleserver_test.go
b/deploy/kubernetes/operator/pkg/controller/sync/shuffleserver/shuffleserver_test.go
index ede041de..a1c6576c 100644
---
a/deploy/kubernetes/operator/pkg/controller/sync/shuffleserver/shuffleserver_test.go
+++
b/deploy/kubernetes/operator/pkg/controller/sync/shuffleserver/shuffleserver_test.go
@@ -18,54 +18,97 @@
package shuffleserver
import (
+ "encoding/json"
"fmt"
+ "reflect"
+ "strconv"
"testing"
appsv1 "k8s.io/api/apps/v1"
+ corev1 "k8s.io/api/core/v1"
+ "k8s.io/apimachinery/pkg/util/sets"
+ "k8s.io/utils/pointer"
- unifflev1alpha1
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/api/uniffle/v1alpha1"
+ uniffleapi
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/api/uniffle/v1alpha1"
+ controllerconstants
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/pkg/controller/constants"
+
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/pkg/controller/sync/coordinator"
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/pkg/utils"
)
+const (
+ testRuntimeClassName = "test-runtime"
+)
+
// IsValidSts checks generated statefulSet, returns whether it is valid and
error message.
-type IsValidSts func(*appsv1.StatefulSet) (bool, error)
+type IsValidSts func(*appsv1.StatefulSet, *uniffleapi.RemoteShuffleService)
(bool, error)
-var commonLabels = map[string]string{
- "key1": "value1",
- "key2": "value2",
- "key3": "value3",
-}
+var (
+ testLabels = map[string]string{
+ "key1": "value1",
+ "key2": "value2",
+ "key3": "value3",
+ }
+ testENVs = []corev1.EnvVar{
+ {
+ Name: "ENV1",
+ Value: "Value1",
+ },
+ {
+ Name: "ENV2",
+ Value: "Value2",
+ },
+ {
+ Name: "ENV3",
+ Value: "Value3",
+ },
+ {
+ Name: controllerconstants.XmxSizeEnv,
+ Value: "1G",
+ },
+ }
+)
-func buildRssWithLabels() *unifflev1alpha1.RemoteShuffleService {
+func buildRssWithLabels() *uniffleapi.RemoteShuffleService {
rss := utils.BuildRSSWithDefaultValue()
rss.Spec.ShuffleServer.Labels = map[string]string{
"uniffle.apache.org/shuffle-server": "change-test",
}
- for k := range commonLabels {
- rss.Spec.ShuffleServer.Labels[k] = commonLabels[k]
+ for k := range testLabels {
+ rss.Spec.ShuffleServer.Labels[k] = testLabels[k]
}
return rss
}
+func buildRssWithRuntimeClassName() *uniffleapi.RemoteShuffleService {
+ rss := utils.BuildRSSWithDefaultValue()
+ rss.Spec.ShuffleServer.RuntimeClassName =
pointer.String(testRuntimeClassName)
+ return rss
+}
+
+func buildRssWithCustomENVs() *uniffleapi.RemoteShuffleService {
+ rss := utils.BuildRSSWithDefaultValue()
+ rss.Spec.ShuffleServer.Env = testENVs
+ return rss
+}
+
func TestGenerateSts(t *testing.T) {
for _, tt := range []struct {
name string
- rss *unifflev1alpha1.RemoteShuffleService
+ rss *uniffleapi.RemoteShuffleService
IsValidSts
}{
{
name: "add custom labels",
rss: buildRssWithLabels(),
- IsValidSts: func(sts *appsv1.StatefulSet) (bool, error)
{
- var valid = true
- var err error
+ IsValidSts: func(sts *appsv1.StatefulSet, rss
*uniffleapi.RemoteShuffleService) (valid bool, err error) {
+ valid = true
expectedLabels := map[string]string{
"app":
"rss-shuffle-server-rss",
"uniffle.apache.org/shuffle-server":
"true",
}
- for k := range commonLabels {
- expectedLabels[k] = commonLabels[k]
+ for k := range testLabels {
+ expectedLabels[k] = testLabels[k]
}
currentLabels := sts.Spec.Template.Labels
@@ -85,10 +128,91 @@ func TestGenerateSts(t *testing.T) {
return valid, err
},
},
+ {
+ name: "set custom runtime class name",
+ rss: buildRssWithRuntimeClassName(),
+ IsValidSts: func(sts *appsv1.StatefulSet, rss
*uniffleapi.RemoteShuffleService) (valid bool, err error) {
+ currentRuntimeClassName :=
sts.Spec.Template.Spec.RuntimeClassName
+ if currentRuntimeClassName == nil {
+ return false, fmt.Errorf("unexpected
empty runtime class, expected: %v",
+ testRuntimeClassName)
+ }
+ if *currentRuntimeClassName !=
testRuntimeClassName {
+ return false, fmt.Errorf("unexpected
runtime class name: %v, expected: %v",
+ *currentRuntimeClassName,
testRuntimeClassName)
+ }
+ return true, nil
+ },
+ },
+ {
+ name: "set custom environment variables",
+ rss: buildRssWithCustomENVs(),
+ IsValidSts: func(sts *appsv1.StatefulSet, rss
*uniffleapi.RemoteShuffleService) (valid bool, err error) {
+ expectENVs := []corev1.EnvVar{
+ {
+ Name:
controllerconstants.ShuffleServerRPCPortEnv,
+ Value:
strconv.FormatInt(int64(controllerconstants.ContainerShuffleServerRPCPort), 10),
+ },
+ {
+ Name:
controllerconstants.ShuffleServerHTTPPortEnv,
+ Value:
strconv.FormatInt(int64(controllerconstants.ContainerShuffleServerHTTPPort),
10),
+ },
+ {
+ Name:
controllerconstants.RSSCoordinatorQuorumEnv,
+ Value:
coordinator.GenerateAddresses(rss),
+ },
+ {
+ Name:
controllerconstants.XmxSizeEnv,
+ Value:
rss.Spec.ShuffleServer.XmxSize,
+ },
+ {
+ Name:
controllerconstants.ServiceNameEnv,
+ Value:
controllerconstants.ShuffleServerServiceName,
+ },
+ {
+ Name:
controllerconstants.NodeNameEnv,
+ ValueFrom: &corev1.EnvVarSource{
+ FieldRef:
&corev1.ObjectFieldSelector{
+ APIVersion:
"v1",
+ FieldPath:
"spec.nodeName",
+ },
+ },
+ },
+ {
+ Name:
controllerconstants.RssIPEnv,
+ ValueFrom: &corev1.EnvVarSource{
+ FieldRef:
&corev1.ObjectFieldSelector{
+ APIVersion:
"v1",
+ FieldPath:
"status.podIP",
+ },
+ },
+ },
+ }
+ defaultEnvNames := sets.NewString()
+ for i := range expectENVs {
+
defaultEnvNames.Insert(expectENVs[i].Name)
+ }
+ for i := range testENVs {
+ if
!defaultEnvNames.Has(testENVs[i].Name) {
+ expectENVs = append(expectENVs,
testENVs[i])
+ }
+ }
+
+ actualENVs :=
sts.Spec.Template.Spec.Containers[0].Env
+ valid = reflect.DeepEqual(expectENVs,
actualENVs)
+ if !valid {
+ actualEnvBody, _ :=
json.Marshal(actualENVs)
+ expectEnvBody, _ :=
json.Marshal(expectENVs)
+ err = fmt.Errorf("unexpected
ENVs:\n%v,\nexpected:\n%v",
+ string(actualEnvBody),
string(expectEnvBody))
+ }
+ return
+ },
+ },
} {
t.Run(tt.name, func(tc *testing.T) {
sts := GenerateSts(tt.rss)
- if valid, err := tt.IsValidSts(sts); !valid {
+ if valid, err := tt.IsValidSts(sts, tt.rss); !valid {
tc.Error(err)
}
})
diff --git a/deploy/kubernetes/operator/pkg/webhook/inspector/rss.go
b/deploy/kubernetes/operator/pkg/webhook/inspector/rss.go
index 3e8f969f..ce1033c4 100644
--- a/deploy/kubernetes/operator/pkg/webhook/inspector/rss.go
+++ b/deploy/kubernetes/operator/pkg/webhook/inspector/rss.go
@@ -18,11 +18,14 @@
package inspector
import (
+ "context"
"encoding/json"
"fmt"
"gomodules.xyz/jsonpatch/v2"
admissionv1 "k8s.io/api/admission/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/client-go/kubernetes"
"k8s.io/klog/v2"
"k8s.io/utils/pointer"
@@ -55,6 +58,11 @@ func (i *inspector) validateRSS(ar
*admissionv1.AdmissionReview) *admissionv1.Ad
string(ar.Request.Object.Raw), err)
return util.AdmissionReviewFailed(ar, err)
}
+ if err := validateRuntimeClassNames(newRSS, i.kubeClient); err != nil {
+ klog.Errorf("validate runtime class of rss (%v) failed: %v",
+ utils.UniqueName(newRSS), err)
+ return util.AdmissionReviewFailed(ar, err)
+ }
if err := validateCoordinator(newRSS.Spec.Coordinator); err != nil {
klog.Errorf("validate coordinator config of rss (%v) failed:
%v",
utils.UniqueName(newRSS), err)
@@ -163,3 +171,26 @@ func validateCoordinator(coordinator
*unifflev1alpha1.CoordinatorConfig) error {
}
return nil
}
+
+func validateRuntimeClassNames(rss *unifflev1alpha1.RemoteShuffleService,
kubeClient kubernetes.Interface) error {
+ if err :=
validateRuntimeClassName(rss.Spec.Coordinator.RuntimeClassName, kubeClient);
err != nil {
+ klog.Errorf("failed to get runtime class for coordinator: %v",
err)
+ return err
+ }
+ if err :=
validateRuntimeClassName(rss.Spec.ShuffleServer.RuntimeClassName, kubeClient);
err != nil {
+ klog.Errorf("failed to get runtime class for shuffleServer:
%v", err)
+ return err
+ }
+ return nil
+}
+
+func validateRuntimeClassName(runtimeClassName *string, kubeClient
kubernetes.Interface) error {
+ if runtimeClassName == nil {
+ return nil
+ }
+ _, err := kubeClient.NodeV1().RuntimeClasses().Get(context.TODO(),
*runtimeClassName, metav1.GetOptions{})
+ if err != nil {
+ klog.Errorf("failed to get runtime class %v: %v",
*runtimeClassName, err)
+ }
+ return err
+}
diff --git a/deploy/kubernetes/operator/pkg/webhook/inspector/rss_test.go
b/deploy/kubernetes/operator/pkg/webhook/inspector/rss_test.go
index 8c7b08bb..101a3e4c 100644
--- a/deploy/kubernetes/operator/pkg/webhook/inspector/rss_test.go
+++ b/deploy/kubernetes/operator/pkg/webhook/inspector/rss_test.go
@@ -22,25 +22,34 @@ import (
"testing"
admissionv1 "k8s.io/api/admission/v1"
- corev1 "k8s.io/api/core/v1"
+ nodev1 "k8s.io/api/node/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
+ "k8s.io/client-go/kubernetes"
+ kubefake "k8s.io/client-go/kubernetes/fake"
"k8s.io/utils/pointer"
- unifflev1alpha1
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/api/uniffle/v1alpha1"
+ uniffleapi
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/api/uniffle/v1alpha1"
+
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/pkg/utils"
"github.com/apache/incubator-uniffle/deploy/kubernetes/operator/pkg/webhook/config"
)
-func wrapTestRssObj(rss *unifflev1alpha1.RemoteShuffleService)
*unifflev1alpha1.RemoteShuffleService {
- rss.Name = "test"
- rss.Namespace = corev1.NamespaceDefault
- rss.UID = "uid-test"
+const (
+ testRuntimeClassName = "test-runtime"
+)
+
+type wrapper func(rss *uniffleapi.RemoteShuffleService)
+
+func wrapRssObj(wrapperFunc wrapper) *uniffleapi.RemoteShuffleService {
+ rss := utils.BuildRSSWithDefaultValue()
+ wrapperFunc(rss)
return rss
}
// convertRssToRawExtension converts a rss object to runtime.RawExtension for
testing.
-func convertRssToRawExtension(rss *unifflev1alpha1.RemoteShuffleService)
(runtime.RawExtension, error) {
+func convertRssToRawExtension(rss *uniffleapi.RemoteShuffleService)
(runtime.RawExtension, error) {
if rss == nil {
- return
convertRssToRawExtension(&unifflev1alpha1.RemoteShuffleService{})
+ return
convertRssToRawExtension(&uniffleapi.RemoteShuffleService{})
}
body, err := json.Marshal(rss)
if err != nil {
@@ -54,7 +63,7 @@ func convertRssToRawExtension(rss
*unifflev1alpha1.RemoteShuffleService) (runtim
// buildTestAdmissionReview builds an AdmissionReview object for testing.
func buildTestAdmissionReview(op admissionv1.Operation,
- oldRss, newRss *unifflev1alpha1.RemoteShuffleService)
*admissionv1.AdmissionReview {
+ oldRss, newRss *uniffleapi.RemoteShuffleService)
*admissionv1.AdmissionReview {
oldObject, err := convertRssToRawExtension(oldRss)
if err != nil {
panic(err)
@@ -77,20 +86,17 @@ func buildTestAdmissionReview(op admissionv1.Operation,
func TestValidateRSS(t *testing.T) {
testInspector := newInspector(&config.Config{}, nil)
- rssWithCooNodePort := &unifflev1alpha1.RemoteShuffleService{
- Spec: unifflev1alpha1.RemoteShuffleServiceSpec{
- Coordinator: &unifflev1alpha1.CoordinatorConfig{
- Count: pointer.Int32(2),
- RPCNodePort: []int32{30001, 30002},
- HTTPNodePort: []int32{30011, 30012},
- },
- },
- }
+ rssWithCooNodePort := wrapRssObj(func(rss
*uniffleapi.RemoteShuffleService) {
+ rss.Spec.Coordinator.Count = pointer.Int32(2)
+ rss.Spec.Coordinator.RPCNodePort = []int32{30001, 30002}
+ rss.Spec.Coordinator.HTTPNodePort = []int32{30011, 30012}
+ rss.Spec.Coordinator.ExcludeNodesFilePath = ""
+ })
rssWithoutLogInCooMounts := rssWithCooNodePort.DeepCopy()
rssWithoutLogInCooMounts.Spec.Coordinator.ExcludeNodesFilePath =
"/exclude_nodes"
- rssWithoutLogInCooMounts.Spec.Coordinator.CommonConfig =
&unifflev1alpha1.CommonConfig{
- RSSPodSpec: &unifflev1alpha1.RSSPodSpec{
+ rssWithoutLogInCooMounts.Spec.Coordinator.CommonConfig =
&uniffleapi.CommonConfig{
+ RSSPodSpec: &uniffleapi.RSSPodSpec{
LogHostPath: "/data/logs",
HostPathMounts: map[string]string{},
},
@@ -98,9 +104,9 @@ func TestValidateRSS(t *testing.T) {
rssWithoutLogInServerMounts := rssWithoutLogInCooMounts.DeepCopy()
rssWithoutLogInServerMounts.Spec.Coordinator.CommonConfig.RSSPodSpec.HostPathMounts["/data/logs"]
= "/data/logs"
- rssWithoutLogInServerMounts.Spec.ShuffleServer =
&unifflev1alpha1.ShuffleServerConfig{
- CommonConfig: &unifflev1alpha1.CommonConfig{
- RSSPodSpec: &unifflev1alpha1.RSSPodSpec{
+ rssWithoutLogInServerMounts.Spec.ShuffleServer =
&uniffleapi.ShuffleServerConfig{
+ CommonConfig: &uniffleapi.CommonConfig{
+ RSSPodSpec: &uniffleapi.RSSPodSpec{
LogHostPath: "/data/logs",
HostPathMounts: map[string]string{},
},
@@ -109,26 +115,26 @@ func TestValidateRSS(t *testing.T) {
rssWithoutPartition := rssWithoutLogInServerMounts.DeepCopy()
rssWithoutPartition.Spec.ShuffleServer.CommonConfig.RSSPodSpec.HostPathMounts["/data/logs"]
= "/data/logs"
- rssWithoutPartition.Spec.ShuffleServer.UpgradeStrategy =
&unifflev1alpha1.ShuffleServerUpgradeStrategy{
- Type: unifflev1alpha1.PartitionUpgrade,
+ rssWithoutPartition.Spec.ShuffleServer.UpgradeStrategy =
&uniffleapi.ShuffleServerUpgradeStrategy{
+ Type: uniffleapi.PartitionUpgrade,
}
rssWithInvalidPartition := rssWithoutLogInServerMounts.DeepCopy()
rssWithInvalidPartition.Spec.ShuffleServer.CommonConfig.RSSPodSpec.HostPathMounts["/data/logs"]
= "/data/logs"
- rssWithInvalidPartition.Spec.ShuffleServer.UpgradeStrategy =
&unifflev1alpha1.ShuffleServerUpgradeStrategy{
- Type: unifflev1alpha1.PartitionUpgrade,
+ rssWithInvalidPartition.Spec.ShuffleServer.UpgradeStrategy =
&uniffleapi.ShuffleServerUpgradeStrategy{
+ Type: uniffleapi.PartitionUpgrade,
Partition: pointer.Int32(-1),
}
rssWithoutSpecificNames := rssWithoutLogInServerMounts.DeepCopy()
rssWithoutSpecificNames.Spec.ShuffleServer.CommonConfig.RSSPodSpec.HostPathMounts["/data/logs"]
= "/data/logs"
- rssWithoutSpecificNames.Spec.ShuffleServer.UpgradeStrategy =
&unifflev1alpha1.ShuffleServerUpgradeStrategy{
- Type: unifflev1alpha1.SpecificUpgrade,
+ rssWithoutSpecificNames.Spec.ShuffleServer.UpgradeStrategy =
&uniffleapi.ShuffleServerUpgradeStrategy{
+ Type: uniffleapi.SpecificUpgrade,
}
rssWithoutUpgradeStrategyType := rssWithoutLogInServerMounts.DeepCopy()
rssWithoutUpgradeStrategyType.Spec.ShuffleServer.CommonConfig.RSSPodSpec.HostPathMounts["/data/logs"]
= "/data/logs"
- rssWithoutUpgradeStrategyType.Spec.ShuffleServer.UpgradeStrategy =
&unifflev1alpha1.ShuffleServerUpgradeStrategy{}
+ rssWithoutUpgradeStrategyType.Spec.ShuffleServer.UpgradeStrategy =
&uniffleapi.ShuffleServerUpgradeStrategy{}
for _, tt := range []struct {
name string
@@ -137,75 +143,138 @@ func TestValidateRSS(t *testing.T) {
}{
{
name: "try to modify a upgrading rss object",
- ar: buildTestAdmissionReview(admissionv1.Update,
wrapTestRssObj(&unifflev1alpha1.RemoteShuffleService{
- Status:
unifflev1alpha1.RemoteShuffleServiceStatus{
- Phase: unifflev1alpha1.RSSUpgrading,
- },
- }), nil),
+ ar: buildTestAdmissionReview(admissionv1.Update,
wrapRssObj(
+ func(rss *uniffleapi.RemoteShuffleService) {
+ rss.Status =
uniffleapi.RemoteShuffleServiceStatus{
+ Phase: uniffleapi.RSSUpgrading,
+ }
+ }), nil),
allowed: false,
},
{
name: "invalid rpc node port number in a rss object",
- ar: buildTestAdmissionReview(admissionv1.Update, nil,
wrapTestRssObj(&unifflev1alpha1.RemoteShuffleService{
- Spec: unifflev1alpha1.RemoteShuffleServiceSpec{
- Coordinator:
&unifflev1alpha1.CoordinatorConfig{
- Count: pointer.Int32(2),
- RPCNodePort: []int32{30001},
- },
- },
- })),
+ ar: buildTestAdmissionReview(admissionv1.Create, nil,
+ wrapRssObj(func(rss
*uniffleapi.RemoteShuffleService) {
+ rss.Spec.Coordinator.Count =
pointer.Int32(2)
+ rss.Spec.Coordinator.RPCNodePort =
[]int32{30001}
+ })),
allowed: false,
},
{
name: "invalid http node port number in a rss object",
- ar: buildTestAdmissionReview(admissionv1.Update, nil,
wrapTestRssObj(&unifflev1alpha1.RemoteShuffleService{
- Spec: unifflev1alpha1.RemoteShuffleServiceSpec{
- Coordinator:
&unifflev1alpha1.CoordinatorConfig{
- Count: pointer.Int32(2),
- RPCNodePort: []int32{30001,
30002},
- HTTPNodePort: []int32{30011},
- },
- },
- })),
+ ar: buildTestAdmissionReview(admissionv1.Create, nil,
+ wrapRssObj(func(rss
*uniffleapi.RemoteShuffleService) {
+ rss.Spec.Coordinator.Count =
pointer.Int32(2)
+ rss.Spec.Coordinator.RPCNodePort =
[]int32{30001, 30002}
+ rss.Spec.Coordinator.HTTPNodePort =
[]int32{30011}
+ })),
allowed: false,
},
{
name: "empty exclude nodes file path field in a rss
object",
- ar: buildTestAdmissionReview(admissionv1.Update,
nil, wrapTestRssObj(rssWithCooNodePort)),
+ ar: buildTestAdmissionReview(admissionv1.Create,
nil, rssWithCooNodePort),
allowed: false,
},
{
name: "can not find log host path in coordinators'
host path mounts field in a rss object",
- ar: buildTestAdmissionReview(admissionv1.Update,
nil, wrapTestRssObj(rssWithoutLogInCooMounts)),
+ ar: buildTestAdmissionReview(admissionv1.Create,
nil, rssWithoutLogInCooMounts),
allowed: false,
},
{
name: "can not find log host path in shuffle server'
host path mounts field in a rss object",
- ar: buildTestAdmissionReview(admissionv1.Update,
nil, wrapTestRssObj(rssWithoutLogInServerMounts)),
+ ar: buildTestAdmissionReview(admissionv1.Create,
nil, rssWithoutLogInServerMounts),
allowed: false,
},
{
name: "empty partition field when shuffler server of
a rss object need partition upgrade",
- ar: buildTestAdmissionReview(admissionv1.Update,
nil, wrapTestRssObj(rssWithoutPartition)),
+ ar: buildTestAdmissionReview(admissionv1.Create,
nil, rssWithoutPartition),
allowed: false,
},
{
name: "invalid partition field when shuffler server
of a rss object need partition upgrade",
- ar: buildTestAdmissionReview(admissionv1.Update,
nil, wrapTestRssObj(rssWithInvalidPartition)),
+ ar: buildTestAdmissionReview(admissionv1.Create,
nil, rssWithInvalidPartition),
allowed: false,
},
{
name: "empty specific names field when shuffler
server of a rss object need specific upgrade",
- ar: buildTestAdmissionReview(admissionv1.Update,
nil, wrapTestRssObj(rssWithoutSpecificNames)),
+ ar: buildTestAdmissionReview(admissionv1.Create,
nil, rssWithoutSpecificNames),
allowed: false,
},
{
name: "empty upgrade strategy type in shuffler
server of a rss object",
- ar: buildTestAdmissionReview(admissionv1.Update,
nil, wrapTestRssObj(rssWithoutUpgradeStrategyType)),
+ ar: buildTestAdmissionReview(admissionv1.Create,
nil, rssWithoutUpgradeStrategyType),
+ allowed: false,
+ },
+ } {
+ t.Run(tt.name, func(tc *testing.T) {
+ updatedAR := testInspector.validateRSS(tt.ar)
+ if !updatedAR.Response.Allowed {
+ tc.Logf("==> message in result: %+v",
updatedAR.Response.Result.Message)
+ }
+ if updatedAR.Response.Allowed != tt.allowed {
+ tc.Errorf("invalid 'allowed' field in response:
%v <> %v", updatedAR.Response.Allowed, tt.allowed)
+ }
+ })
+ }
+}
+
+func TestValidateRuntimeClassNames(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ runtimeClass *nodev1.RuntimeClass
+ ar *admissionv1.AdmissionReview
+ allowed bool
+ }{
+ {
+ name: "create rss object with existent runtime
class names",
+ runtimeClass: buildTestRuntimeClass(),
+ ar: buildTestAdmissionReview(admissionv1.Create, nil,
+ wrapRssObj(func(rss
*uniffleapi.RemoteShuffleService) {
+ rss.Spec.Coordinator.RuntimeClassName =
pointer.String(testRuntimeClassName)
+ rss.Spec.ShuffleServer.RuntimeClassName
= pointer.String(testRuntimeClassName)
+ })),
+ allowed: true,
+ },
+ {
+ name: "create rss object with empty runtime
class name used by shuffle server",
+ runtimeClass: buildTestRuntimeClass(),
+ ar: buildTestAdmissionReview(admissionv1.Create, nil,
+ wrapRssObj(func(rss
*uniffleapi.RemoteShuffleService) {
+ rss.Spec.Coordinator.RuntimeClassName =
pointer.String(testRuntimeClassName)
+ })),
+ allowed: true,
+ },
+ {
+ name: "create rss object with non existent runtime
class name used by coordinator",
+ ar: buildTestAdmissionReview(admissionv1.Create, nil,
+ wrapRssObj(func(rss
*uniffleapi.RemoteShuffleService) {
+ rss.Spec.Coordinator.RuntimeClassName =
pointer.String(testRuntimeClassName)
+ })),
+ allowed: false,
+ },
+ {
+ name: "create rss object with non existent runtime
class name used by shuffle server",
+ ar: buildTestAdmissionReview(admissionv1.Create, nil,
+ wrapRssObj(func(rss
*uniffleapi.RemoteShuffleService) {
+ rss.Spec.ShuffleServer.RuntimeClassName
= pointer.String(testRuntimeClassName)
+ })),
allowed: false,
},
} {
t.Run(tt.name, func(tc *testing.T) {
+ var kubeClient kubernetes.Interface
+ if tt.runtimeClass != nil {
+ kubeClient =
kubefake.NewSimpleClientset(tt.runtimeClass)
+ } else {
+ kubeClient = kubefake.NewSimpleClientset()
+ }
+
+ testInspector := newInspector(&config.Config{
+ GenericConfig: utils.GenericConfig{
+ KubeClient: kubeClient,
+ },
+ }, nil)
+
updatedAR := testInspector.validateRSS(tt.ar)
if !updatedAR.Response.Allowed {
tc.Logf("==> message in result: %+v",
updatedAR.Response.Result.Message)
@@ -216,3 +285,12 @@ func TestValidateRSS(t *testing.T) {
})
}
}
+
+func buildTestRuntimeClass() *nodev1.RuntimeClass {
+ return &nodev1.RuntimeClass{
+ ObjectMeta: metav1.ObjectMeta{
+ Name: testRuntimeClassName,
+ },
+ Handler: "/etc/runtime/bin",
+ }
+}