This is an automated email from the ASF dual-hosted git repository.

cdmikechen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d2a3e67 SUBMARINE-1283. copy data for experiment before it running 
via distcp to minio
3d2a3e67 is described below

commit 3d2a3e67530d2836ae9be8e5f9549bb30dcc53ca
Author: FatalLin <fatal...@gmail.com>
AuthorDate: Tue Sep 20 22:16:50 2022 +0800

    SUBMARINE-1283. copy data for experiment before it running via distcp to 
minio
    
    ### What is this PR for?
    This is a prototype of experiment prehandler, once the required arguments 
has been set, submarine would put an init-container on the main pod. The init 
container would copy the source data to minio to the path 
/submarine/${experimentId}.
    note: the init container would be add under:
    TFJob: ps
    PytorchJob: master
    XGboostJob: master
    
    ### What type of PR is it?
    Feature
    
    ### Todos
    add a housekeeping container to clean the copied data.
    
    ### What is the Jira issue?
    https://issues.apache.org/jira/browse/SUBMARINE-1283
    ### How should this be tested?
    Should add another test for it.
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Do the license files need updating? No
    * Are there breaking changes for older versions? No
    * Does this need new documentation? Yes
    
    Author: FatalLin <fatal...@gmail.com>
    
    Signed-off-by: Xiang Chen <cdmikec...@apache.org>
    
    Closes #989 from FatalLin/SUBMARINE-1283 and squashes the following commits:
    
    ef804903 [FatalLin] fix conflict
    d816b880 [FatalLin] modify test cases
    171da696 [FatalLin] code polish
    36297f31 [FatalLin] polish code
    2c62fb7c [FatalLin] fix
    17135f2a [FatalLin] fix script
    314a3866 [FatalLin] fix build script
    028eff7d [FatalLin] fix script
    79d91c88 [FatalLin] prototype of experiment prehandler
    eaa25e1d [FatalLin] Merge branch 'master' of 
https://github.com/apache/submarine into SUBMARINE-1283
    d81841b1 [FatalLin] Merge branch 'master' of 
https://github.com/apache/submarine into SUBMARINE-1283
    44473203 [FatalLin] for debugging
---
 .github/scripts/build-image-locally-v3.sh          |  4 +-
 .github/scripts/build-image-locally.sh             |  4 +-
 .../docker-images/experiment-prehandler/Dockerfile |  3 +-
 .../docker-images/experiment-prehandler/build.sh   | 10 +++
 .../fs_prehandler/hdfs_prehandler.py               | 42 +++++++----
 .../submarine/server/api/spec/ExperimentSpec.java  |  9 +++
 .../server/submitter/k8s/model/mljob/MLJob.java    | 86 ++++++++++++++++++++++
 .../submitter/k8s/model/pytorchjob/PyTorchJob.java |  7 ++
 .../server/submitter/k8s/model/tfjob/TFJob.java    | 12 +++
 .../submitter/k8s/model/xgboostjob/XGBoostJob.java |  8 +-
 .../submitter/k8s/ExperimentSpecParserTest.java    | 44 ++++++++++-
 .../src/test/resources/pytorch_job_req.json        |  7 ++
 .../src/test/resources/tf_mnist_req.json           |  7 ++
 .../src/test/resources/xgboost_job_req.json        |  7 ++
 14 files changed, 231 insertions(+), 19 deletions(-)

diff --git a/.github/scripts/build-image-locally-v3.sh 
b/.github/scripts/build-image-locally-v3.sh
index 81203af7..00fdb770 100755
--- a/.github/scripts/build-image-locally-v3.sh
+++ b/.github/scripts/build-image-locally-v3.sh
@@ -17,12 +17,14 @@
 #
 
 SUBMARINE_VERSION="0.8.0-SNAPSHOT"
-FOLDER_LIST=("database" "mlflow" "submarine" "operator-v3")
+FOLDER_LIST=("database" "mlflow" "submarine" "operator-v3" "agent" 
"experiment-prehandler")
 IMAGE_LIST=(
   "apache/submarine:database-${SUBMARINE_VERSION}"
   "apache/submarine:mlflow-${SUBMARINE_VERSION}"
   "apache/submarine:server-${SUBMARINE_VERSION}"
   "apache/submarine:operator-${SUBMARINE_VERSION}"
+  "apache/submarine:agent-${SUBMARINE_VERSION}"
+  "apache/submarine:experiment-prehandler-${SUBMARINE_VERSION}"
 )
 
 for i in "${!IMAGE_LIST[@]}"
