This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 9dd3b70 [FLINK-26404] Support non-local file systems
9dd3b70 is described below
commit 9dd3b70e29e936ace755ff607364f21fb332b9be
Author: Mr-Mu <[email protected]>
AuthorDate: Wed Mar 9 23:05:20 2022 -0600
[FLINK-26404] Support non-local file systems
This closes #68.
---
flink-ml-core/pom.xml | 62 +++++++++++++++
.../org/apache/flink/ml/util/ReadWriteUtils.java | 60 +++++++--------
.../apache/flink/ml/util/ReadWriteUtilsTest.java | 90 ++++++++++++++++++++++
3 files changed, 182 insertions(+), 30 deletions(-)
diff --git a/flink-ml-core/pom.xml b/flink-ml-core/pom.xml
index 73bc204..c351270 100644
--- a/flink-ml-core/pom.xml
+++ b/flink-ml-core/pom.xml
@@ -31,6 +31,10 @@ under the License.
<artifactId>flink-ml-core_${scala.binary.version}</artifactId>
<name>Flink ML : Core</name>
+ <properties>
+ <hadoop.version>2.4.1</hadoop.version>
+ </properties>
+
<dependencies>
<dependency>
<groupId>org.apache.flink</groupId>
@@ -108,5 +112,63 @@ under the License.
<type>jar</type>
<scope>test</scope>
</dependency>
+
+ <!-- hdfs is required for the data cache test -->
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-hdfs</artifactId>
+ <scope>test</scope>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-common</artifactId>
+ <scope>test</scope>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ </exclusion>
+ <exclusion>
+ <!-- This dependency is no longer shipped with the JDK since Java
9.-->
+ <groupId>jdk.tools</groupId>
+ <artifactId>jdk.tools</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-minicluster</artifactId>
+ <scope>test</scope>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ </exclusion>
+ <exclusion>
+ <!-- This dependency is no longer shipped with the JDK since Java
9.-->
+ <groupId>jdk.tools</groupId>
+ <artifactId>jdk.tools</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
</dependencies>
</project>
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index eedc74d..674440b 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -24,6 +24,8 @@ import org.apache.flink.api.connector.source.Source;
import org.apache.flink.connector.file.sink.FileSink;
import org.apache.flink.connector.file.src.FileSource;
import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.builder.Graph;
import org.apache.flink.ml.builder.GraphData;
@@ -42,13 +44,11 @@ import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMap
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
-import java.io.FileReader;
-import java.io.FileWriter;
import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
-import java.nio.file.Path;
-import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -92,7 +92,7 @@ public class ReadWriteUtils {
public static void saveMetadata(Stage<?> stage, String path, Map<String,
?> extraMetadata)
throws IOException {
// Creates parent directories if not already created.
- new File(path).mkdirs();
+ FileSystem fs = mkdirs(path);
Map<String, Object> metadata = new HashMap<>(extraMetadata);
metadata.put("className", stage.getClass().getName());
@@ -101,11 +101,14 @@ public class ReadWriteUtils {
// TODO: add version in the metadata.
String metadataStr = OBJECT_MAPPER.writeValueAsString(metadata);
- File metadataFile = new File(path, "metadata");
- if (!metadataFile.createNewFile()) {
- throw new IOException("File " + metadataFile.toString() + "
already exists.");
+ Path metadataPath = new Path(path, "metadata");
+ if (fs.exists(metadataPath)) {
+ throw new IOException("File " + metadataPath + " already exists.");
}
- try (BufferedWriter writer = new BufferedWriter(new
FileWriter(metadataFile))) {
+ try (BufferedWriter writer =
+ new BufferedWriter(
+ new OutputStreamWriter(
+ fs.create(metadataPath,
FileSystem.WriteMode.NO_OVERWRITE)))) {
writer.write(metadataStr);
}
}
@@ -125,20 +128,7 @@ public class ReadWriteUtils {
/** Returns a subdirectory of the given path for saving/loading model
data. */
private static String getDataPath(String path) {
- return Paths.get(path, "data").toString();
- }
-
- /** Returns all data files under the given path as a list of paths. */
- private static org.apache.flink.core.fs.Path[] getDataPaths(String path) {
- String dataPath = getDataPath(path);
- File[] files = new File(dataPath).listFiles();
-
- org.apache.flink.core.fs.Path[] paths = new
org.apache.flink.core.fs.Path[files.length];
- for (int i = 0; i < paths.length; i++) {
- paths[i] = org.apache.flink.core.fs.Path.fromLocalFile(files[i]);
- }
-
- return paths;
+ return new Path(path, "data").toString();
}
/**
@@ -153,9 +143,11 @@ public class ReadWriteUtils {
*/
public static Map<String, ?> loadMetadata(String path, String
expectedClassName)
throws IOException {
- Path metadataPath = Paths.get(path, "metadata");
+ Path metadataPath = new Path(path, "metadata");
+ FileSystem fs = metadataPath.getFileSystem();
+
StringBuilder buffer = new StringBuilder();
- try (BufferedReader br = new BufferedReader(new
FileReader(metadataPath.toString()))) {
+ try (BufferedReader br = new BufferedReader(new
InputStreamReader(fs.open(metadataPath)))) {
String line;
while ((line = br.readLine()) != null) {
if (!line.startsWith("#")) {
@@ -184,9 +176,10 @@ public class ReadWriteUtils {
// with zero or more `0` to have the same length as numStages. The
resulting string can be
// used as the directory to save a stage of the Pipeline or PipelineModel.
private static String getPathForPipelineStage(int stageIdx, int numStages,
String parentPath) {
- String format = String.format("%%0%dd",
String.valueOf(numStages).length());
+ String format =
+ String.format("stages%s%%0%dd", File.separator,
String.valueOf(numStages).length());
String fileName = String.format(format, stageIdx);
- return Paths.get(parentPath, "stages", fileName).toString();
+ return new Path(parentPath, fileName).toString();
}
/**
@@ -199,7 +192,7 @@ public class ReadWriteUtils {
public static void savePipeline(Stage<?> pipeline, List<Stage<?>> stages,
String path)
throws IOException {
// Creates parent directories if not already created.
- new File(path).mkdirs();
+ mkdirs(path);
Map<String, Object> extraMetadata = new HashMap<>();
extraMetadata.put("numStages", stages.size());
@@ -237,6 +230,13 @@ public class ReadWriteUtils {
return stages;
}
+ private static FileSystem mkdirs(String path) throws IOException {
+ Path temp = new Path(path);
+ FileSystem fs = temp.getFileSystem();
+ fs.mkdirs(temp);
+ return fs;
+ }
+
/**
* Saves a Graph or GraphModel with the given GraphData to the given path.
*
@@ -247,7 +247,7 @@ public class ReadWriteUtils {
public static void saveGraph(Stage<?> graph, GraphData graphData, String
path)
throws IOException {
// Creates parent directories if not already created.
- new File(path).mkdirs();
+ mkdirs(path);
Map<String, Object> extraMetadata = new HashMap<>();
extraMetadata.put("graphData", graphData.toMap());
@@ -432,7 +432,7 @@ public class ReadWriteUtils {
public static <T> DataStream<T> loadModelData(
StreamExecutionEnvironment env, String path, SimpleStreamFormat<T>
modelDecoder) {
Source<T, ?, ?> source =
- FileSource.forRecordStreamFormat(modelDecoder,
getDataPaths(path)).build();
+ FileSource.forRecordStreamFormat(modelDecoder, new
Path(getDataPath(path))).build();
return env.fromSource(source, WatermarkStrategy.noWatermarks(),
"modelData");
}
}
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java
new file mode 100644
index 0000000..e050647
--- /dev/null
+++
b/flink-ml-core/src/test/java/org/apache/flink/ml/util/ReadWriteUtilsTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.ml.api.ExampleStages;
+import org.apache.flink.ml.api.TestUtils;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hdfs.MiniDFSCluster;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+/** Tests {@link ReadWriteUtils}. */
+public class ReadWriteUtilsTest extends AbstractTestBase {
+
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private MiniDFSCluster hdfsCluster;
+
+ @Before
+ public void before() throws IOException {
+
+ org.apache.flink.configuration.Configuration config =
+ new org.apache.flink.configuration.Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ Configuration conf = new Configuration();
+ conf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR,
tempFolder.newFolder().getAbsolutePath());
+ MiniDFSCluster.Builder builder = new MiniDFSCluster.Builder(conf);
+ hdfsCluster = builder.build();
+ }
+
+ @After
+ public void after() {
+ hdfsCluster.shutdown();
+ }
+
+ @Test
+ public void testModelSaveLoad() throws Exception {
+ // Builds a SumModel that increments input value by 10.
+ ExampleStages.SumModel model =
+ new ExampleStages.SumModel().setModelData(tEnv.fromValues(10));
+ List<List<Integer>> inputs =
Collections.singletonList(Collections.singletonList(1));
+ List<Integer> output = Collections.singletonList(11);
+
+ // Save and load the model.
+ String path = "hdfs://localhost:" + hdfsCluster.getNameNodePort() +
"/sumModel";
+ model.save(path);
+ env.execute();
+
+ ExampleStages.SumModel loadedModel = ExampleStages.SumModel.load(env,
path);
+ // Executes the loaded SumModel and verifies that it produces the
expected output.
+ TestUtils.executeAndCheckOutput(env, loadedModel, inputs, output,
null, null);
+ }
+}