This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 9dd3b70  [FLINK-26404] Support non-local file systems
9dd3b70 is described below

commit 9dd3b70e29e936ace755ff607364f21fb332b9be
Author: Mr-Mu <[email protected]>
AuthorDate: Wed Mar 9 23:05:20 2022 -0600

    [FLINK-26404] Support non-local file systems
    
    This closes #68.
---
 flink-ml-core/pom.xml                              | 62 +++++++++++++++
 .../org/apache/flink/ml/util/ReadWriteUtils.java   | 60 +++++++--------
 .../apache/flink/ml/util/ReadWriteUtilsTest.java   | 90 ++++++++++++++++++++++
 3 files changed, 182 insertions(+), 30 deletions(-)

diff --git a/flink-ml-core/pom.xml b/flink-ml-core/pom.xml
index 73bc204..c351270 100644
--- a/flink-ml-core/pom.xml
+++ b/flink-ml-core/pom.xml
@@ -31,6 +31,10 @@ under the License.
   <artifactId>flink-ml-core_${scala.binary.version}</artifactId>
   <name>Flink ML : Core</name>
 
+  <properties>
+    <hadoop.version>2.4.1</hadoop.version>
+  </properties>
+
   <dependencies>
     <dependency>
       <groupId>org.apache.flink</groupId>
@@ -108,5 +112,63 @@ under the License.
       <type>jar</type>
       <scope>test</scope>
     </dependency>
+
+    <!-- hdfs is required for the data cache test -->
+    <dependency>
+      <groupId>org.apache.hadoop</groupId>
+      <artifactId>hadoop-hdfs</artifactId>
+      <scope>test</scope>
+      <version>${hadoop.version}</version>
+      <exclusions>
+        <exclusion>
+          <groupId>log4j</groupId>
+          <artifactId>log4j</artifactId>
+        </exclusion>
+      </exclusions>
+    </dependency>
+
+    <dependency>
+      <groupId>org.apache.hadoop</groupId>
+      <artifactId>hadoop-common</artifactId>
+      <scope>test</scope>
+      <version>${hadoop.version}</version>
+      <exclusions>
+        <exclusion>
+          <groupId>log4j</groupId>
+          <artifactId>log4j</artifactId>
+        </exclusion>
+        <exclusion>
+          <groupId>org.slf4j</groupId>
+          <artifactId>slf4j-log4j12</artifactId>
+        </exclusion>
+        <exclusion>
+          <!-- This dependency is no longer shipped with the JDK since Java 
9.-->
+          <groupId>jdk.tools</groupId>
+          <artifactId>jdk.tools</artifactId>
+        </exclusion>
+      </exclusions>
+    </dependency>
+
+    <dependency>
+      <groupId>org.apache.hadoop</groupId>
+      <artifactId>hadoop-minicluster</artifactId>
+      <scope>test</scope>
+      <version>${hadoop.version}</version>
+      <exclusions>
+        <exclusion>
+          <groupId>log4j</groupId>
+          <artifactId>log4j</artifactId>
+        </exclusion>
+        <exclusion>
+          <groupId>org.slf4j</groupId>
+          <artifactId>slf4j-log4j12</artifactId>
+        </exclusion>
+        <exclusion>
+          <!-- This dependency is no longer shipped with the JDK since Java 
9.-->
+          <groupId>jdk.tools</groupId>
+          <artifactId>jdk.tools</artifactId>
+        </exclusion>
+      </exclusions>
+    </dependency>
   </dependencies>
 </project>
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index eedc74d..674440b 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -24,6 +24,8 @@ import org.apache.flink.api.connector.source.Source;
 import org.apache.flink.connector.file.sink.FileSink;
 import org.apache.flink.connector.file.src.FileSource;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
 import org.apache.flink.ml.api.Stage;
 import org.apache.flink.ml.builder.Graph;
 import org.apache.flink.ml.builder.GraphData;
@@ -42,13 +44,11 @@ import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMap
 import java.io.BufferedReader;
 import java.io.BufferedWriter;
 import java.io.File;
-import java.io.FileReader;
-import java.io.FileWriter;
 import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
