This is an automated email from the ASF dual-hosted git repository. cdmikechen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push: new 3d2a3e67 SUBMARINE-1283. copy data for experiment before it running via distcp to minio 3d2a3e67 is described below commit 3d2a3e67530d2836ae9be8e5f9549bb30dcc53ca Author: FatalLin <fatal...@gmail.com> AuthorDate: Tue Sep 20 22:16:50 2022 +0800 SUBMARINE-1283. copy data for experiment before it running via distcp to minio ### What is this PR for? This is a prototype of experiment prehandler, once the required arguments has been set, submarine would put an init-container on the main pod. The init container would copy the source data to minio to the path /submarine/${experimentId}. note: the init container would be add under: TFJob: ps PytorchJob: master XGboostJob: master ### What type of PR is it? Feature ### Todos add a housekeeping container to clean the copied data. ### What is the Jira issue? https://issues.apache.org/jira/browse/SUBMARINE-1283 ### How should this be tested? Should add another test for it. ### Screenshots (if appropriate) ### Questions: * Do the license files need updating? No * Are there breaking changes for older versions? No * Does this need new documentation? Yes Author: FatalLin <fatal...@gmail.com> Signed-off-by: Xiang Chen <cdmikec...@apache.org> Closes #989 from FatalLin/SUBMARINE-1283 and squashes the following commits: ef804903 [FatalLin] fix conflict d816b880 [FatalLin] modify test cases 171da696 [FatalLin] code polish 36297f31 [FatalLin] polish code 2c62fb7c [FatalLin] fix 17135f2a [FatalLin] fix script 314a3866 [FatalLin] fix build script 028eff7d [FatalLin] fix script 79d91c88 [FatalLin] prototype of experiment prehandler eaa25e1d [FatalLin] Merge branch 'master' of https://github.com/apache/submarine into SUBMARINE-1283 d81841b1 [FatalLin] Merge branch 'master' of https://github.com/apache/submarine into SUBMARINE-1283 44473203 [FatalLin] for debugging --- .github/scripts/build-image-locally-v3.sh | 4 +- .github/scripts/build-image-locally.sh | 4 +- .../docker-images/experiment-prehandler/Dockerfile | 3 +- .../docker-images/experiment-prehandler/build.sh | 10 +++ .../fs_prehandler/hdfs_prehandler.py | 42 +++++++---- .../submarine/server/api/spec/ExperimentSpec.java | 9 +++ .../server/submitter/k8s/model/mljob/MLJob.java | 86 ++++++++++++++++++++++ .../submitter/k8s/model/pytorchjob/PyTorchJob.java | 7 ++ .../server/submitter/k8s/model/tfjob/TFJob.java | 12 +++ .../submitter/k8s/model/xgboostjob/XGBoostJob.java | 8 +- .../submitter/k8s/ExperimentSpecParserTest.java | 44 ++++++++++- .../src/test/resources/pytorch_job_req.json | 7 ++ .../src/test/resources/tf_mnist_req.json | 7 ++ .../src/test/resources/xgboost_job_req.json | 7 ++ 14 files changed, 231 insertions(+), 19 deletions(-) diff --git a/.github/scripts/build-image-locally-v3.sh b/.github/scripts/build-image-locally-v3.sh index 81203af7..00fdb770 100755 --- a/.github/scripts/build-image-locally-v3.sh +++ b/.github/scripts/build-image-locally-v3.sh @@ -17,12 +17,14 @@ # SUBMARINE_VERSION="0.8.0-SNAPSHOT" -FOLDER_LIST=("database" "mlflow" "submarine" "operator-v3") +FOLDER_LIST=("database" "mlflow" "submarine" "operator-v3" "agent" "experiment-prehandler") IMAGE_LIST=( "apache/submarine:database-${SUBMARINE_VERSION}" "apache/submarine:mlflow-${SUBMARINE_VERSION}" "apache/submarine:server-${SUBMARINE_VERSION}" "apache/submarine:operator-${SUBMARINE_VERSION}" + "apache/submarine:agent-${SUBMARINE_VERSION}" + "apache/submarine:experiment-prehandler-${SUBMARINE_VERSION}" ) for i in "${!IMAGE_LIST[@]}" diff --git a/.github/scripts/build-image-locally.sh b/.github/scripts/build-image-locally.sh index a53de5b3..4d35d690 100755 --- a/.github/scripts/build-image-locally.sh +++ b/.github/scripts/build-image-locally.sh @@ -17,12 +17,14 @@ # SUBMARINE_VERSION="0.8.0-SNAPSHOT" -FOLDER_LIST=("database" "mlflow" "submarine" "operator") +FOLDER_LIST=("database" "mlflow" "submarine" "operator" "agent" "experiment-prehandler") IMAGE_LIST=( "apache/submarine:database-${SUBMARINE_VERSION}" "apache/submarine:mlflow-${SUBMARINE_VERSION}" "apache/submarine:server-${SUBMARINE_VERSION}" "apache/submarine:operator-${SUBMARINE_VERSION}" + "apache/submarine:agent-${SUBMARINE_VERSION}" + "apache/submarine:experiment-prehandler-${SUBMARINE_VERSION}" ) for i in "${!IMAGE_LIST[@]}" diff --git a/dev-support/docker-images/experiment-prehandler/Dockerfile b/dev-support/docker-images/experiment-prehandler/Dockerfile index 87307d07..7a6c7e69 100644 --- a/dev-support/docker-images/experiment-prehandler/Dockerfile +++ b/dev-support/docker-images/experiment-prehandler/Dockerfile @@ -21,7 +21,8 @@ RUN apt-get -y install python3 python3-pip bash tini ADD ./tmp/hadoop-3.3.3.tar.gz /opt/ ADD ./tmp/submarine-experiment-prehandler /opt/submarine-experiment-prehandler - +ADD ./tmp/hadoop-aws-3.3.3.jar /opt/hadoop-3.3.3/share/hadoop/hdfs +ADD ./tmp/aws-java-sdk-bundle-1.12.267.jar /opt/hadoop-3.3.3/share/hadoop/hdfs ENV HADOOP_HOME=/opt/hadoop-3.3.3 ENV ARROW_LIBHDFS_DIR=/opt/hadoop-3.3.3/lib/native diff --git a/dev-support/docker-images/experiment-prehandler/build.sh b/dev-support/docker-images/experiment-prehandler/build.sh index fcdedb05..c1e94a37 100755 --- a/dev-support/docker-images/experiment-prehandler/build.sh +++ b/dev-support/docker-images/experiment-prehandler/build.sh @@ -19,6 +19,12 @@ set -euxo pipefail SUBMARINE_VERSION=0.8.0-SNAPSHOT SUBMARINE_IMAGE_NAME="apache/submarine:experiment-prehandler-${SUBMARINE_VERSION}" +if [ -L ${BASH_SOURCE-$0} ]; then + PWD=$(dirname $(readlink "${BASH_SOURCE-$0}")) +else + PWD=$(dirname ${BASH_SOURCE-$0}) +fi + export CURRENT_PATH=$(cd "${PWD}">/dev/null; pwd) export SUBMARINE_HOME=${CURRENT_PATH}/../../.. @@ -33,7 +39,11 @@ trap "test -f $tmpfile && rm $tmpfile" RETURN curl -L -o $tmpfile ${HADOOP_TAR_URL} mv $tmpfile ${CURRENT_PATH}/tmp/hadoop-3.3.3.tar.gz +curl -L -o ${CURRENT_PATH}/tmp/hadoop-aws-3.3.3.jar https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.3.3/hadoop-aws-3.3.3.jar +curl -L -o ${CURRENT_PATH}/tmp/aws-java-sdk-bundle-1.12.267.jar https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.267/aws-java-sdk-bundle-1.12.267.jar + echo "Start building the ${SUBMARINE_IMAGE_NAME} docker image ..." +cd ${CURRENT_PATH} docker build -t ${SUBMARINE_IMAGE_NAME} . # clean temp file diff --git a/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py b/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py index 1a13f1b2..c138b61b 100644 --- a/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py +++ b/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py @@ -15,6 +15,7 @@ import logging import os +import subprocess from fs_prehandler import FsPreHandler from fsspec.implementations.arrow import HadoopFileSystem @@ -22,23 +23,36 @@ from fsspec.implementations.arrow import HadoopFileSystem class HDFSPreHandler(FsPreHandler): def __init__(self): - self.hdfs_host = os.environ['HDFS_HOST'] - self.hdfs_port = int(os.environ['HDFS_PORT']) - self.hdfs_source = os.environ['HDFS_SOURCE'] - self.dest_path = os.environ['DEST_PATH'] - self.enable_kerberos = os.environ['ENABLE_KERBEROS'] + self.hdfs_host=os.environ['HDFS_HOST'] + self.hdfs_port=os.environ['HDFS_PORT'] + self.hdfs_source=os.environ['HDFS_SOURCE'] + self.enable_kerberos=os.environ['ENABLE_KERBEROS'] + self.hadoop_home=os.environ['HADOOP_HOME'] + self.dest_minio_host=os.environ['DEST_MINIO_HOST'] + self.dest_minio_port=os.environ['DEST_MINIO_PORT'] + self.minio_access_key=os.environ['MINIO_ACCESS_KEY'] + self.minio_secert_key=os.environ['MINIO_SECRET_KEY'] + self.experiment_id=os.environ['EXPERIMENT_ID'] logging.info('HDFS_HOST:%s' % self.hdfs_host) - logging.info('HDFS_PORT:%d' % self.hdfs_port) + logging.info('HDFS_PORT:%s' % self.hdfs_port) logging.info('HDFS_SOURCE:%s' % self.hdfs_source) - logging.info('DEST_PATH:%s' % self.dest_path) + logging.info('MINIO_DEST_HOST:%s' % self.dest_minio_host) + logging.info('MINIO_DEST_PORT:%s' % self.dest_minio_port) logging.info('ENABLE_KERBEROS:%s' % self.enable_kerberos) - - self.fs = HadoopFileSystem(host=self.hdfs_host, port=self.hdfs_port) + logging.info('EXPERIMENT_ID:%s' % self.experiment_id) def process(self): - self.fs.get(self.hdfs_source, self.dest_path, recursive=True) - logging.info( - 'fetch data from hdfs://%s:%d/%s to %s complete' - % (self.hdfs_host, self.hdfs_port, self.hdfs_source, self.dest_path) - ) + dest_path = 'submarine/experiment/' + self.experiment_id + p = subprocess.run([self.hadoop_home+'/bin/hadoop', 'distcp' + , '-Dfs.s3a.endpoint=http://' + self.dest_minio_host + ':' + self.dest_minio_port + '/' + , '-Dfs.s3a.access.key=' + self.minio_access_key + , '-Dfs.s3a.secret.key=' + self.minio_secert_key + , '-Dfs.s3a.path.style.access=true' + , 'hdfs://'+self.hdfs_host + ':' + self.hdfs_port + '/' + self.hdfs_source + , 's3a://' + dest_path]) + + if p.returncode == 0: + logging.info('fetch data from hdfs://%s:%s/%s to %s complete' % (self.hdfs_host, self.hdfs_port, self.hdfs_source, dest_path)) + else: + raise Exception( 'error occured when fetching data from hdfs://%s:%s/%s to %s' % (self.hdfs_host, self.hdfs_port, self.hdfs_source, dest_path) ) diff --git a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java index b0c283ab..8c3024d6 100644 --- a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java +++ b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java @@ -28,6 +28,7 @@ public class ExperimentSpec { private ExperimentMeta meta; private EnvironmentSpec environment; private Map<String, ExperimentTaskSpec> spec; + private Map<String, String> experimentHandlerSpec; private CodeSpec code; public ExperimentSpec() {} @@ -63,6 +64,14 @@ public class ExperimentSpec { public void setCode(CodeSpec code) { this.code = code; } + + public Map<String, String> getExperimentHandlerSpec() { + return experimentHandlerSpec; + } + + public void setExperimentHandlerSpec(Map<String, String> experimentHandlerSpec) { + this.experimentHandlerSpec = experimentHandlerSpec; + } @Override public String toString() { diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java index 45c174cb..f026bc3d 100644 --- a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java +++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java @@ -22,10 +22,17 @@ package org.apache.submarine.server.submitter.k8s.model.mljob; import com.google.gson.JsonSyntaxException; import com.google.gson.annotations.SerializedName; import io.kubernetes.client.common.KubernetesObject; +import io.kubernetes.client.openapi.models.V1Container; +import io.kubernetes.client.openapi.models.V1EnvVar; import io.kubernetes.client.openapi.models.V1JobStatus; import io.kubernetes.client.openapi.models.V1ObjectMeta; import io.kubernetes.client.openapi.models.V1ObjectMetaBuilder; import io.kubernetes.client.openapi.models.V1Status; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException; import org.apache.submarine.server.api.common.CustomResourceType; import org.apache.submarine.server.api.experiment.Experiment; @@ -205,6 +212,85 @@ public abstract class MLJob implements KubernetesObject, K8sResource<Experiment> this.experimentId = experimentId; } + public V1Container getExperimentHandlerContainer(ExperimentSpec spec) { + Map<String, String> handlerSpec = spec.getExperimentHandlerSpec(); + + if (checkExperimentHanderArg(handlerSpec)) { + V1Container initContainer = new V1Container(); + initContainer.setImage("apache/submarine:experiment-prehandler-0.8.0-SNAPSHOT"); + initContainer.setName("ExperimentHandler".toLowerCase()); + List<V1EnvVar> envVar = new ArrayList<>(); + + V1EnvVar hdfsHostVar = new V1EnvVar(); + hdfsHostVar.setName("HDFS_HOST"); + hdfsHostVar.setValue(handlerSpec.get("HDFS_HOST")); + envVar.add(hdfsHostVar); + + V1EnvVar hdfsPortVar = new V1EnvVar(); + hdfsPortVar.setName("HDFS_PORT"); + hdfsPortVar.setValue(handlerSpec.get("HDFS_PORT")); + envVar.add(hdfsPortVar); + + V1EnvVar hdfsSourceVar = new V1EnvVar(); + hdfsSourceVar.setName("HDFS_SOURCE"); + hdfsSourceVar.setValue(handlerSpec.get("HDFS_SOURCE")); + envVar.add(hdfsSourceVar); + + V1EnvVar hdfsEnableKerberosVar = new V1EnvVar(); + hdfsEnableKerberosVar.setName("ENABLE_KERBEROS"); + hdfsEnableKerberosVar.setValue(handlerSpec.get("ENABLE_KERBEROS")); + envVar.add(hdfsEnableKerberosVar); + + V1EnvVar destMinIOHostVar = new V1EnvVar(); + destMinIOHostVar.setName("DEST_MINIO_HOST"); + destMinIOHostVar.setValue("submarine-minio-service"); + envVar.add(destMinIOHostVar); + + V1EnvVar destMinIOPortVar = new V1EnvVar(); + destMinIOPortVar.setName("DEST_MINIO_PORT"); + destMinIOPortVar.setValue("9000"); + envVar.add(destMinIOPortVar); + + V1EnvVar minIOAccessKeyVar = new V1EnvVar(); + minIOAccessKeyVar.setName("MINIO_ACCESS_KEY"); + minIOAccessKeyVar.setValue("submarine_minio"); + envVar.add(minIOAccessKeyVar); + + V1EnvVar minIOSecretKeyVar = new V1EnvVar(); + minIOSecretKeyVar.setName("MINIO_SECRET_KEY"); + minIOSecretKeyVar.setValue("submarine_minio"); + envVar.add(minIOSecretKeyVar); + + V1EnvVar fileSystemTypeVar = new V1EnvVar(); + fileSystemTypeVar.setName("FILE_SYSTEM_TYPE"); + fileSystemTypeVar.setValue(handlerSpec.get("FILE_SYSTEM_TYPE")); + envVar.add(fileSystemTypeVar); + + V1EnvVar experimentIdVar = new V1EnvVar(); + experimentIdVar.setName("EXPERIMENT_ID"); + experimentIdVar.setValue(this.experimentId); + envVar.add(experimentIdVar); + + initContainer.setEnv(envVar); + return initContainer; + } + return null; + } + + private boolean checkExperimentHanderArg(Map<String, String> handlerSpec) { + if (handlerSpec == null) + return false; + if (handlerSpec.get("FILE_SYSTEM_TYPE") == null) + return false; + else if (handlerSpec.get("FILE_SYSTEM_TYPE") == "HDFS") { + if ((handlerSpec.get("HDFS_HOST") == null) || (handlerSpec.get("HDFS_PORT") == null) || + (handlerSpec.get("HDFS_SOURCE") == null) || (handlerSpec.get("ENABLE_KERBEROS") == null)) { + return false; + } + } + return true; + } + /** * Convert MLJob object to return Experiment object */ diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java index 7da916c5..7d38ed59 100644 --- a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java +++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java @@ -22,6 +22,7 @@ package org.apache.submarine.server.submitter.k8s.model.pytorchjob; import com.google.gson.annotations.SerializedName; import io.kubernetes.client.custom.V1Patch; import io.kubernetes.client.openapi.ApiException; +import io.kubernetes.client.openapi.models.V1Container; import io.kubernetes.client.openapi.models.V1PodTemplateSpec; import io.kubernetes.client.openapi.models.V1Status; import io.kubernetes.client.util.generic.options.CreateOptions; @@ -78,6 +79,7 @@ public class PyTorchJob extends MLJob { throws InvalidSpecException { PyTorchJobSpec pyTorchJobSpec = new PyTorchJobSpec(); + V1Container initContainer = this.getExperimentHandlerContainer(experimentSpec); Map<PyTorchJobReplicaType, MLJobReplicaSpec> replicaSpecMap = new HashMap<>(); for (Map.Entry<String, ExperimentTaskSpec> entry : experimentSpec.getSpec().entrySet()) { String replicaType = entry.getKey(); @@ -86,6 +88,11 @@ public class PyTorchJob extends MLJob { MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec(); replicaSpec.setReplicas(taskSpec.getReplicas()); V1PodTemplateSpec podTemplateSpec = ExperimentSpecParser.parseTemplateSpec(taskSpec, experimentSpec); + + if (initContainer != null && replicaType.equals("Master")) { + podTemplateSpec.getSpec().addInitContainersItem(initContainer); + } + replicaSpec.setTemplate(podTemplateSpec); replicaSpecMap.put(PyTorchJobReplicaType.valueOf(replicaType), replicaSpec); } else { diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java index ee9d0d9f..1b9fe5a3 100644 --- a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java +++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java @@ -23,6 +23,7 @@ import com.google.gson.annotations.SerializedName; import io.kubernetes.client.custom.V1Patch; import io.kubernetes.client.openapi.ApiException; +import io.kubernetes.client.openapi.models.V1Container; import io.kubernetes.client.openapi.models.V1PodTemplateSpec; import io.kubernetes.client.openapi.models.V1Status; import io.kubernetes.client.util.generic.options.CreateOptions; @@ -68,6 +69,17 @@ public class TFJob extends MLJob { setGroup(CRD_TF_GROUP_V1); // set spec setSpec(parseTFJobSpec(experimentSpec)); + + V1Container initContainer = this.getExperimentHandlerContainer(experimentSpec); + if (initContainer != null) { + Map<TFJobReplicaType, MLJobReplicaSpec> replicaSet = this.getSpec().getReplicaSpecs(); + if (replicaSet.keySet().contains(TFJobReplicaType.Ps)) { + MLJobReplicaSpec psSpec = replicaSet.get(TFJobReplicaType.Ps); + psSpec.getTemplate().getSpec().addInitContainersItem(initContainer); + } else { + throw new InvalidSpecException("PreHandler only support TFJob with PS for now"); + } + } } @Override diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java index 0bacbfde..740b868c 100644 --- a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java +++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java @@ -22,6 +22,7 @@ package org.apache.submarine.server.submitter.k8s.model.xgboostjob; import com.google.gson.annotations.SerializedName; import io.kubernetes.client.custom.V1Patch; import io.kubernetes.client.openapi.ApiException; +import io.kubernetes.client.openapi.models.V1Container; import io.kubernetes.client.openapi.models.V1PodTemplateSpec; import io.kubernetes.client.openapi.models.V1Status; import io.kubernetes.client.util.generic.options.CreateOptions; @@ -75,11 +76,16 @@ public class XGBoostJob extends MLJob { for (Map.Entry<String, ExperimentTaskSpec> entry : experimentSpec.getSpec().entrySet()) { String replicaType = entry.getKey(); ExperimentTaskSpec taskSpec = entry.getValue(); - + V1Container initContainer = this.getExperimentHandlerContainer(experimentSpec); if (XGBoostJobReplicaType.isSupportedReplicaType(replicaType)) { MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec(); replicaSpec.setReplicas(taskSpec.getReplicas()); V1PodTemplateSpec podTemplateSpec = ExperimentSpecParser.parseTemplateSpec(taskSpec, experimentSpec); + + if (initContainer != null && replicaType.equals("Master")) { + podTemplateSpec.getSpec().addInitContainersItem(initContainer); + } + replicaSpec.setTemplate(podTemplateSpec); replicaSpecMap.put(XGBoostJobReplicaType.valueOf(replicaType), replicaSpec); } else { diff --git a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java index e6f8b344..acbf48d5 100644 --- a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java +++ b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java @@ -27,6 +27,8 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import io.kubernetes.client.openapi.models.V1ObjectMeta; import io.kubernetes.client.openapi.models.V1Volume; @@ -98,6 +100,7 @@ public class ExperimentSpecParserTest extends SpecBuilder { validateReplicaSpec(experimentSpec, tfJob, TFJobReplicaType.Ps); validateReplicaSpec(experimentSpec, tfJob, TFJobReplicaType.Worker); + validateExperimentHandlerMetadata(experimentSpec, tfJob); } @Test @@ -142,6 +145,7 @@ public class ExperimentSpecParserTest extends SpecBuilder { validateReplicaSpec(experimentSpec, pyTorchJob, PyTorchJobReplicaType.Master); validateReplicaSpec(experimentSpec, pyTorchJob, PyTorchJobReplicaType.Worker); + validateExperimentHandlerMetadata(experimentSpec, pyTorchJob); } @Test @@ -183,6 +187,7 @@ public class ExperimentSpecParserTest extends SpecBuilder { validateReplicaSpec(experimentSpec, xgboostJob, XGBoostJobReplicaType.Master); validateReplicaSpec(experimentSpec, xgboostJob, XGBoostJobReplicaType.Worker); + validateExperimentHandlerMetadata(experimentSpec, xgboostJob); } @Test @@ -218,7 +223,44 @@ public class ExperimentSpecParserTest extends SpecBuilder { Assert.assertEquals(K8sUtils.getNamespace(), actualMeta.getNamespace()); Assert.assertEquals(expectedMeta.getFramework().toLowerCase(), actualFramework); } - + + private void validateExperimentHandlerMetadata(ExperimentSpec experimentSpec, + MLJob mlJob) { + + if (experimentSpec.getExperimentHandlerSpec() == null || + experimentSpec.getExperimentHandlerSpec().isEmpty()) { + return; + } + + V1Container initContainer = null; + + MLJobReplicaSpec mlJobReplicaSpec = null; + if (mlJob instanceof PyTorchJob) { + mlJobReplicaSpec = ((PyTorchJob) mlJob).getSpec() + .getReplicaSpecs().get(PyTorchJobReplicaType.Master); + } else if (mlJob instanceof TFJob) { + mlJobReplicaSpec = ((TFJob) mlJob).getSpec() + .getReplicaSpecs().get(TFJobReplicaType.Ps); + } else if (mlJob instanceof XGBoostJob) { + mlJobReplicaSpec = ((XGBoostJob) mlJob).getSpec() + .getReplicaSpecs().get(XGBoostJobReplicaType.Master); + } + initContainer = mlJobReplicaSpec.getTemplate().getSpec().getInitContainers().get(0); + Map<String, String> varMap = initContainer.getEnv().stream() + .collect(Collectors.toMap(V1EnvVar::getName, V1EnvVar::getValue)); + Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("FILE_SYSTEM_TYPE") + , varMap.get("FILE_SYSTEM_TYPE")); + Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("HDFS_HOST") + , varMap.get("HDFS_HOST")); + Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("HDFS_PORT") + , varMap.get("HDFS_PORT")); + Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("HDFS_SOURCE") + , varMap.get("HDFS_SOURCE")); + Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("ENABLE_KERBEROS") + , varMap.get("ENABLE_KERBEROS")); + Assert.assertEquals(mlJob.getExperimentId(), varMap.get("EXPERIMENT_ID")); + } + private void validateReplicaSpec(ExperimentSpec experimentSpec, MLJob mlJob, MLJobReplicaType type) { MLJobReplicaSpec mlJobReplicaSpec = null; diff --git a/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json b/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json index ed1828fa..69b1101d 100644 --- a/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json +++ b/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json @@ -22,5 +22,12 @@ "replicas": 2, "resources": "cpu=1,memory=1024M" } + }, + experimentHandlerSpec": { + "FILE_SYSTEM_TYPE": "HDFS", + "HDFS_HOST": "127.0.0.1", + "HDFS_PORT": "9000", + "HDFS_SOURCE": "/tmp", + "ENABLE_KERBEROS": "false" } } diff --git a/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json b/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json index 2c806ddc..80ac9646 100644 --- a/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json +++ b/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json @@ -20,5 +20,12 @@ "replicas": 2, "resources": "cpu=2,memory=1024M,nvidia.com/gpu=1" } + }, + "experimentHandlerSpec": { + "FILE_SYSTEM_TYPE": "HDFS", + "HDFS_HOST": "127.0.0.1", + "HDFS_PORT": "9000", + "HDFS_SOURCE": "/tmp", + "ENABLE_KERBEROS": "false" } } diff --git a/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json b/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json index c4ba97f4..7498ba6c 100644 --- a/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json +++ b/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json @@ -22,5 +22,12 @@ "replicas": 2, "resources": "cpu=1,memory=1024M" } + }, + "experimentHandlerSpec": { + "FILE_SYSTEM_TYPE": "HDFS", + "HDFS_HOST": "127.0.0.1", + "HDFS_PORT": "9000", + "HDFS_SOURCE": "/tmp", + "ENABLE_KERBEROS": "false" } } --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@submarine.apache.org For additional commands, e-mail: dev-h...@submarine.apache.org