diff --git a/.github/scripts/build-image-locally.sh 
b/.github/scripts/build-image-locally.sh
index a53de5b3..4d35d690 100755
--- a/.github/scripts/build-image-locally.sh
+++ b/.github/scripts/build-image-locally.sh
@@ -17,12 +17,14 @@
 #
 
 SUBMARINE_VERSION="0.8.0-SNAPSHOT"
-FOLDER_LIST=("database" "mlflow" "submarine" "operator")
+FOLDER_LIST=("database" "mlflow" "submarine" "operator" "agent" 
"experiment-prehandler")
 IMAGE_LIST=(
   "apache/submarine:database-${SUBMARINE_VERSION}"
   "apache/submarine:mlflow-${SUBMARINE_VERSION}"
   "apache/submarine:server-${SUBMARINE_VERSION}"
   "apache/submarine:operator-${SUBMARINE_VERSION}"
+  "apache/submarine:agent-${SUBMARINE_VERSION}"
+  "apache/submarine:experiment-prehandler-${SUBMARINE_VERSION}"
 )
 
 for i in "${!IMAGE_LIST[@]}"
diff --git a/dev-support/docker-images/experiment-prehandler/Dockerfile 
b/dev-support/docker-images/experiment-prehandler/Dockerfile
index 87307d07..7a6c7e69 100644
--- a/dev-support/docker-images/experiment-prehandler/Dockerfile
+++ b/dev-support/docker-images/experiment-prehandler/Dockerfile
@@ -21,7 +21,8 @@ RUN apt-get -y install python3 python3-pip bash tini
 
 ADD ./tmp/hadoop-3.3.3.tar.gz /opt/
 ADD ./tmp/submarine-experiment-prehandler /opt/submarine-experiment-prehandler
-
+ADD ./tmp/hadoop-aws-3.3.3.jar /opt/hadoop-3.3.3/share/hadoop/hdfs
+ADD ./tmp/aws-java-sdk-bundle-1.12.267.jar /opt/hadoop-3.3.3/share/hadoop/hdfs
 
 ENV HADOOP_HOME=/opt/hadoop-3.3.3
 ENV ARROW_LIBHDFS_DIR=/opt/hadoop-3.3.3/lib/native
diff --git a/dev-support/docker-images/experiment-prehandler/build.sh 
b/dev-support/docker-images/experiment-prehandler/build.sh
index fcdedb05..c1e94a37 100755
--- a/dev-support/docker-images/experiment-prehandler/build.sh
+++ b/dev-support/docker-images/experiment-prehandler/build.sh
@@ -19,6 +19,12 @@ set -euxo pipefail
 SUBMARINE_VERSION=0.8.0-SNAPSHOT
 
SUBMARINE_IMAGE_NAME="apache/submarine:experiment-prehandler-${SUBMARINE_VERSION}"
 
+if [ -L ${BASH_SOURCE-$0} ]; then
+  PWD=$(dirname $(readlink "${BASH_SOURCE-$0}"))
+else
+  PWD=$(dirname ${BASH_SOURCE-$0})
+fi
+
 export CURRENT_PATH=$(cd "${PWD}">/dev/null; pwd)
 export SUBMARINE_HOME=${CURRENT_PATH}/../../..
 
@@ -33,7 +39,11 @@ trap "test -f $tmpfile && rm $tmpfile" RETURN
 curl -L -o $tmpfile ${HADOOP_TAR_URL}
 mv $tmpfile ${CURRENT_PATH}/tmp/hadoop-3.3.3.tar.gz
 
+curl -L -o ${CURRENT_PATH}/tmp/hadoop-aws-3.3.3.jar 
https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.3.3/hadoop-aws-3.3.3.jar
 
+curl -L -o ${CURRENT_PATH}/tmp/aws-java-sdk-bundle-1.12.267.jar 
https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.267/aws-java-sdk-bundle-1.12.267.jar
+
 echo "Start building the ${SUBMARINE_IMAGE_NAME} docker image ..."
+cd ${CURRENT_PATH}
 docker build -t ${SUBMARINE_IMAGE_NAME} .
 
 # clean temp file