-import java.nio.file.Path;
-import java.nio.file.Paths;
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -92,7 +92,7 @@ public class ReadWriteUtils {
     public static void saveMetadata(Stage<?> stage, String path, Map<String, 
?> extraMetadata)
             throws IOException {
         // Creates parent directories if not already created.
-        new File(path).mkdirs();
+        FileSystem fs = mkdirs(path);
 
         Map<String, Object> metadata = new HashMap<>(extraMetadata);
         metadata.put("className", stage.getClass().getName());
@@ -101,11 +101,14 @@ public class ReadWriteUtils {
         // TODO: add version in the metadata.
         String metadataStr = OBJECT_MAPPER.writeValueAsString(metadata);
 
-        File metadataFile = new File(path, "metadata");
-        if (!metadataFile.createNewFile()) {
-            throw new IOException("File " + metadataFile.toString() + " 
already exists.");
+        Path metadataPath = new Path(path, "metadata");
+        if (fs.exists(metadataPath)) {
+            throw new IOException("File " + metadataPath + " already exists.");
         }
-        try (BufferedWriter writer = new BufferedWriter(new 
FileWriter(metadataFile))) {
+        try (BufferedWriter writer =
+                new BufferedWriter(
+                        new OutputStreamWriter(
+                                fs.create(metadataPath, 
FileSystem.WriteMode.NO_OVERWRITE)))) {
             writer.write(metadataStr);
         }
     }
@@ -125,20 +128,7 @@ public class ReadWriteUtils {
 
     /** Returns a subdirectory of the given path for saving/loading model 
data. */
     private static String getDataPath(String path) {
-        return Paths.get(path, "data").toString();
-    }
-
-    /** Returns all data files under the given path as a list of paths. */
-    private static org.apache.flink.core.fs.Path[] getDataPaths(String path) {
-        String dataPath = getDataPath(path);
-        File[] files = new File(dataPath).listFiles();
-
-        org.apache.flink.core.fs.Path[] paths = new 
org.apache.flink.core.fs.Path[files.length];
-        for (int i = 0; i < paths.length; i++) {
-            paths[i] = org.apache.flink.core.fs.Path.fromLocalFile(files[i]);
-        }
-
-        return paths;
+        return new Path(path, "data").toString();
     }
 
     /**
@@ -153,9 +143,11 @@ public class ReadWriteUtils {
      */
     public static Map<String, ?> loadMetadata(String path, String 
expectedClassName)
             throws IOException {
-        Path metadataPath = Paths.get(path, "metadata");
+        Path metadataPath = new Path(path, "metadata");
+        FileSystem fs = metadataPath.getFileSystem();
+
         StringBuilder buffer = new StringBuilder();
-        try (BufferedReader br = new BufferedReader(new 
FileReader(metadataPath.toString()))) {
+        try (BufferedReader br = new BufferedReader(new 
InputStreamReader(fs.open(metadataPath)))) {
             String line;
             while ((line = br.readLine()) != null) {
                 if (!line.startsWith("#")) {
@@ -184,9 +176,10 @@ public class ReadWriteUtils {
     // with zero or more `0` to have the same length as numStages. The 
resulting string can be
     // used as the directory to save a stage of the Pipeline or PipelineModel.
     private static String getPathForPipelineStage(int stageIdx, int numStages, 
String parentPath) {
-        String format = String.format("%%0%dd", 
String.valueOf(numStages).length());
+        String format =
+                String.format("stages%s%%0%dd", File.separator, 
String.valueOf(numStages).length());
         String fileName = String.format(format, stageIdx);
-        return Paths.get(parentPath, "stages", fileName).toString();
+        return new Path(parentPath, fileName).toString();
     }
 
     /**
@@ -199,7 +192,7 @@ public class ReadWriteUtils {
     public static void savePipeline(Stage<?> pipeline, List<Stage<?>> stages, 
String path)
             throws IOException {
         // Creates parent directories if not already created.
-        new File(path).mkdirs();
+        mkdirs(path);
 
         Map<String, Object> extraMetadata = new HashMap<>();
         extraMetadata.put("numStages", stages.size());
@@ -237,6 +230,13 @@ public class ReadWriteUtils {
         return stages;
     }
 
+    private static FileSystem mkdirs(String path) throws IOException {
+        Path temp = new Path(path);
+        FileSystem fs = temp.getFileSystem();
+        fs.mkdirs(temp);
+        return fs;
+    }
+
     /**
      * Saves a Graph or GraphModel with the given GraphData to the given path.
      *
@@ -247,7 +247,7 @@ public class ReadWriteUtils {
     public static void saveGraph(Stage<?> graph, GraphData graphData, String 
path)
             throws IOException {
         // Creates parent directories if not already created.
-        new File(path).mkdirs();
+        mkdirs(path);
 
         Map<String, Object> extraMetadata = new HashMap<>();
         extraMetadata.put("graphData", graphData.toMap());
@@ -432,7 +432,7 @@ public class ReadWriteUtils {
     public static <T> DataStream<T> loadModelData(
             StreamExecutionEnvironment env, String path, SimpleStreamFormat<T> 
modelDecoder) {
         Source<T, ?, ?> source =
-                FileSource.forRecordStreamFormat(modelDecoder, 
getDataPaths(path)).build();
+                FileSource.forRecordStreamFormat(modelDecoder, new 
Path(getDataPath(path))).build();
         return env.fromSource(source, WatermarkStrategy.noWatermarks(), 
"modelData");
     }
 }
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java
new file mode 100644
index 0000000..e050647
--- /dev/null
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.ml.api.ExampleStages;
+import org.apache.flink.ml.api.TestUtils;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hdfs.MiniDFSCluster;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+/** Tests {@link ReadWriteUtils}. */
+public class ReadWriteUtilsTest extends AbstractTestBase {
+
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private MiniDFSCluster hdfsCluster;
+
+    @Before
+    public void before() throws IOException {
+
+        org.apache.flink.configuration.Configuration config =
+                new org.apache.flink.configuration.Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Configuration conf = new Configuration();
+        conf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, 
tempFolder.newFolder().getAbsolutePath());
+        MiniDFSCluster.Builder builder = new MiniDFSCluster.Builder(conf);
+        hdfsCluster = builder.build();
+    }
+
+    @After
+    public void after() {
+        hdfsCluster.shutdown();
+    }
+
+    @Test
+    public void testModelSaveLoad() throws Exception {
+        // Builds a SumModel that increments input value by 10.
+        ExampleStages.SumModel model =
+                new ExampleStages.SumModel().setModelData(tEnv.fromValues(10));
+        List<List<Integer>> inputs = 
Collections.singletonList(Collections.singletonList(1));
+        List<Integer> output = Collections.singletonList(11);
+
+        // Save and load the model.
+        String path = "hdfs://localhost:" + hdfsCluster.getNameNodePort() + 
"/sumModel";
+        model.save(path);
+        env.execute();
+
+        ExampleStages.SumModel loadedModel = ExampleStages.SumModel.load(env, 
path);
+        // Executes the loaded SumModel and verifies that it produces the 
expected output.
+        TestUtils.executeAndCheckOutput(env, loadedModel, inputs, output, 
null, null);
+    }
+}

Reply via email to