This is an automated email from the ASF dual-hosted git repository.
shaoxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 1660c6b [FLINK-12881][ml] Add more functionalities for ML Params and
ParamInfo class
1660c6b is described below
commit 1660c6b777747af789fa5c9bf6a3ff77e868ca90
Author: xuyang1706 <[email protected]>
AuthorDate: Tue Jun 18 16:26:02 2019 +0800
[FLINK-12881][ml] Add more functionalities for ML Params and ParamInfo class
Add more functionalities, including the support of aliases, the config of
size/clear/isEmpty/contains/fromJason in Params
This closes #8776
---
.../apache/flink/ml/api/core/PipelineStage.java | 12 +-
.../apache/flink/ml/api/misc/param/ParamInfo.java | 18 +-
.../org/apache/flink/ml/api/misc/param/Params.java | 217 ++++++++++++++++-----
.../org/apache/flink/ml/api/misc/ParamsTest.java | 100 +++++++++-
4 files changed, 289 insertions(+), 58 deletions(-)
diff --git
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
index 86bf0d3..cda73db 100644
---
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
+++
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
@@ -18,14 +18,9 @@
package org.apache.flink.ml.api.core;
-import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.WithParams;
-import org.apache.flink.ml.util.param.ExtractParamInfosUtil;
import java.io.Serializable;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
/**
* Base class for a stage in a pipeline. The interface is only a concept, and
does not have any
@@ -46,11 +41,6 @@ interface PipelineStage<T extends PipelineStage<T>> extends
WithParams<T>, Seria
}
default void loadJson(String json) {
- List<ParamInfo> paramInfos =
ExtractParamInfosUtil.extractParamInfos(this);
- Map<String, Class<?>> classMap = new HashMap<>();
- for (ParamInfo i : paramInfos) {
- classMap.put(i.getName(), i.getValueClass());
- }
- getParams().loadJson(json, classMap);
+ getParams().loadJson(json);
}
}
diff --git
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
index 994576f..3b01b4e 100644
---
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
+++
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
@@ -19,11 +19,26 @@
package org.apache.flink.ml.api.misc.param;
import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.util.Preconditions;
/**
* Definition of a parameter, including name, type, default value, validator
and so on.
*
- * <p>This class is provided to unify the interaction with parameters.
+ * <p>A parameter can either be optional or non-optional.
+ * <ul>
+ * <li>
+ * A non-optional parameter should not have a default value. Instead,
its value must be provided by the users.
+ * </li>
+ * <li>
+ * An optional parameter may or may not have a default value.
+ * </li>
+ * </ul>
+ *
+ * <p>Please see {@link Params#get(ParamInfo)} and {@link
Params#contains(ParamInfo)} for more details about the behavior.
+ *
+ * <p>A parameter may have aliases in addition to the parameter name for
convenience and compatibility purposes. One
+ * should not set values for both parameter name and an alias. One and only
one value should be set either under
+ * the parameter name or one of the alias.
*
* @param <V> the type of the param value
*/
@@ -67,6 +82,7 @@ public class ParamInfo<V> {
* @return the aliases of the parameter
*/
public String[] getAlias() {
+ Preconditions.checkNotNull(alias);
return alias;
}
diff --git
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
index 0c1e0d8..49da8ad 100644
---
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
+++
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
@@ -25,7 +25,10 @@ import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMap
import java.io.IOException;
import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
/**
@@ -33,28 +36,93 @@ import java.util.Map;
* parameters.
*/
@PublicEvolving
-public class Params implements Serializable {
- private final Map<String, Object> paramMap = new HashMap<>();
+public class Params implements Serializable, Cloneable {
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * A mapping from param name to its value.
+ *
+ * <p>The value is stored in map using json format.
+ */
+ private final Map<String, String> params;
+
+ private transient ObjectMapper mapper;
+
+ public Params() {
+ this.params = new HashMap<>();
+ }
+
+ /**
+ * Return the number of params.
+ *
+ * @return Return the number of params.
+ */
+ public int size() {
+ return params.size();
+ }
+
+ /**
+ * Removes all of the params.
+ * The params will be empty after this call returns.
+ */
+ public void clear() {
+ params.clear();
+ }
+
+ /**
+ * Returns <tt>true</tt> if this params contains no mappings.
+ *
+ * @return <tt>true</tt> if this map contains no mappings
+ */
+ public boolean isEmpty() {
+ return params.isEmpty();
+ }
/**
* Returns the value of the specific parameter, or default value
defined in the {@code info} if
- * this Params doesn't contain the param.
+ * this Params doesn't have a value set for the parameter. An exception
will be thrown in the
+ * following cases because no value could be found for the specified
parameter.
+ * <ul>
+ * <li>
+ * Non-optional parameter: no value is defined in this params
for a non-optional parameter.
+ * </li>
+ * <li>
+ * Optional parameter: no value is defined in this params and
no default value is defined.
+ * </li>
+ * </ul>
*
* @param info the info of the specific parameter, usually with default
value
* @param <V> the type of the specific parameter
* @return the value of the specific parameter, or default value
defined in the {@code info} if
* this Params doesn't contain the parameter
- * @throws RuntimeException if the Params doesn't contains the specific
parameter, while the
- * param is not optional but has no default
value in the {@code info}
+ * @throws IllegalArgumentException if no value can be found for
specified parameter
*/
- @SuppressWarnings("unchecked")
public <V> V get(ParamInfo<V> info) {
- V value = (V) paramMap.getOrDefault(info.getName(),
info.getDefaultValue());
- if (value == null && !info.isOptional() &&
!info.hasDefaultValue()) {
- throw new RuntimeException(info.getName() +
- " not exist which is not optional and don't
have a default value");
+ String value = null;
+ String usedParamName = null;
+ for (String nameOrAlias : getParamNameAndAlias(info)) {
+ if (params.containsKey(nameOrAlias)) {
+ if (usedParamName != null) {
+ throw new
IllegalArgumentException(String.format("Duplicate parameters of %s and %s",
+ usedParamName, nameOrAlias));
+ }
+ usedParamName = nameOrAlias;
+ value = params.get(nameOrAlias);
+ }
+ }
+
+ if (usedParamName != null) {
+ // The param value was set by the user.
+ return valueFromJson(value, info.getValueClass());
+ } else {
+ // The param value was not set by the user.
+ if (!info.isOptional()) {
+ throw new IllegalArgumentException("Missing
non-optional parameter " + info.getName());
+ } else if (!info.hasDefaultValue()) {
+ throw new IllegalArgumentException("Cannot find
default value for optional parameter " + info.getName());
+ }
+ return info.getDefaultValue();
}
- return value;
}
/**
@@ -69,20 +137,11 @@ public class Params implements Serializable {
* evaluated as illegal by the validator
*/
public <V> Params set(ParamInfo<V> info, V value) {
- if (!info.isOptional() && value == null) {
- throw new RuntimeException(
- "Setting " + info.getName() + " as null while
it's not a optional param");
- }
- if (value == null) {
- remove(info);
- return this;
- }
-
if (info.getValidator() != null &&
!info.getValidator().validate(value)) {
throw new RuntimeException(
"Setting " + info.getName() + " as a invalid
value:" + value);
}
- paramMap.put(info.getName(), value);
+ params.put(info.getName(), valueToJson(value));
return this;
}
@@ -93,18 +152,20 @@ public class Params implements Serializable {
* @param <V> the type of the specific parameter
*/
public <V> void remove(ParamInfo<V> info) {
- paramMap.remove(info.getName());
+ params.remove(info.getName());
+ for (String a : info.getAlias()) {
+ params.remove(a);
+ }
}
/**
- * Creates and returns a deep clone of this Params.
+ * Check whether this params has a value set for the given {@code info}.
*
- * @return a deep clone of this Params
+ * @return <tt>true</tt> if this params has a value set for the
specified {@code info}, false otherwise.
*/
- public Params clone() {
- Params newParams = new Params();
- newParams.paramMap.putAll(this.paramMap);
- return newParams;
+ public <V> boolean contains(ParamInfo<V> info) {
+ return params.containsKey(info.getName()) ||
+
Arrays.stream(info.getAlias()).anyMatch(params::containsKey);
}
/**
@@ -114,13 +175,9 @@ public class Params implements Serializable {
* @return a json containing all parameters in this Params
*/
public String toJson() {
- ObjectMapper mapper = new ObjectMapper();
- Map<String, String> stringMap = new HashMap<>();
+ assertMapperInited();
try {
- for (Map.Entry<String, Object> e : paramMap.entrySet())
{
- stringMap.put(e.getKey(),
mapper.writeValueAsString(e.getValue()));
- }
- return mapper.writeValueAsString(stringMap);
+ return mapper.writeValueAsString(params);
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize params
to json", e);
}
@@ -128,24 +185,94 @@ public class Params implements Serializable {
/**
* Restores the parameters from the given json. The parameters should
be exactly the same with
- * the one who was serialized to the input json after the restoration.
The class mapping of the
- * parameters in the json is required because it is hard to directly
restore a param of a user
- * defined type. Params will be treated as String if it doesn't exist
in the {@code classMap}.
+ * the one who was serialized to the input json after the restoration.
*
- * @param json the json String to restore from
- * @param classMap the classes of the parameters contained in the json
+ * @param json the json String to restore from
*/
@SuppressWarnings("unchecked")
- public void loadJson(String json, Map<String, Class<?>> classMap) {
- ObjectMapper mapper = new ObjectMapper();
+ public void loadJson(String json) {
+ assertMapperInited();
+ Map<String, String> params;
+ try {
+ params = mapper.readValue(json, Map.class);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to deserialize
json:" + json, e);
+ }
+ this.params.putAll(params);
+ }
+
+ /**
+ * Factory method for constructing params.
+ *
+ * @param json the json string to load
+ * @return the {@code Params} loaded from the json string.
+ */
+ public static Params fromJson(String json) {
+ Params params = new Params();
+ params.loadJson(json);
+ return params;
+ }
+
+ /**
+ * Merge other params into this.
+ *
+ * @param otherParams other params
+ * @return this
+ */
+ public Params merge(Params otherParams) {
+ if (otherParams != null) {
+ this.params.putAll(otherParams.params);
+ }
+ return this;
+ }
+
+ /**
+ * Creates and returns a deep clone of this Params.
+ *
+ * @return a deep clone of this Params
+ */
+ @Override
+ public Params clone() {
+ Params newParams = new Params();
+ newParams.params.putAll(this.params);
+ return newParams;
+ }
+
+ private void assertMapperInited() {
+ if (mapper == null) {
+ mapper = new ObjectMapper();
+ }
+ }
+
+ private String valueToJson(Object value) {
+ assertMapperInited();
try {
- Map<String, String> m = mapper.readValue(json,
Map.class);
- for (Map.Entry<String, String> e : m.entrySet()) {
- Class<?> valueClass =
classMap.getOrDefault(e.getKey(), String.class);
- paramMap.put(e.getKey(),
mapper.readValue(e.getValue(), valueClass));
+ if (value == null) {
+ return null;
}
+ return mapper.writeValueAsString(value);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to serialize to
json:" + value, e);
+ }
+ }
+
+ private <T> T valueFromJson(String json, Class<T> clazz) {
+ assertMapperInited();
+ try {
+ if (json == null) {
+ return null;
+ }
+ return mapper.readValue(json, clazz);
} catch (IOException e) {
throw new RuntimeException("Failed to deserialize
json:" + json, e);
}
}
+
+ private <V> List<String> getParamNameAndAlias(
+ ParamInfo <V> info) {
+ List<String> paramNames = new
ArrayList<>(info.getAlias().length + 1);
+ paramNames.add(info.getName());
+ paramNames.addAll(Arrays.asList(info.getAlias()));
+ return paramNames;
+ }
}
diff --git
a/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
index 8bdf95b..7d40847 100644
---
a/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
+++
b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.ml.api.misc.param.ParamInfo;
import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
import org.apache.flink.ml.api.misc.param.Params;
+import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -39,7 +40,11 @@ public class ParamsTest {
ParamInfo<String> optionalWithoutDefault =
ParamInfoFactory.createParamInfo("a",
String.class).build();
- assert params.get(optionalWithoutDefault) == null;
+
+ // It should call params.contain to check when get the param in
this case.
+ thrown.expect(RuntimeException.class);
+ thrown.expectMessage("Cannot find default value for optional
parameter a");
+ params.get(optionalWithoutDefault);
ParamInfo<String> optionalWithDefault =
ParamInfoFactory.createParamInfo("a",
String.class).setHasDefaultValue("def").build();
@@ -69,4 +74,97 @@ public class ParamsTest {
thrown.expectMessage("Setting a as a invalid value:0");
params.set(intParam, 0);
}
+
+ @Test
+ public void getOptionalParam() {
+ ParamInfo <String> key = ParamInfoFactory
+ .createParamInfo("key", String.class)
+ .setHasDefaultValue(null)
+ .setDescription("")
+ .build();
+
+ Params params = new Params();
+ Assert.assertNull(params.get(key));
+
+ String val = "3";
+ params.set(key, val);
+ Assert.assertEquals(params.get(key), val);
+
+ params.set(key, null);
+ Assert.assertNull(params.get(key));
+ }
+
+ @Test
+ public void getOptionalWithoutDefaultParam() {
+ ParamInfo <String> key = ParamInfoFactory
+ .createParamInfo("key", String.class)
+ .setOptional()
+ .setDescription("")
+ .build();
+ Params params = new Params();
+
+ try {
+ String val = params.get(key);
+ Assert.fail("Should throw exception.");
+ } catch (IllegalArgumentException ex) {
+ Assert.assertTrue(ex.getMessage().startsWith("Cannot
find default value for optional parameter"));
+ }
+
+ Assert.assertFalse(params.contains(key));
+
+ String val = "3";
+ params.set(key, val);
+ Assert.assertEquals(params.get(key), val);
+
+ Assert.assertTrue(params.contains(key));
+
+ params.set(key, null);
+ Assert.assertNull(params.get(key));
+ }
+
+ @Test
+ public void getRequiredParam() {
+ ParamInfo <String> labelWithRequired = ParamInfoFactory
+ .createParamInfo("label", String.class)
+ .setDescription("")
+ .setRequired()
+ .build();
+ Params params = new Params();
+ try {
+ params.get(labelWithRequired);
+ Assert.fail("failure");
+ } catch (IllegalArgumentException ex) {
+ Assert.assertTrue(ex.getMessage().startsWith("Missing
non-optional parameter"));
+ }
+
+ params.set(labelWithRequired, null);
+ Assert.assertNull(params.get(labelWithRequired));
+
+ String val = "3";
+ params.set(labelWithRequired, val);
+ Assert.assertEquals(params.get(labelWithRequired), val);
+ }
+
+ @Test
+ public void testGetAliasParam() {
+ ParamInfo <String> predResultColName = ParamInfoFactory
+ .createParamInfo("predResultColName", String.class)
+ .setDescription("Column name of predicted result.")
+ .setRequired()
+ .setAlias(new String[] {"predColName", "outputColName"})
+ .build();
+
+ Params params =
Params.fromJson("{\"predResultColName\":\"\\\"f0\\\"\"}");
+
+ Assert.assertEquals("f0", params.get(predResultColName));
+
+ params =
Params.fromJson("{\"predResultColName\":\"\\\"f0\\\"\",
\"predColName\":\"\\\"f0\\\"\"}");
+
+ try {
+ params.get(predResultColName);
+ Assert.fail("failure");
+ } catch (IllegalArgumentException ex) {
+ Assert.assertTrue(ex.getMessage().startsWith("Duplicate
parameters of predResultColName and predColName"));
+ }
+ }
}