diff --git a/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py 
b/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py
index 1a13f1b2..c138b61b 100644
--- a/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py
+++ b/submarine-experiment-prehandler/fs_prehandler/hdfs_prehandler.py
@@ -15,6 +15,7 @@
 
 import logging
 import os
+import subprocess
 
 from fs_prehandler import FsPreHandler
 from fsspec.implementations.arrow import HadoopFileSystem
@@ -22,23 +23,36 @@ from fsspec.implementations.arrow import HadoopFileSystem
 
 class HDFSPreHandler(FsPreHandler):
     def __init__(self):
-        self.hdfs_host = os.environ['HDFS_HOST']
-        self.hdfs_port = int(os.environ['HDFS_PORT'])
-        self.hdfs_source = os.environ['HDFS_SOURCE']
-        self.dest_path = os.environ['DEST_PATH']
-        self.enable_kerberos = os.environ['ENABLE_KERBEROS']
+        self.hdfs_host=os.environ['HDFS_HOST']
+        self.hdfs_port=os.environ['HDFS_PORT']
+        self.hdfs_source=os.environ['HDFS_SOURCE']
+        self.enable_kerberos=os.environ['ENABLE_KERBEROS']
+        self.hadoop_home=os.environ['HADOOP_HOME']
+        self.dest_minio_host=os.environ['DEST_MINIO_HOST']
+        self.dest_minio_port=os.environ['DEST_MINIO_PORT']
+        self.minio_access_key=os.environ['MINIO_ACCESS_KEY']
+        self.minio_secert_key=os.environ['MINIO_SECRET_KEY']
+        self.experiment_id=os.environ['EXPERIMENT_ID']
 
         logging.info('HDFS_HOST:%s' % self.hdfs_host)
-        logging.info('HDFS_PORT:%d' % self.hdfs_port)
+        logging.info('HDFS_PORT:%s' % self.hdfs_port)
         logging.info('HDFS_SOURCE:%s' % self.hdfs_source)
-        logging.info('DEST_PATH:%s' % self.dest_path)
+        logging.info('MINIO_DEST_HOST:%s' % self.dest_minio_host)
+        logging.info('MINIO_DEST_PORT:%s' % self.dest_minio_port)
         logging.info('ENABLE_KERBEROS:%s' % self.enable_kerberos)
-
-        self.fs = HadoopFileSystem(host=self.hdfs_host, port=self.hdfs_port)
+        logging.info('EXPERIMENT_ID:%s' % self.experiment_id)
 
     def process(self):
-        self.fs.get(self.hdfs_source, self.dest_path, recursive=True)
-        logging.info(
-            'fetch data from hdfs://%s:%d/%s to %s complete'
-            % (self.hdfs_host, self.hdfs_port, self.hdfs_source, 
self.dest_path)
-        )
+        dest_path = 'submarine/experiment/' + self.experiment_id
+        p = subprocess.run([self.hadoop_home+'/bin/hadoop', 'distcp'
+            , '-Dfs.s3a.endpoint=http://' + self.dest_minio_host + ':' + 
self.dest_minio_port + '/'
+            , '-Dfs.s3a.access.key=' + self.minio_access_key
+            , '-Dfs.s3a.secret.key=' + self.minio_secert_key
+            , '-Dfs.s3a.path.style.access=true'
+            , 'hdfs://'+self.hdfs_host + ':' + self.hdfs_port + '/' + 
self.hdfs_source
+            , 's3a://' + dest_path])
+
+        if p.returncode == 0:
+            logging.info('fetch data from hdfs://%s:%s/%s to %s complete' % 
(self.hdfs_host, self.hdfs_port, self.hdfs_source, dest_path))
+        else:
+            raise Exception( 'error occured when fetching data from 
hdfs://%s:%s/%s to %s' % (self.hdfs_host, self.hdfs_port, self.hdfs_source, 
dest_path) )
diff --git 
a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java
 
