This is an automated email from the ASF dual-hosted git repository. jqin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 9c44eef25970338fe32dcf77ce45efac74c4324f Author: Dong Lin <[email protected]> AuthorDate: Sun Sep 26 21:40:59 2021 +0800 [FLINK-24354][FLIP-174] Improve the WithParams interface --- flink-ml-api/pom.xml | 15 + .../org/apache/flink/ml/api/core/Pipeline.java | 23 +- .../apache/flink/ml/api/core/PipelineModel.java | 23 +- .../java/org/apache/flink/ml/api/core/Stage.java | 2 +- .../org/apache/flink/ml/param/BooleanParam.java | 35 ++ .../apache/flink/ml/param/DoubleArrayParam.java | 35 ++ .../org/apache/flink/ml/param/DoubleParam.java | 35 ++ .../org/apache/flink/ml/param/FloatArrayParam.java | 35 ++ .../java/org/apache/flink/ml/param/FloatParam.java | 32 ++ .../org/apache/flink/ml/param/IntArrayParam.java | 35 ++ .../java/org/apache/flink/ml/param/IntParam.java | 35 ++ .../org/apache/flink/ml/param/LongArrayParam.java | 35 ++ .../java/org/apache/flink/ml/param/LongParam.java | 32 ++ .../main/java/org/apache/flink/ml/param/Param.java | 98 ++++++ .../org/apache/flink/ml/param/ParamValidator.java | 40 +++ .../org/apache/flink/ml/param/ParamValidators.java | 98 ++++++ .../apache/flink/ml/param/StringArrayParam.java | 35 ++ .../org/apache/flink/ml/param/StringParam.java | 35 ++ .../java/org/apache/flink/ml/param/WithParams.java | 135 ++++++++ .../java/org/apache/flink/ml/util/ParamUtils.java | 89 +++++ .../org/apache/flink/ml/util/ReadWriteUtils.java | 279 +++++++++++++++ .../apache/flink/ml/api/core/ExampleStages.java | 244 ++++++++++++++ .../org/apache/flink/ml/api/core/PipelineTest.java | 202 +++++------ .../org/apache/flink/ml/api/core/StageTest.java | 375 +++++++++++++++++++++ pom.xml | 2 - 25 files changed, 1863 insertions(+), 141 deletions(-) diff --git a/flink-ml-api/pom.xml b/flink-ml-api/pom.xml index 81fdcc7..ddfc659 100644 --- a/flink-ml-api/pom.xml +++ b/flink-ml-api/pom.xml @@ -38,6 +38,21 @@ under the License. <version>${flink.version}</version> <scope>provided</scope> </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-table-planner_${scala.binary.version}</artifactId> + <version>${flink.version}</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-test-utils_${scala.binary.version}</artifactId> + <version>${flink.version}</version> + <scope>test</scope> + </dependency> + <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-shaded-jackson</artifactId> diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java index a5fed01..f1e5d0c 100644 --- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Pipeline.java @@ -20,13 +20,17 @@ package org.apache.flink.ml.api.core; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.table.api.Table; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * A Pipeline acts as an Estimator. It consists of an ordered list of stages, each of which could be @@ -36,10 +40,11 @@ import java.util.List; public final class Pipeline implements Estimator<Pipeline, PipelineModel> { private static final long serialVersionUID = 6384850154817512318L; private final List<Stage<?>> stages; - private final Params params = new Params(); + private final Map<Param<?>, Object> paramMap = new HashMap<>(); public Pipeline(List<Stage<?>> stages) { this.stages = stages; + ParamUtils.initializeMapWithDefaultValues(paramMap, this); } /** @@ -97,17 +102,17 @@ public final class Pipeline implements Estimator<Pipeline, PipelineModel> { } @Override - public void save(String path) throws IOException { - throw new UnsupportedOperationException(); + public Map<Param<?>, Object> getParamMap() { + return paramMap; } - public static Pipeline load(String path) throws IOException { - throw new UnsupportedOperationException(); + @Override + public void save(String path) throws IOException { + ReadWriteUtils.savePipeline(this, stages, path); } - @Override - public Params getParams() { - return params; + public static Pipeline load(String path) throws IOException { + return new Pipeline(ReadWriteUtils.loadPipeline(path, Pipeline.class.getName())); } /** diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java index 704fa8e..45bb757 100644 --- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineModel.java @@ -20,12 +20,16 @@ package org.apache.flink.ml.api.core; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.table.api.Table; import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * A PipelineModel acts as a Model. It consists of an ordered list of stages, each of which could be @@ -35,10 +39,11 @@ import java.util.List; public final class PipelineModel implements Model<PipelineModel> { private static final long serialVersionUID = 6184950154217411318L; private final List<Stage<?>> stages; - private final Params params = new Params(); + private final Map<Param<?>, Object> paramMap = new HashMap<>(); public PipelineModel(List<Stage<?>> stages) { this.stages = stages; + ParamUtils.initializeMapWithDefaultValues(paramMap, this); } /** @@ -58,17 +63,17 @@ public final class PipelineModel implements Model<PipelineModel> { } @Override - public void save(String path) throws IOException { - throw new UnsupportedOperationException(); + public Map<Param<?>, Object> getParamMap() { + return paramMap; } - public static PipelineModel load(String path) throws IOException { - throw new UnsupportedOperationException(); + @Override + public void save(String path) throws IOException { + ReadWriteUtils.savePipeline(this, stages, path); } - @Override - public Params getParams() { - return params; + public static PipelineModel load(String path) throws IOException { + return new PipelineModel(ReadWriteUtils.loadPipeline(path, PipelineModel.class.getName())); } /** diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java index 551c5e5..168599b 100644 --- a/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/Stage.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.api.core; import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.ml.api.misc.param.WithParams; +import org.apache.flink.ml.param.WithParams; import java.io.IOException; import java.io.Serializable; diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/BooleanParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/BooleanParam.java new file mode 100644 index 0000000..dd96ebe --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/BooleanParam.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the boolean parameter. */ +public class BooleanParam extends Param<Boolean> { + + public BooleanParam( + String name, + String description, + Boolean defaultValue, + ParamValidator<Boolean> validator) { + super(name, Boolean.class, description, defaultValue, validator); + } + + public BooleanParam(String name, String description, Boolean defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleArrayParam.java new file mode 100644 index 0000000..b86dd00 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleArrayParam.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the double array parameter. */ +public class DoubleArrayParam extends Param<Double[]> { + + public DoubleArrayParam( + String name, + String description, + Double[] defaultValue, + ParamValidator<Double[]> validator) { + super(name, Double[].class, description, defaultValue, validator); + } + + public DoubleArrayParam(String name, String description, Double[] defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleParam.java new file mode 100644 index 0000000..f6d4911 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/DoubleParam.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the double parameter. */ +public class DoubleParam extends Param<Double> { + + public DoubleParam( + String name, + String description, + Double defaultValue, + ParamValidator<Double> validator) { + super(name, Double.class, description, defaultValue, validator); + } + + public DoubleParam(String name, String description, Double defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatArrayParam.java new file mode 100644 index 0000000..4224557 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatArrayParam.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the float array parameter. */ +public class FloatArrayParam extends Param<Float[]> { + + public FloatArrayParam( + String name, + String description, + Float[] defaultValue, + ParamValidator<Float[]> validator) { + super(name, Float[].class, description, defaultValue, validator); + } + + public FloatArrayParam(String name, String description, Float[] defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatParam.java new file mode 100644 index 0000000..0de890c --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/FloatParam.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the float parameter. */ +public class FloatParam extends Param<Float> { + + public FloatParam( + String name, String description, Float defaultValue, ParamValidator<Float> validator) { + super(name, Float.class, description, defaultValue, validator); + } + + public FloatParam(String name, String description, Float defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntArrayParam.java new file mode 100644 index 0000000..4f7c630 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntArrayParam.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the integer array parameter. */ +public class IntArrayParam extends Param<Integer[]> { + + public IntArrayParam( + String name, + String description, + Integer[] defaultValue, + ParamValidator<Integer[]> validator) { + super(name, Integer[].class, description, defaultValue, validator); + } + + public IntArrayParam(String name, String description, Integer[] defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntParam.java new file mode 100644 index 0000000..4178e22 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/IntParam.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the integer parameter. */ +public class IntParam extends Param<Integer> { + + public IntParam( + String name, + String description, + Integer defaultValue, + ParamValidator<Integer> validator) { + super(name, Integer.class, description, defaultValue, validator); + } + + public IntParam(String name, String description, Integer defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongArrayParam.java new file mode 100644 index 0000000..5e4fc47 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongArrayParam.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the long array parameter. */ +public class LongArrayParam extends Param<Long[]> { + + public LongArrayParam( + String name, + String description, + Long[] defaultValue, + ParamValidator<Long[]> validator) { + super(name, Long[].class, description, defaultValue, validator); + } + + public LongArrayParam(String name, String description, Long[] defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongParam.java new file mode 100644 index 0000000..3fd7dd8 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/LongParam.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the long parameter. */ +public class LongParam extends Param<Long> { + + public LongParam( + String name, String description, Long defaultValue, ParamValidator<Long> validator) { + super(name, Long.class, description, defaultValue, validator); + } + + public LongParam(String name, String description, Long defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java new file mode 100644 index 0000000..b7a1aef --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.ml.util.ReadWriteUtils; + +import java.io.IOException; +import java.io.Serializable; + +/** + * Definition of a parameter, including name, class, description, default value and the validator. + * + * @param <T> The class type of the parameter value. + */ +@PublicEvolving +public class Param<T> implements Serializable { + private static final long serialVersionUID = 4396556083935765299L; + + public final String name; + public final Class<T> clazz; + public final String description; + public final T defaultValue; + public final ParamValidator<T> validator; + + public Param( + String name, + Class<T> clazz, + String description, + T defaultValue, + ParamValidator<T> validator) { + this.name = name; + this.clazz = clazz; + this.description = description; + this.defaultValue = defaultValue; + this.validator = validator; + + if (defaultValue != null && !validator.validate(defaultValue)) { + throw new IllegalArgumentException( + "Parameter " + name + " is given an invalid value " + defaultValue); + } + } + + /** + * Encodes the given object into a json-formatted string. + * + * @param value An object of class type T. + * @return A json-formatted string. + */ + public String jsonEncode(T value) throws IOException { + return ReadWriteUtils.OBJECT_MAPPER.writeValueAsString(value); + } + + /** + * Decodes the given string into an object of class type T. + * + * @param json A json-formatted string. + * @return An object of class type T. + */ + @SuppressWarnings("unchecked") + public T jsonDecode(String json) throws IOException { + return ReadWriteUtils.OBJECT_MAPPER.readValue(json, clazz); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Param)) { + return false; + } + return ((Param<?>) obj).name.equals(name); + } + + @Override + public int hashCode() { + return name.hashCode(); + } + + @Override + public String toString() { + return name; + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidator.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidator.java new file mode 100644 index 0000000..afdcd9a --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidator.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +import org.apache.flink.annotation.PublicEvolving; + +import java.io.Serializable; + +/** + * An interface to validate that a parameter value is valid. + * + * @param <T> The class type of the parameter value. + */ +@PublicEvolving +public interface ParamValidator<T> extends Serializable { + + /** + * Validate whether the parameter value is valid. + * + * @param value The parameter value. + * @return A boolean indicating whether the parameter value is valid. + */ + boolean validate(T value); +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidators.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidators.java new file mode 100644 index 0000000..925ccb2 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/ParamValidators.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +import org.apache.commons.lang3.ArrayUtils; + +/** Factory methods for common validation functions on numerical values. */ +public class ParamValidators { + + // Always return true. + public static <T> ParamValidator<T> alwaysTrue() { + return (value) -> true; + } + + // Check if the parameter value is greater than lowerBound. + public static <T> ParamValidator<T> gt(double lowerBound) { + return (value) -> value != null && ((Number) value).doubleValue() > lowerBound; + } + + // Check if the parameter value is greater than or equal to lowerBound. + public static <T> ParamValidator<T> gtEq(double lowerBound) { + return (value) -> value != null && ((Number) value).doubleValue() >= lowerBound; + } + + // Check if the parameter value is less than upperBound. + public static <T> ParamValidator<T> lt(double upperBound) { + return (value) -> value != null && ((Number) value).doubleValue() < upperBound; + } + + // Check if the parameter value is less than or equal to upperBound. + public static <T> ParamValidator<T> ltEq(double upperBound) { + return (value) -> value != null && ((Number) value).doubleValue() <= upperBound; + } + + /** + * Check if the parameter value is in the range from lowerBound to upperBound. + * + * @param lowerInclusive if true, range includes value = lowerBound + * @param upperInclusive if true, range includes value = upperBound + */ + public static <T> ParamValidator<T> inRange( + double lowerBound, double upperBound, boolean lowerInclusive, boolean upperInclusive) { + return new ParamValidator<T>() { + @Override + public boolean validate(T obj) { + if (obj == null) { + return false; + } + double value = ((Number) obj).doubleValue(); + return (value >= lowerBound) + && (value <= upperBound) + && (lowerInclusive || value != lowerBound) + && (upperInclusive || value != upperBound); + } + }; + } + + // Check if the parameter value is in the range [lowerBound, upperBound]. + public static <T> ParamValidator<T> inRange(double lowerBound, double upperBound) { + return inRange(lowerBound, upperBound, true, true); + } + + // Check if the parameter value is in the array of allowed values. + public static <T> ParamValidator<T> inArray(T... allowed) { + return new ParamValidator<T>() { + @Override + public boolean validate(T value) { + return value != null && ArrayUtils.contains(allowed, value); + } + }; + } + + // Check if the parameter value is not null. + public static <T> ParamValidator<T> notNull() { + return new ParamValidator<T>() { + @Override + public boolean validate(T value) { + return value != null; + } + }; + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringArrayParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringArrayParam.java new file mode 100644 index 0000000..5062463 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringArrayParam.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the string array parameter. */ +public class StringArrayParam extends Param<String[]> { + + public StringArrayParam( + String name, + String description, + String[] defaultValue, + ParamValidator<String[]> validator) { + super(name, String[].class, description, defaultValue, validator); + } + + public StringArrayParam(String name, String description, String[] defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringParam.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringParam.java new file mode 100644 index 0000000..1736354 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/StringParam.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +/** Class for the string parameter. */ +public class StringParam extends Param<String> { + + public StringParam( + String name, + String description, + String defaultValue, + ParamValidator<String> validator) { + super(name, String.class, description, defaultValue, validator); + } + + public StringParam(String name, String description, String defaultValue) { + this(name, description, defaultValue, ParamValidators.alwaysTrue()); + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/param/WithParams.java b/flink-ml-api/src/main/java/org/apache/flink/ml/param/WithParams.java new file mode 100644 index 0000000..f631c8e --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/param/WithParams.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.param; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.ml.util.ParamUtils; + +import java.util.Map; +import java.util.Optional; + +/** + * Interface for classes that take parameters. It provides APIs to set and get parameters. + * + * @param <T> The class type of WithParams implementation itself. + */ +@PublicEvolving +public interface WithParams<T> { + + /** + * Gets the parameter by its name. + * + * @param name The parameter name. + * @param <V> The class type of the parameter value. + * @return The parameter. + */ + default <V> Param<V> getParam(String name) { + Optional<Param<?>> result = + getParamMap().keySet().stream().filter(param -> param.name.equals(name)).findAny(); + return (Param<V>) result.orElse(null); + } + + /** + * Sets the value of the parameter. + * + * @param param The parameter. + * @param value The parameter value. + * @return The WithParams instance itself. + */ + @SuppressWarnings("unchecked") + default <V> T set(Param<V> param, V value) { + if (value != null && !param.clazz.isAssignableFrom(value.getClass())) { + throw new ClassCastException( + "Parameter " + + param.name + + " is given a value with incompatible class " + + value.getClass().getName()); + } + + if (!param.validator.validate(value)) { + if (value == null) { + throw new IllegalArgumentException( + "Parameter " + param.name + "'s value should not be null"); + } else { + throw new IllegalArgumentException( + "Parameter " + + param.name + + " is given an invalid value " + + value.toString()); + } + } + getParamMap().put(param, value); + return (T) this; + } + + /** + * Gets the value of the parameter. + * + * @param param The parameter. + * @param <V> The class type of the parameter value. + * @return The parameter value. + */ + @SuppressWarnings("unchecked") + default <V> V get(Param<V> param) { + Map<Param<?>, Object> paramMap = getParamMap(); + V value = (V) paramMap.get(param); + + if (value == null && !param.validator.validate(value)) { + throw new IllegalArgumentException( + "Parameter " + param.name + "'s value should not be null"); + } + + return value; + } + + /** + * Returns a map which should contain value for every parameter that meets one of the following + * conditions. + * + * <p>1) set(...) has been called to set value for this parameter. + * + * <p>2) The parameter is a public final field of this WithParams instance. This includes fields + * inherited from its interfaces and super-classes. + * + * <p>The subclass which implements this interface could meet this requirement by returning a + * member field of the given map type, after having initialized this member field using the + * {@link ParamUtils#initializeMapWithDefaultValues(Map, WithParams)} method. + * + * @return A map which maps parameter definition to parameter value. + */ + Map<Param<?>, Object> getParamMap(); +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ParamUtils.java b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ParamUtils.java new file mode 100644 index 0000000..cdbe63d --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ParamUtils.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.WithParams; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** Utility methods for reading and writing stages. */ +public class ParamUtils { + /** + * Updates the paramMap with default values of all public final Param-typed fields of the given + * instance. A parameter's value will not be updated if this parameter is already found in the + * map. + * + * <p>Note: This method should be called after all public final Param-typed fields of the given + * instance have been defined. A good choice is to call this method in the constructor of the + * given instance. + */ + public static void initializeMapWithDefaultValues( + Map<Param<?>, Object> paramMap, WithParams<?> instance) { + List<Param<?>> defaultParams = getPublicFinalParamFields(instance); + for (Param<?> param : defaultParams) { + if (!paramMap.containsKey(param)) { + paramMap.put(param, param.defaultValue); + } + } + } + + /** + * Finds all public final fields of the Param class type of the given object, including those + * fields inherited from its interfaces and super-classes, and returns those Param instances as + * a list. + * + * @param object the object whose public final Param-typed fields will be returned. + * @return a list of Param instances. + */ + public static List<Param<?>> getPublicFinalParamFields(Object object) { + return getPublicFinalParamFields(object, object.getClass()); + } + + // A helper method that finds all public final fields of the Param class type of the given + // object and returns those Param instances as a list. The clazz specifies the object class. + private static List<Param<?>> getPublicFinalParamFields(Object object, Class<?> clazz) { + List<Param<?>> result = new ArrayList<>(); + for (Field field : clazz.getDeclaredFields()) { + field.setAccessible(true); + if (Param.class.isAssignableFrom(field.getType()) + && Modifier.isPublic(field.getModifiers()) + && Modifier.isFinal(field.getModifiers())) { + try { + result.add((Param<?>) field.get(object)); + } catch (IllegalAccessException e) { + throw new RuntimeException( + "Failed to extract param from field " + field.getName(), e); + } + } + } + + if (clazz.getSuperclass() != null) { + result.addAll(getPublicFinalParamFields(object, clazz.getSuperclass())); + } + for (Class<?> cls : clazz.getInterfaces()) { + result.addAll(getPublicFinalParamFields(object, cls)); + } + return result; + } +} diff --git a/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java new file mode 100644 index 0000000..283c1e5 --- /dev/null +++ b/flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.ml.api.core.Stage; +import org.apache.flink.ml.param.Param; +import org.apache.flink.util.InstantiationUtil; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Utility methods for reading and writing stages. */ +public class ReadWriteUtils { + public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + // A helper method that calls encodes the given parameter value to a json string. We can not + // call param.jsonEncode(value) directly because Param::jsonEncode(...) needs the actual type + // of the value. + private static <T> String jsonEncodeHelper(Param<T> param, Object value) throws IOException { + return param.jsonEncode((T) value); + } + + // Converts Map<Param<?>, Object> to Map<String, String> which maps the parameter name to the + // string-encoded parameter value. + private static Map<String, String> jsonEncode(Map<Param<?>, Object> paramMap) + throws IOException { + Map<String, String> result = new HashMap<>(paramMap.size()); + for (Map.Entry<Param<?>, Object> entry : paramMap.entrySet()) { + String json = jsonEncodeHelper(entry.getKey(), entry.getValue()); + result.put(entry.getKey().name, json); + } + return result; + } + + /** + * Saves the metadata of the given stage and the extra metadata to a file named `metadata` under + * the given path. The metadata of a stage includes the stage class name, parameter values etc. + * + * <p>Required: the metadata file under the given path should not exist. + * + * @param stage The stage instance. + * @param path The parent directory to save the stage metadata. + * @param extraMetadata The extra metadata to be saved. + */ + public static void saveMetadata(Stage<?> stage, String path, Map<String, ?> extraMetadata) + throws IOException { + // Creates parent directories if not already created. + new File(path).mkdirs(); + + Map<String, Object> metadata = new HashMap<>(extraMetadata); + metadata.put("className", stage.getClass().getName()); + metadata.put("timestamp", System.currentTimeMillis()); + metadata.put("paramMap", jsonEncode(stage.getParamMap())); + // TODO: add version in the metadata. + String metadataStr = OBJECT_MAPPER.writeValueAsString(metadata); + + File metadataFile = new File(path, "metadata"); + if (!metadataFile.createNewFile()) { + throw new IOException("File " + metadataFile.toString() + " already exists."); + } + try (BufferedWriter writer = new BufferedWriter(new FileWriter(metadataFile))) { + writer.write(metadataStr); + } + } + + /** + * Saves the metadata of the given stage to a file named `metadata` under the given path. The + * metadata of a stage includes the stage class name, parameter values etc. + * + * <p>Required: the metadata file under the given path should not exist. + * + * @param stage The stage instance. + * @param path The parent directory to save the stage metadata. + */ + public static void saveMetadata(Stage<?> stage, String path) throws IOException { + saveMetadata(stage, path, new HashMap<>()); + } + + /** + * Loads the metadata from the metadata file under the given path. + * + * <p>The method throws RuntimeException if the expectedClassName is not empty AND it does not + * match the className of the previously saved stage. + * + * @param path The parent directory of the metadata file to read from. + * @param expectedClassName The expected class name of the stage. + * @return A map from metadata name to metadata value. + */ + public static Map<String, ?> loadMetadata(String path, String expectedClassName) + throws IOException { + Path metadataPath = Paths.get(path, "metadata"); + StringBuilder buffer = new StringBuilder(); + try (BufferedReader br = new BufferedReader(new FileReader(metadataPath.toString()))) { + String line; + while ((line = br.readLine()) != null) { + if (!line.startsWith("#")) { + buffer.append(line); + } + } + } + + @SuppressWarnings("unchecked") + Map<String, ?> result = OBJECT_MAPPER.readValue(buffer.toString(), Map.class); + + String className = (String) result.get("className"); + if (!expectedClassName.isEmpty() && !expectedClassName.equals(className)) { + throw new RuntimeException( + "Class name " + + className + + " does not match the expected class name " + + expectedClassName + + "."); + } + + return result; + } + + // Returns a string with value {parentPath}/stages/{stageIdx}, where the stageIdx is prefixed + // with zero or more `0` to have the same length as numStages. The resulting string can be + // used as the directory to save a stage of the Pipeline or PipelineModel. + private static String getPathForPipelineStage(int stageIdx, int numStages, String parentPath) { + String format = String.format("%%0%dd", String.valueOf(numStages).length()); + String fileName = String.format(format, stageIdx); + return Paths.get(parentPath, "stages", fileName).toString(); + } + + /** + * Saves a Pipeline or PipelineModel with the given list of stages to the given path. + * + * @param pipeline A Pipeline or PipelineModel instance. + * @param stages A list of stages of the given pipeline. + * @param path The parent directory to save the pipeline metadata and its stages. + */ + public static void savePipeline(Stage<?> pipeline, List<Stage<?>> stages, String path) + throws IOException { + // Creates parent directories if not already created. + new File(path).mkdirs(); + + Map<String, Object> extraMetadata = new HashMap<>(); + extraMetadata.put("numStages", stages.size()); + saveMetadata(pipeline, path, extraMetadata); + + int numStages = stages.size(); + for (int i = 0; i < numStages; i++) { + String stagePath = getPathForPipelineStage(i, numStages, path); + stages.get(i).save(stagePath); + } + } + + /** + * Loads the stages of a Pipeline or PipelineModel from the given path. + * + * <p>The method throws RuntimeException if the expectedClassName is not empty AND it does not + * match the className of the previously saved Pipeline or PipelineModel. + * + * @param path The parent directory to load the pipeline metadata and its stages. + * @param expectedClassName The expected class name of the pipeline. + * @return A list of stages. + */ + public static List<Stage<?>> loadPipeline(String path, String expectedClassName) + throws IOException { + Map<String, ?> metadata = loadMetadata(path, expectedClassName); + int numStages = (Integer) metadata.get("numStages"); + List<Stage<?>> stages = new ArrayList<>(numStages); + + for (int i = 0; i < numStages; i++) { + String stagePath = getPathForPipelineStage(i, numStages, path); + stages.add(loadStage(stagePath)); + } + return stages; + } + + // A helper method that sets stage's parameter value. We can not call stage.set(param, value) + // directly because stage::set(...) needs the actual type of the value. + public static <T> void setStageParam(Stage<?> stage, Param<T> param, Object value) { + stage.set(param, (T) value); + } + + /** + * Loads the stage with the saved parameters from the given path. This method reads the metadata + * file under the given path, instantiates the stage using its no-argument constructor, and + * loads the stage with the paramMap from the metadata file. + * + * <p>Note: This method does not attempt to read model data from the given path. Caller needs to + * read model data from the given path if the stage has model data. + * + * <p>Required: the class with type T must have a no-argument constructor. + * + * @param path The parent directory of the stage metadata file. + * @param <T> The class type of the Stage subclass. + * @return An instance of class type T. + */ + @SuppressWarnings("unchecked") + public static <T extends Stage<T>> T loadStageParam(String path) throws IOException { + Map<String, ?> metadata = loadMetadata(path, ""); + String className = (String) metadata.get("className"); + Map<String, String> paramMap = (Map<String, String>) metadata.get("paramMap"); + + try { + Class<T> clazz = (Class<T>) Class.forName(className); + T instance = InstantiationUtil.instantiate(clazz); + + Map<String, Param<?>> nameToParam = new HashMap<>(); + for (Param<?> param : ParamUtils.getPublicFinalParamFields(instance)) { + nameToParam.put(param.name, param); + } + + for (Map.Entry<String, String> entry : paramMap.entrySet()) { + Param<?> param = nameToParam.get(entry.getKey()); + setStageParam(instance, param, param.jsonDecode(entry.getValue())); + } + return instance; + } catch (ClassNotFoundException e) { + throw new RuntimeException("Failed to load stage.", e); + } + } + + /** + * Loads the stage from the given path by invoking the static load() method of the stage. The + * stage class name is read from the metadata file under the given path. The load() method is + * expected to construct the stage instance with the saved parameters, model data and other + * metadata if exists. + * + * <p>Required: the stage class must have a static load() method. + * + * @param path The parent directory of the stage metadata file. + * @return An instance of Stage. + */ + public static Stage<?> loadStage(String path) throws IOException { + Map<String, ?> metadata = loadMetadata(path, ""); + String className = (String) metadata.get("className"); + + try { + Class<?> clazz = Class.forName(className); + Method method = clazz.getMethod("load", String.class); + method.setAccessible(true); + return (Stage<?>) method.invoke(null, path); + } catch (NoSuchMethodException e) { + String methodName = String.format("%s::load(String)", className); + throw new RuntimeException( + "Failed to load stage because the static method " + + methodName + + " is not implemented.", + e); + } catch (ClassNotFoundException | IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException("Failed to load stage.", e); + } + } +} diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java new file mode 100644 index 0000000..2e4b4c2 --- /dev/null +++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/ExampleStages.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.api.core; + +import org.apache.flink.api.common.state.BroadcastState; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +/** Defines a few Stage subclasses to be used in unit tests. */ +public class ExampleStages { + /** + * A Model subclass that increments every value in the input stream by `delta` and outputs the + * resulting values. + */ + public static class SumModel implements Model<SumModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private DataStream<Integer> modelData; + + // This empty constructor is necessary in order for ModelA to be loaded by + // ReadWriteUtils.createStageWithParam + public SumModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public Table[] transform(Table... inputs) { + Assert.assertEquals(1, inputs.length); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Integer> input = tEnv.toDataStream(inputs[0], Integer.class); + DataStream<Integer> output = + input.connect(modelData.broadcast()) + .transform( + "ApplyDeltaOperator", + BasicTypeInfo.INT_TYPE_INFO, + new ApplyDeltaOperator()); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public void setModelData(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + modelData = tEnv.toDataStream(inputs[0], Integer.class); + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + + File dataDir = new File(path, "data"); + if (!dataDir.mkdir()) { + throw new IOException("Directory " + dataDir.toString() + " already exists."); + } + + File dataFile = new File(dataDir, "delta"); + if (!dataFile.createNewFile()) { + throw new IOException("File " + dataFile.toString() + " already exists."); + } + + int delta; + try { + delta = (Integer) IteratorUtils.toList(modelData.executeAndCollect()).get(0); + } catch (Exception e) { + throw new RuntimeException(e); + } + + try (DataOutputStream outputStream = + new DataOutputStream(new FileOutputStream(dataFile))) { + outputStream.writeInt(delta); + } + } + + public static SumModel load(String path) throws IOException { + SumModel model = ReadWriteUtils.loadStageParam(path); + File dataFile = Paths.get(path, "data", "delta").toFile(); + + try (DataInputStream inputStream = new DataInputStream(new FileInputStream(dataFile))) { + StreamExecutionEnvironment env = + StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + Table modelData = tEnv.fromDataStream(env.fromElements(inputStream.readInt())); + model.setModelData(modelData); + return model; + } + } + } + + // Adds delta from the 2nd input to every element in the 1st input and returns the added values. + private static class ApplyDeltaOperator extends AbstractStreamOperator<Integer> + implements TwoInputStreamOperator<Integer, Integer, Integer> { + private ListState<Integer> unProcessedValues; + private BroadcastState<String, Integer> broadcastState = null; + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + unProcessedValues = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<Integer>( + "unProcessedValues", Integer.class)); + broadcastState = + context.getOperatorStateStore() + .getBroadcastState( + new MapStateDescriptor<String, Integer>( + "broadcastState", String.class, Integer.class)); + } + + @Override + public void processElement1(StreamRecord<Integer> record) throws Exception { + if (broadcastState.get("delta") == null) { + unProcessedValues.add(record.getValue()); + } else { + output.collect(new StreamRecord<>(record.getValue() + broadcastState.get("delta"))); + } + } + + @Override + public void processElement2(StreamRecord<Integer> record) throws Exception { + if (broadcastState.get("delta") != null) { + throw new IllegalStateException("Model data should have exactly one value"); + } + broadcastState.put("delta", record.getValue()); + + for (Integer value : unProcessedValues.get()) { + output.collect(new StreamRecord<>(value + record.getValue())); + } + unProcessedValues.clear(); + } + } + + /** + * An Estimator subclass which calculates the sum of input values and instantiates a ModelA + * instance with delta=sum(inputs). + */ + public static class SumEstimator implements Estimator<SumEstimator, SumModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public SumEstimator() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public SumModel fit(Table... inputs) { + Assert.assertEquals(1, inputs.length); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream<Integer> input = tEnv.toDataStream(inputs[0], Integer.class); + DataStream<Integer> modelData = + input.transform("SumOperator", BasicTypeInfo.INT_TYPE_INFO, new SumOperator()) + .setParallelism(1); + try { + SumModel model = new SumModel(); + model.setModelData(tEnv.fromDataStream(modelData)); + + return model; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static SumEstimator load(String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + } + + private static class SumOperator extends AbstractStreamOperator<Integer> + implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput { + int sum = 0; + + @Override + public void endInput() throws Exception { + output.collect(new StreamRecord<>(sum)); + } + + @Override + public void processElement(StreamRecord<Integer> input) throws Exception { + sum += input.getValue(); + } + } +} diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java index 6d46430..74f9c65 100644 --- a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java +++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/PipelineTest.java @@ -18,141 +18,103 @@ package org.apache.flink.ml.api.core; -import org.apache.flink.ml.api.misc.param.ParamInfo; -import org.apache.flink.ml.api.misc.param.ParamInfoFactory; -import org.apache.flink.ml.api.misc.param.Params; +import org.apache.flink.ml.api.core.ExampleStages.SumEstimator; +import org.apache.flink.ml.api.core.ExampleStages.SumModel; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; -import org.junit.Rule; +import org.apache.commons.collections.IteratorUtils; import org.junit.Test; -import org.junit.rules.ExpectedException; -import java.io.IOException; -import java.util.ArrayList; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Comparator; import java.util.List; -/** Tests the behavior of {@link Pipeline}. */ -public class PipelineTest { - @Rule public ExpectedException thrown = ExpectedException.none(); +/** Tests the behavior of Pipeline and PipelineModel. */ +public class PipelineTest extends AbstractTestBase { - @Test - public void testPipelineBehavior() { - List<Stage<?>> stages = new ArrayList<>(); - stages.add(new MockTransformer("a")); - stages.add(new MockEstimator("b")); - stages.add(new MockEstimator("c")); - stages.add(new MockTransformer("d")); - - Pipeline pipeline = new Pipeline(stages); - assert describePipeline(pipeline.getStages()).equals("a_b_c_d"); - - PipelineModel pipelineModel = pipeline.fit(null, null); - assert describePipeline(pipelineModel.getStages()).equals("a_mb_mc_d"); - } - - private static String describePipeline(List<Stage<?>> stages) { - StringBuilder res = new StringBuilder(); - for (Stage<?> s : stages) { - if (res.length() != 0) { - res.append("_"); - } - res.append(((SelfDescribe) s).describe()); - } - return res.toString(); - } - - /** Interface to describe a class with a string, only for pipeline test. */ - private interface SelfDescribe { - ParamInfo<String> DESCRIPTION = - ParamInfoFactory.createParamInfo("description", String.class).build(); - - String describe(); - } - - /** Mock estimator for pipeline test. */ - public static class MockEstimator implements Estimator<MockEstimator, MockModel>, SelfDescribe { - private final Params params = new Params(); + // Executes the given stage and verifies that it produces the expected output. + private static void executeAndCheckOutput( + Stage<?> stage, List<Integer> input, List<Integer> expectedOutput) throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + env.setParallelism(4); - public MockEstimator() {} + Table inputTable = tEnv.fromDataStream(env.fromCollection(input)); - MockEstimator(String description) { - set(DESCRIPTION, description); - } - - @Override - public MockModel fit(Table... inputs) { - return new MockModel("m" + describe()); - } - - @Override - public Params getParams() { - return params; - } + Table outputTable; - @Override - public String describe() { - return get(DESCRIPTION); + if (stage instanceof AlgoOperator) { + outputTable = ((AlgoOperator<?>) stage).transform(inputTable)[0]; + } else { + Estimator<?, ?> estimator = (Estimator<?, ?>) stage; + Model<?> model = estimator.fit(inputTable); + outputTable = model.transform(inputTable)[0]; } - @Override - public void save(String path) throws IOException {} + List<Integer> output = + IteratorUtils.toList( + tEnv.toDataStream(outputTable, Integer.class).executeAndCollect()); + compareResultCollections(expectedOutput, output, Comparator.naturalOrder()); } - /** Mock transformer for pipeline test. */ - public static class MockTransformer implements Transformer<MockTransformer>, SelfDescribe { - private final Params params = new Params(); - - public MockTransformer() {} - - MockTransformer(String description) { - set(DESCRIPTION, description); - } - - @Override - public Table[] transform(Table... inputs) { - return inputs; - } - - @Override - public Params getParams() { - return params; - } - - @Override - public String describe() { - return get(DESCRIPTION); - } - - @Override - public void save(String path) throws IOException {} + @Test + public void testPipelineModel() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + // Builds a PipelineModel that increments input value by 60. This PipelineModel consists of + // three stages where each stage increments input value by 10, 20, and 30 respectively. + SumModel modelA = new SumModel(); + modelA.setModelData(tEnv.fromValues(10)); + SumModel modelB = new SumModel(); + modelB.setModelData(tEnv.fromValues(20)); + SumModel modelC = new SumModel(); + modelC.setModelData(tEnv.fromValues(30)); + + List<Stage<?>> stages = Arrays.asList(modelA, modelB, modelC); + Model<?> model = new PipelineModel(stages); + + // Executes the original PipelineModel and verifies that it produces the expected output. + executeAndCheckOutput(model, Arrays.asList(1, 2, 3), Arrays.asList(61, 62, 63)); + + // Saves and loads the PipelineModel. + Path tempDir = Files.createTempDirectory("PipelineTest"); + String path = Paths.get(tempDir.toString(), "testPipelineModelSaveLoad").toString(); + model.save(path); + Model<?> loadedModel = PipelineModel.load(path); + + // Executes the loaded PipelineModel and verifies that it produces the expected output. + executeAndCheckOutput(loadedModel, Arrays.asList(1, 2, 3), Arrays.asList(61, 62, 63)); } - /** Mock model for pipeline test. */ - public static class MockModel implements Model<MockModel>, SelfDescribe { - private final Params params = new Params(); - - public MockModel() {} - - MockModel(String description) { - set(DESCRIPTION, description); - } - - @Override - public Table[] transform(Table... inputs) { - return inputs; - } - - @Override - public Params getParams() { - return params; - } - - @Override - public String describe() { - return get(DESCRIPTION); - } - - @Override - public void save(String path) throws IOException {} + @Test + public void testPipeline() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + // Builds a Pipeline that consists of a Model, an Estimator, and a model. + SumModel modelA = new SumModel(); + modelA.setModelData(tEnv.fromValues(10)); + SumModel modelB = new SumModel(); + modelB.setModelData(tEnv.fromValues(30)); + + List<Stage<?>> stages = Arrays.asList(modelA, new SumEstimator(), modelB); + Estimator<?, ?> estimator = new Pipeline(stages); + + // Executes the original Pipeline and verifies that it produces the expected output. + executeAndCheckOutput(estimator, Arrays.asList(1, 2, 3), Arrays.asList(77, 78, 79)); + + // Saves and loads the Pipeline. + Path tempDir = Files.createTempDirectory("PipelineTest"); + String path = Paths.get(tempDir.toString(), "testPipeline").toString(); + estimator.save(path); + Estimator<?, ?> loadedEstimator = Pipeline.load(path); + + // Executes the loaded Pipeline and verifies that it produces the expected output. + executeAndCheckOutput(loadedEstimator, Arrays.asList(1, 2, 3), Arrays.asList(77, 78, 79)); } } diff --git a/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java new file mode 100644 index 0000000..9e03ddb --- /dev/null +++ b/flink-ml-api/src/test/java/org/apache/flink/ml/api/core/StageTest.java @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.api.core; + +import org.apache.flink.ml.param.BooleanParam; +import org.apache.flink.ml.param.DoubleArrayParam; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.FloatArrayParam; +import org.apache.flink.ml.param.FloatParam; +import org.apache.flink.ml.param.IntArrayParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.LongArrayParam; +import org.apache.flink.ml.param.LongParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidator; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** Tests the behavior of Stage and WithParams. */ +public class StageTest { + + // A WithParams subclass which has one parameter for each pre-defined parameter type. + private interface MyParams<T> extends WithParams<T> { + Param<Boolean> BOOLEAN_PARAM = new BooleanParam("booleanParam", "Description", false); + + Param<Integer> INT_PARAM = + new IntParam("intParam", "Description", 1, ParamValidators.lt(100)); + + Param<Long> LONG_PARAM = + new LongParam("longParam", "Description", 2L, ParamValidators.lt(100)); + + Param<Float> FLOAT_PARAM = + new FloatParam("floatParam", "Description", 3.0f, ParamValidators.lt(100)); + + Param<Double> DOUBLE_PARAM = + new DoubleParam("doubleParam", "Description", 4.0, ParamValidators.lt(100)); + + Param<String> STRING_PARAM = new StringParam("stringParam", "Description", "5"); + + Param<Integer[]> INT_ARRAY_PARAM = + new IntArrayParam("intArrayParam", "Description", new Integer[] {6, 7}); + + Param<Long[]> LONG_ARRAY_PARAM = + new LongArrayParam( + "longArrayParam", + "Description", + new Long[] {8L, 9L}, + ParamValidators.alwaysTrue()); + + Param<Float[]> FLOAT_ARRAY_PARAM = + new FloatArrayParam("floatArrayParam", "Description", new Float[] {10.0f, 11.0f}); + + Param<Double[]> DOUBLE_ARRAY_PARAM = + new DoubleArrayParam( + "doubleArrayParam", + "Description", + new Double[] {12.0, 13.0}, + ParamValidators.alwaysTrue()); + + Param<String[]> STRING_ARRAY_PARAM = + new StringArrayParam("stringArrayParam", "Description", new String[] {"14", "15"}); + } + + /** + * A Stage subclass which inherits all parameters from MyParams and defines an extra parameter. + */ + public static class MyStage implements Stage<MyStage>, MyParams<MyStage> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public final Param<Integer> extraIntParam = + new IntParam("extraIntParam", "Description", 20, ParamValidators.alwaysTrue()); + + public final Param<Integer> paramWithNullDefault = + new IntParam( + "paramWithNullDefault", + "Must be explicitly set with a non-null value", + null, + ParamValidators.notNull()); + + public MyStage() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static MyStage load(String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + } + + /** A Stage subclass without the static load() method. */ + public static class MyStageWithoutLoad implements Stage<MyStage>, MyParams<MyStage> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public MyStageWithoutLoad() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + } + + // Asserts that m1 and m2 are equivalent. + private static void assertParamMapEquals(Map<Param<?>, Object> m1, Map<Param<?>, Object> m2) { + Assert.assertTrue(m1 != null && m2 != null); + Assert.assertEquals(m1.size(), m2.size()); + + for (Map.Entry<Param<?>, Object> entry : m1.entrySet()) { + Assert.assertTrue(m2.containsKey(entry.getKey())); + Object v1 = entry.getValue(); + Object v2 = m2.get(entry.getKey()); + if (v1 == null || v2 == null) { + Assert.assertTrue(v1 == null && v2 == null); + } else if (v1.getClass().isArray() && v2.getClass().isArray()) { + Assert.assertArrayEquals((Object[]) v1, (Object[]) v2); + } else { + Assert.assertEquals(v1, v2); + } + } + } + + // Saves and loads the given stage. And verifies that the loaded stage has same parameter values + // as the original stage. + private static Stage<?> validateStageSaveLoad( + Stage<?> stage, Map<String, Object> paramOverrides) throws IOException { + for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) { + Param<?> param = stage.getParam(entry.getKey()); + ReadWriteUtils.setStageParam(stage, param, entry.getValue()); + } + + String tempDir = Files.createTempDirectory("").toString(); + String path = Paths.get(tempDir, "test").toString(); + stage.save(path); + try { + stage.save(path); + Assert.fail("Expected IOException"); + } catch (IOException e) { + // This is expected. + } + + Stage<?> loadedStage = ReadWriteUtils.loadStage(path); + for (Map.Entry<String, Object> entry : paramOverrides.entrySet()) { + Param<?> param = loadedStage.getParam(entry.getKey()); + Assert.assertEquals(entry.getValue(), loadedStage.get(param)); + } + assertParamMapEquals(stage.getParamMap(), loadedStage.getParamMap()); + return loadedStage; + } + + @Test + public void testParamSetValueWithName() { + MyStage stage = new MyStage(); + + Param<Integer> paramA = MyParams.INT_PARAM; + stage.set(paramA, 2); + Assert.assertEquals(2, (int) stage.get(paramA)); + + Param<Integer> paramB = stage.getParam("intParam"); + stage.set(paramB, 3); + Assert.assertEquals(3, (int) stage.get(paramB)); + + Param<Integer> paramC = stage.getParam("extraIntParam"); + stage.set(paramC, 50); + Assert.assertEquals(50, (int) stage.get(paramC)); + } + + @Test + public void testParamWithNullDefault() { + MyStage stage = new MyStage(); + try { + stage.get(stage.paramWithNullDefault); + Assert.fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException e) { + Assert.assertTrue(e.getMessage().contains("should not be null")); + } + + stage.set(stage.paramWithNullDefault, 3); + Assert.assertEquals(3, (int) stage.get(stage.paramWithNullDefault)); + } + + private static <T> void assertInvalidValue(Stage<?> stage, Param<T> param, T value) { + try { + stage.set(param, value); + Assert.fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException e) { + Assert.assertTrue(e.getMessage().contains("invalid value")); + } + } + + private static <T> void assertInvalidClass(Stage<?> stage, Param<T> param, Object value) { + try { + stage.set(param, (T) value); + Assert.fail("Expected ClassCastException"); + } catch (ClassCastException e) { + Assert.assertTrue(e.getMessage().contains("incompatible class")); + } + } + + @Test + public void testParamSetInvalidValue() { + MyStage stage = new MyStage(); + assertInvalidValue(stage, MyParams.INT_PARAM, 100); + assertInvalidValue(stage, MyParams.LONG_PARAM, 100L); + assertInvalidValue(stage, MyParams.FLOAT_PARAM, 100.0f); + assertInvalidValue(stage, MyParams.DOUBLE_PARAM, 100.0); + assertInvalidClass(stage, MyParams.INT_PARAM, "100"); + assertInvalidClass(stage, MyParams.STRING_PARAM, 100); + + Param<Integer> param = stage.getParam("stringParam"); + assertInvalidClass(stage, param, 50); + } + + @Test + public void testParamSetValidValue() { + MyStage stage = new MyStage(); + + stage.set(MyParams.BOOLEAN_PARAM, true); + Assert.assertEquals(true, stage.get(MyParams.BOOLEAN_PARAM)); + + stage.set(MyParams.INT_PARAM, 50); + Assert.assertEquals(50, (int) stage.get(MyParams.INT_PARAM)); + + stage.set(MyParams.LONG_PARAM, 50L); + Assert.assertEquals(50L, (long) stage.get(MyParams.LONG_PARAM)); + + stage.set(MyParams.FLOAT_PARAM, 50f); + Assert.assertEquals(50f, (float) stage.get(MyParams.FLOAT_PARAM), 0.0001); + + stage.set(MyParams.DOUBLE_PARAM, 50.0); + Assert.assertEquals(50, (double) stage.get(MyParams.DOUBLE_PARAM), 0.0001); + + stage.set(MyParams.STRING_PARAM, "50"); + Assert.assertEquals("50", stage.get(MyParams.STRING_PARAM)); + + stage.set(MyParams.INT_ARRAY_PARAM, new Integer[] {50, 51}); + Assert.assertArrayEquals(new Integer[] {50, 51}, stage.get(MyParams.INT_ARRAY_PARAM)); + + stage.set(MyParams.LONG_ARRAY_PARAM, new Long[] {50L, 51L}); + Assert.assertArrayEquals(new Long[] {50L, 51L}, stage.get(MyParams.LONG_ARRAY_PARAM)); + + stage.set(MyParams.FLOAT_ARRAY_PARAM, new Float[] {50.0f, 51.0f}); + Assert.assertArrayEquals(new Float[] {50.0f, 51.0f}, stage.get(MyParams.FLOAT_ARRAY_PARAM)); + + stage.set(MyParams.DOUBLE_ARRAY_PARAM, new Double[] {50.0, 51.0}); + Assert.assertArrayEquals(new Double[] {50.0, 51.0}, stage.get(MyParams.DOUBLE_ARRAY_PARAM)); + + stage.set(MyParams.STRING_ARRAY_PARAM, new String[] {"50", "51"}); + Assert.assertArrayEquals(new String[] {"50", "51"}, stage.get(MyParams.STRING_ARRAY_PARAM)); + } + + @Test + public void testStageSaveLoad() throws IOException { + MyStage stage = new MyStage(); + stage.set(stage.paramWithNullDefault, 1); + Stage<?> loadedStage = validateStageSaveLoad(stage, Collections.emptyMap()); + Assert.assertEquals(1, (int) loadedStage.get(MyParams.INT_PARAM)); + } + + @Test + public void testStageSaveLoadWithParamOverrides() throws IOException { + MyStage stage = new MyStage(); + stage.set(stage.paramWithNullDefault, 1); + Stage<?> loadedStage = + validateStageSaveLoad(stage, Collections.singletonMap("intParam", 10)); + Assert.assertEquals(10, (int) loadedStage.get(MyParams.INT_PARAM)); + } + + @Test + public void testStageLoadWithoutLoadMethod() throws IOException { + MyStageWithoutLoad stage = new MyStageWithoutLoad(); + try { + validateStageSaveLoad(stage, Collections.emptyMap()); + Assert.fail("Expected RuntimeException"); + } catch (RuntimeException e) { + Assert.assertTrue(e.getMessage().contains("not implemented")); + } + } + + @Test + public void testValidators() { + ParamValidator<Integer> gt = ParamValidators.gt(10); + Assert.assertFalse(gt.validate(null)); + Assert.assertFalse(gt.validate(5)); + Assert.assertFalse(gt.validate(10)); + Assert.assertTrue(gt.validate(15)); + + ParamValidator<Integer> gtEq = ParamValidators.gtEq(10); + Assert.assertFalse(gtEq.validate(null)); + Assert.assertFalse(gtEq.validate(5)); + Assert.assertTrue(gtEq.validate(10)); + Assert.assertTrue(gtEq.validate(15)); + + ParamValidator<Integer> lt = ParamValidators.lt(10); + Assert.assertFalse(lt.validate(null)); + Assert.assertTrue(lt.validate(5)); + Assert.assertFalse(lt.validate(10)); + Assert.assertFalse(lt.validate(15)); + + ParamValidator<Integer> ltEq = ParamValidators.ltEq(10); + Assert.assertFalse(ltEq.validate(null)); + Assert.assertTrue(ltEq.validate(5)); + Assert.assertTrue(ltEq.validate(10)); + Assert.assertFalse(ltEq.validate(15)); + + ParamValidator<Integer> inRangeInclusive = ParamValidators.inRange(5, 15); + Assert.assertFalse(inRangeInclusive.validate(null)); + Assert.assertFalse(inRangeInclusive.validate(0)); + Assert.assertTrue(inRangeInclusive.validate(5)); + Assert.assertTrue(inRangeInclusive.validate(10)); + Assert.assertTrue(inRangeInclusive.validate(15)); + Assert.assertFalse(inRangeInclusive.validate(20)); + + ParamValidator<Integer> inRangeExclusive = ParamValidators.inRange(5, 15, false, false); + Assert.assertFalse(inRangeExclusive.validate(null)); + Assert.assertFalse(inRangeExclusive.validate(0)); + Assert.assertFalse(inRangeExclusive.validate(5)); + Assert.assertTrue(inRangeExclusive.validate(10)); + Assert.assertFalse(inRangeExclusive.validate(15)); + Assert.assertFalse(inRangeExclusive.validate(20)); + + ParamValidator<Integer> inArray = ParamValidators.inArray(1, 2, 3); + Assert.assertFalse(inArray.validate(null)); + Assert.assertTrue(inArray.validate(1)); + Assert.assertFalse(inArray.validate(0)); + + ParamValidator<Integer> notNull = ParamValidators.notNull(); + Assert.assertTrue(notNull.validate(5)); + Assert.assertFalse(notNull.validate(null)); + } +} diff --git a/pom.xml b/pom.xml index 5eb1805..66530a0 100644 --- a/pom.xml +++ b/pom.xml @@ -53,8 +53,6 @@ under the License. <modules> <module>flink-ml-api</module> - <module>flink-ml-lib</module> - <module>flink-ml-uber</module> <module>flink-ml-iteration</module> <module>flink-ml-tests</module> </modules>