b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java
index b0c283ab..8c3024d6 100644
--- 
a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java
+++ 
b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentSpec.java
@@ -28,6 +28,7 @@ public class ExperimentSpec {
   private ExperimentMeta meta;
   private EnvironmentSpec environment;
   private Map<String, ExperimentTaskSpec> spec;
+  private Map<String, String> experimentHandlerSpec;
   private CodeSpec code;
 
   public ExperimentSpec() {}
@@ -63,6 +64,14 @@ public class ExperimentSpec {
   public void setCode(CodeSpec code) {
     this.code = code;
   }
+  
+  public Map<String, String> getExperimentHandlerSpec() {
+    return experimentHandlerSpec;
+  }
+
+  public void setExperimentHandlerSpec(Map<String, String> 
experimentHandlerSpec) {
+    this.experimentHandlerSpec = experimentHandlerSpec;
+  }
 
   @Override
   public String toString() {
diff --git 
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java
 
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java
index 45c174cb..f026bc3d 100644
--- 
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java
+++ 
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/mljob/MLJob.java
@@ -22,10 +22,17 @@ package 
org.apache.submarine.server.submitter.k8s.model.mljob;
 import com.google.gson.JsonSyntaxException;
 import com.google.gson.annotations.SerializedName;
 import io.kubernetes.client.common.KubernetesObject;
+import io.kubernetes.client.openapi.models.V1Container;
+import io.kubernetes.client.openapi.models.V1EnvVar;
 import io.kubernetes.client.openapi.models.V1JobStatus;
 import io.kubernetes.client.openapi.models.V1ObjectMeta;
 import io.kubernetes.client.openapi.models.V1ObjectMetaBuilder;
 import io.kubernetes.client.openapi.models.V1Status;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
 import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
 import org.apache.submarine.server.api.common.CustomResourceType;
 import org.apache.submarine.server.api.experiment.Experiment;
@@ -205,6 +212,85 @@ public abstract class MLJob implements KubernetesObject, 
K8sResource<Experiment>
     this.experimentId = experimentId;
   }
 
+  public V1Container getExperimentHandlerContainer(ExperimentSpec spec) {
+    Map<String, String> handlerSpec = spec.getExperimentHandlerSpec();
+    
+    if (checkExperimentHanderArg(handlerSpec)) {
+      V1Container initContainer = new V1Container();
+      
initContainer.setImage("apache/submarine:experiment-prehandler-0.8.0-SNAPSHOT");
+      initContainer.setName("ExperimentHandler".toLowerCase());
+      List<V1EnvVar> envVar = new ArrayList<>();
+      
+      V1EnvVar hdfsHostVar = new V1EnvVar();
+      hdfsHostVar.setName("HDFS_HOST");
+      hdfsHostVar.setValue(handlerSpec.get("HDFS_HOST"));
+      envVar.add(hdfsHostVar);
+      
+      V1EnvVar hdfsPortVar = new V1EnvVar();
+      hdfsPortVar.setName("HDFS_PORT");
+      hdfsPortVar.setValue(handlerSpec.get("HDFS_PORT"));
+      envVar.add(hdfsPortVar);
+      
+      V1EnvVar hdfsSourceVar = new V1EnvVar();
+      hdfsSourceVar.setName("HDFS_SOURCE");
+      hdfsSourceVar.setValue(handlerSpec.get("HDFS_SOURCE"));
+      envVar.add(hdfsSourceVar);
+      
+      V1EnvVar hdfsEnableKerberosVar = new V1EnvVar();
+      hdfsEnableKerberosVar.setName("ENABLE_KERBEROS");
+      hdfsEnableKerberosVar.setValue(handlerSpec.get("ENABLE_KERBEROS"));
+      envVar.add(hdfsEnableKerberosVar);
+      
+      V1EnvVar destMinIOHostVar = new V1EnvVar();
+      destMinIOHostVar.setName("DEST_MINIO_HOST");
+      destMinIOHostVar.setValue("submarine-minio-service");
+      envVar.add(destMinIOHostVar);
+      
+      V1EnvVar destMinIOPortVar = new V1EnvVar();
+      destMinIOPortVar.setName("DEST_MINIO_PORT");
+      destMinIOPortVar.setValue("9000");
+      envVar.add(destMinIOPortVar);
+      
+      V1EnvVar minIOAccessKeyVar = new V1EnvVar();
+      minIOAccessKeyVar.setName("MINIO_ACCESS_KEY");
+      minIOAccessKeyVar.setValue("submarine_minio");
+      envVar.add(minIOAccessKeyVar);
+      
+      V1EnvVar minIOSecretKeyVar = new V1EnvVar();
+      minIOSecretKeyVar.setName("MINIO_SECRET_KEY");
+      minIOSecretKeyVar.setValue("submarine_minio");
+      envVar.add(minIOSecretKeyVar);
+      
+      V1EnvVar fileSystemTypeVar = new V1EnvVar();
+      fileSystemTypeVar.setName("FILE_SYSTEM_TYPE");
+      fileSystemTypeVar.setValue(handlerSpec.get("FILE_SYSTEM_TYPE"));
+      envVar.add(fileSystemTypeVar);
+      
+      V1EnvVar experimentIdVar = new V1EnvVar();
+      experimentIdVar.setName("EXPERIMENT_ID");
+      experimentIdVar.setValue(this.experimentId);
+      envVar.add(experimentIdVar);
+      
+      initContainer.setEnv(envVar);
+      return initContainer;
+    }
+    return null;
+  }
+  
+  private boolean checkExperimentHanderArg(Map<String, String> handlerSpec) {
+    if (handlerSpec == null)
+      return false;
+    if (handlerSpec.get("FILE_SYSTEM_TYPE") == null)
+      return false;
+    else if (handlerSpec.get("FILE_SYSTEM_TYPE") == "HDFS") {
+      if ((handlerSpec.get("HDFS_HOST") == null) || 
(handlerSpec.get("HDFS_PORT") == null) ||
+          (handlerSpec.get("HDFS_SOURCE") == null) || 
(handlerSpec.get("ENABLE_KERBEROS") == null)) {
+        return false;
+      }
+    }
+    return true;
+  }
+  
   /**
    * Convert MLJob object to return Experiment object
    */
diff --git 
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java
 
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java
index 7da916c5..7d38ed59 100644
--- 
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java
+++ 
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/pytorchjob/PyTorchJob.java
@@ -22,6 +22,7 @@ package 
org.apache.submarine.server.submitter.k8s.model.pytorchjob;
 import com.google.gson.annotations.SerializedName;
 import io.kubernetes.client.custom.V1Patch;
 import io.kubernetes.client.openapi.ApiException;
+import io.kubernetes.client.openapi.models.V1Container;
 import io.kubernetes.client.openapi.models.V1PodTemplateSpec;
 import io.kubernetes.client.openapi.models.V1Status;
 import io.kubernetes.client.util.generic.options.CreateOptions;
@@ -78,6 +79,7 @@ public class PyTorchJob extends MLJob {
           throws InvalidSpecException {
     PyTorchJobSpec pyTorchJobSpec = new PyTorchJobSpec();
 
+    V1Container initContainer = 
this.getExperimentHandlerContainer(experimentSpec);
     Map<PyTorchJobReplicaType, MLJobReplicaSpec> replicaSpecMap = new 
HashMap<>();
     for (Map.Entry<String, ExperimentTaskSpec> entry : 
experimentSpec.getSpec().entrySet()) {
       String replicaType = entry.getKey();
@@ -86,6 +88,11 @@ public class PyTorchJob extends MLJob {
         MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
         replicaSpec.setReplicas(taskSpec.getReplicas());
         V1PodTemplateSpec podTemplateSpec = 
ExperimentSpecParser.parseTemplateSpec(taskSpec, experimentSpec);
+        
+        if (initContainer != null && replicaType.equals("Master")) {
+          podTemplateSpec.getSpec().addInitContainersItem(initContainer);  
+        }
+        
         replicaSpec.setTemplate(podTemplateSpec);
         replicaSpecMap.put(PyTorchJobReplicaType.valueOf(replicaType), 
replicaSpec);
       } else {
diff --git 
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
 
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
index ee9d0d9f..1b9fe5a3 100644
--- 
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
+++ 
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/tfjob/TFJob.java
@@ -23,6 +23,7 @@ import com.google.gson.annotations.SerializedName;
 
 import io.kubernetes.client.custom.V1Patch;
 import io.kubernetes.client.openapi.ApiException;
+import io.kubernetes.client.openapi.models.V1Container;
 import io.kubernetes.client.openapi.models.V1PodTemplateSpec;
 import io.kubernetes.client.openapi.models.V1Status;
 import io.kubernetes.client.util.generic.options.CreateOptions;
@@ -68,6 +69,17 @@ public class TFJob extends MLJob {
     setGroup(CRD_TF_GROUP_V1);
     // set spec
     setSpec(parseTFJobSpec(experimentSpec));
+    
+    V1Container initContainer = 
this.getExperimentHandlerContainer(experimentSpec);
+    if (initContainer != null) {
+      Map<TFJobReplicaType, MLJobReplicaSpec> replicaSet = 
this.getSpec().getReplicaSpecs();
+      if (replicaSet.keySet().contains(TFJobReplicaType.Ps)) {
+        MLJobReplicaSpec psSpec = replicaSet.get(TFJobReplicaType.Ps);
+        psSpec.getTemplate().getSpec().addInitContainersItem(initContainer);
+      } else {
+        throw new InvalidSpecException("PreHandler only support TFJob with PS 
for now");
+      }
+    }
   }
 
   @Override
diff --git 
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
 
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
index 0bacbfde..740b868c 100644
--- 
a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
+++ 
b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
@@ -22,6 +22,7 @@ package 
org.apache.submarine.server.submitter.k8s.model.xgboostjob;
 import com.google.gson.annotations.SerializedName;
 import io.kubernetes.client.custom.V1Patch;
 import io.kubernetes.client.openapi.ApiException;
+import io.kubernetes.client.openapi.models.V1Container;
 import io.kubernetes.client.openapi.models.V1PodTemplateSpec;
 import io.kubernetes.client.openapi.models.V1Status;
 import io.kubernetes.client.util.generic.options.CreateOptions;
@@ -75,11 +76,16 @@ public class XGBoostJob extends MLJob {
     for (Map.Entry<String, ExperimentTaskSpec> entry : 
experimentSpec.getSpec().entrySet()) {
       String replicaType = entry.getKey();
       ExperimentTaskSpec taskSpec = entry.getValue();
-
+      V1Container initContainer = 
this.getExperimentHandlerContainer(experimentSpec);
       if (XGBoostJobReplicaType.isSupportedReplicaType(replicaType)) {
         MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
         replicaSpec.setReplicas(taskSpec.getReplicas());
         V1PodTemplateSpec podTemplateSpec = 
ExperimentSpecParser.parseTemplateSpec(taskSpec, experimentSpec);
+        
+        if (initContainer != null && replicaType.equals("Master")) {
+          podTemplateSpec.getSpec().addInitContainersItem(initContainer);  
+        }
+        
         replicaSpec.setTemplate(podTemplateSpec);
         replicaSpecMap.put(XGBoostJobReplicaType.valueOf(replicaType), 
replicaSpec);
       } else {
diff --git 
a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
 
b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
index e6f8b344..acbf48d5 100644
--- 
a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
+++ 
b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
@@ -27,6 +27,8 @@ import java.sql.SQLException;
 import java.sql.Statement;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
 
 import io.kubernetes.client.openapi.models.V1ObjectMeta;
 import io.kubernetes.client.openapi.models.V1Volume;
@@ -98,6 +100,7 @@ public class ExperimentSpecParserTest extends SpecBuilder {
 
     validateReplicaSpec(experimentSpec, tfJob, TFJobReplicaType.Ps);
     validateReplicaSpec(experimentSpec, tfJob, TFJobReplicaType.Worker);
+    validateExperimentHandlerMetadata(experimentSpec, tfJob);
   }
 
   @Test
@@ -142,6 +145,7 @@ public class ExperimentSpecParserTest extends SpecBuilder {
 
     validateReplicaSpec(experimentSpec, pyTorchJob, 
PyTorchJobReplicaType.Master);
     validateReplicaSpec(experimentSpec, pyTorchJob, 
PyTorchJobReplicaType.Worker);
+    validateExperimentHandlerMetadata(experimentSpec, pyTorchJob);
   }
 
   @Test
@@ -183,6 +187,7 @@ public class ExperimentSpecParserTest extends SpecBuilder {
 
     validateReplicaSpec(experimentSpec, xgboostJob, 
XGBoostJobReplicaType.Master);
     validateReplicaSpec(experimentSpec, xgboostJob, 
XGBoostJobReplicaType.Worker);
+    validateExperimentHandlerMetadata(experimentSpec, xgboostJob);
   }
 
   @Test
@@ -218,7 +223,44 @@ public class ExperimentSpecParserTest extends SpecBuilder {
     Assert.assertEquals(K8sUtils.getNamespace(), actualMeta.getNamespace());
     Assert.assertEquals(expectedMeta.getFramework().toLowerCase(), 
actualFramework);
   }
-
+  
+  private void validateExperimentHandlerMetadata(ExperimentSpec experimentSpec,
+      MLJob mlJob) {
+    
+    if (experimentSpec.getExperimentHandlerSpec() == null || 
+        experimentSpec.getExperimentHandlerSpec().isEmpty()) {
+      return;
+    }
+      
+    V1Container initContainer = null;
+    
+    MLJobReplicaSpec mlJobReplicaSpec = null;
+    if (mlJob instanceof PyTorchJob) {
+      mlJobReplicaSpec = ((PyTorchJob) mlJob).getSpec()
+        .getReplicaSpecs().get(PyTorchJobReplicaType.Master);
+    } else if (mlJob instanceof TFJob) {
+      mlJobReplicaSpec = ((TFJob) mlJob).getSpec()
+        .getReplicaSpecs().get(TFJobReplicaType.Ps);
+    } else if (mlJob instanceof XGBoostJob) {
+      mlJobReplicaSpec = ((XGBoostJob) mlJob).getSpec()
+        .getReplicaSpecs().get(XGBoostJobReplicaType.Master);
+    }
+    initContainer = 
mlJobReplicaSpec.getTemplate().getSpec().getInitContainers().get(0);
+    Map<String, String> varMap = initContainer.getEnv().stream()
+        .collect(Collectors.toMap(V1EnvVar::getName, V1EnvVar::getValue));
+    
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("FILE_SYSTEM_TYPE")
+        , varMap.get("FILE_SYSTEM_TYPE"));
+    
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("HDFS_HOST")
+        , varMap.get("HDFS_HOST"));
+    
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("HDFS_PORT")
+        , varMap.get("HDFS_PORT"));
+    
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("HDFS_SOURCE")
+        , varMap.get("HDFS_SOURCE"));
+    
Assert.assertEquals(experimentSpec.getExperimentHandlerSpec().get("ENABLE_KERBEROS")
+        , varMap.get("ENABLE_KERBEROS"));
+    Assert.assertEquals(mlJob.getExperimentId(), varMap.get("EXPERIMENT_ID")); 
+  }
+  
   private void validateReplicaSpec(ExperimentSpec experimentSpec,
       MLJob mlJob, MLJobReplicaType type) {
     MLJobReplicaSpec mlJobReplicaSpec = null;
diff --git 
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json
 
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json
index ed1828fa..69b1101d 100644
--- 
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json
+++ 
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/pytorch_job_req.json
@@ -22,5 +22,12 @@
       "replicas": 2,
       "resources": "cpu=1,memory=1024M"
     }
+  },
+   experimentHandlerSpec": {
+    "FILE_SYSTEM_TYPE": "HDFS",
+    "HDFS_HOST": "127.0.0.1",
+    "HDFS_PORT": "9000",
+    "HDFS_SOURCE": "/tmp",
+    "ENABLE_KERBEROS": "false"
   }
 }
diff --git 
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json
 
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json
index 2c806ddc..80ac9646 100644
--- 
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json
+++ 
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/tf_mnist_req.json
@@ -20,5 +20,12 @@
       "replicas": 2,
       "resources": "cpu=2,memory=1024M,nvidia.com/gpu=1"
     }
+  },
+  "experimentHandlerSpec": {
+      "FILE_SYSTEM_TYPE": "HDFS",
+      "HDFS_HOST": "127.0.0.1",
+      "HDFS_PORT": "9000",
+      "HDFS_SOURCE": "/tmp",
+      "ENABLE_KERBEROS": "false"
   }
 }
diff --git 
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
 
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
index c4ba97f4..7498ba6c 100644
--- 
a/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
+++ 
b/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
@@ -22,5 +22,12 @@
       "replicas": 2,
       "resources": "cpu=1,memory=1024M"
     }
+  },
+  "experimentHandlerSpec": {
+    "FILE_SYSTEM_TYPE": "HDFS",
+    "HDFS_HOST": "127.0.0.1",
+    "HDFS_PORT": "9000",
+    "HDFS_SOURCE": "/tmp",
+    "ENABLE_KERBEROS": "false"
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@submarine.apache.org
For additional commands, e-mail: dev-h...@submarine.apache.org

Reply via email to