This is an automated email from the ASF dual-hosted git repository.

shaoxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 1660c6b  [FLINK-12881][ml] Add more functionalities for ML Params and 
ParamInfo class
1660c6b is described below

commit 1660c6b777747af789fa5c9bf6a3ff77e868ca90
Author: xuyang1706 <[email protected]>
AuthorDate: Tue Jun 18 16:26:02 2019 +0800

    [FLINK-12881][ml] Add more functionalities for ML Params and ParamInfo class
    
    Add more functionalities, including the support of aliases, the config of 
size/clear/isEmpty/contains/fromJason in Params
    
    This closes #8776
---
 .../apache/flink/ml/api/core/PipelineStage.java    |  12 +-
 .../apache/flink/ml/api/misc/param/ParamInfo.java  |  18 +-
 .../org/apache/flink/ml/api/misc/param/Params.java | 217 ++++++++++++++++-----
 .../org/apache/flink/ml/api/misc/ParamsTest.java   | 100 +++++++++-
 4 files changed, 289 insertions(+), 58 deletions(-)

diff --git 
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
 
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
index 86bf0d3..cda73db 100644
--- 
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
+++ 
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/core/PipelineStage.java
@@ -18,14 +18,9 @@
 
 package org.apache.flink.ml.api.core;
 
-import org.apache.flink.ml.api.misc.param.ParamInfo;
 import org.apache.flink.ml.api.misc.param.WithParams;
-import org.apache.flink.ml.util.param.ExtractParamInfosUtil;
 
 import java.io.Serializable;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
 
 /**
  * Base class for a stage in a pipeline. The interface is only a concept, and 
does not have any
@@ -46,11 +41,6 @@ interface PipelineStage<T extends PipelineStage<T>> extends 
WithParams<T>, Seria
        }
 
        default void loadJson(String json) {
-               List<ParamInfo> paramInfos = 
ExtractParamInfosUtil.extractParamInfos(this);
-               Map<String, Class<?>> classMap = new HashMap<>();
-               for (ParamInfo i : paramInfos) {
-                       classMap.put(i.getName(), i.getValueClass());
-               }
-               getParams().loadJson(json, classMap);
+               getParams().loadJson(json);
        }
 }
diff --git 
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
 
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
index 994576f..3b01b4e 100644
--- 
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
+++ 
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/ParamInfo.java
@@ -19,11 +19,26 @@
 package org.apache.flink.ml.api.misc.param;
 
 import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.util.Preconditions;
 
 /**
  * Definition of a parameter, including name, type, default value, validator 
and so on.
  *
- * <p>This class is provided to unify the interaction with parameters.
+ * <p>A parameter can either be optional or non-optional.
+ * <ul>
+ *     <li>
+ *         A non-optional parameter should not have a default value. Instead, 
its value must be provided by the users.
+ *     </li>
+ *     <li>
+ *         An optional parameter may or may not have a default value.
+ *     </li>
+ * </ul>
+ *
+ * <p>Please see {@link Params#get(ParamInfo)} and {@link 
Params#contains(ParamInfo)} for more details about the behavior.
+ *
+ * <p>A parameter may have aliases in addition to the parameter name for 
convenience and compatibility purposes. One
+ * should not set values for both parameter name and an alias. One and only 
one value should be set either under
+ * the parameter name or one of the alias.
  *
  * @param <V> the type of the param value
  */
@@ -67,6 +82,7 @@ public class ParamInfo<V> {
         * @return the aliases of the parameter
         */
        public String[] getAlias() {
+               Preconditions.checkNotNull(alias);
                return alias;
        }
 
diff --git 
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
 
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
index 0c1e0d8..49da8ad 100644
--- 
a/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
+++ 
b/flink-ml-parent/flink-ml-api/src/main/java/org/apache/flink/ml/api/misc/param/Params.java
@@ -25,7 +25,10 @@ import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMap
 
 import java.io.IOException;
 import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 /**
@@ -33,28 +36,93 @@ import java.util.Map;
  * parameters.
  */
 @PublicEvolving
-public class Params implements Serializable {
-       private final Map<String, Object> paramMap = new HashMap<>();
+public class Params implements Serializable, Cloneable {
+       private static final long serialVersionUID = 1L;
+
+       /**
+        * A mapping from param name to its value.
+        *
+        * <p>The value is stored in map using json format.
+        */
+       private final Map<String, String> params;
+
+       private transient ObjectMapper mapper;
+
+       public Params() {
+               this.params = new HashMap<>();
+       }
+
+       /**
+        * Return the number of params.
+        *
+        * @return Return the number of params.
+        */
+       public int size() {
+               return params.size();
+       }
+
+       /**
+        * Removes all of the params.
+        * The params will be empty after this call returns.
+        */
+       public void clear() {
+               params.clear();
+       }
+
+       /**
+        * Returns <tt>true</tt> if this params contains no mappings.
+        *
+        * @return <tt>true</tt> if this map contains no mappings
+        */
+       public boolean isEmpty() {
+               return params.isEmpty();
+       }
 
        /**
         * Returns the value of the specific parameter, or default value 
defined in the {@code info} if
-        * this Params doesn't contain the param.
+        * this Params doesn't have a value set for the parameter. An exception 
will be thrown in the
+        * following cases because no value could be found for the specified 
parameter.
+        * <ul>
+        *     <li>
+        *         Non-optional parameter: no value is defined in this params 
for a non-optional parameter.
+        *     </li>
+        *     <li>
+        *         Optional parameter: no value is defined in this params and 
no default value is defined.
+        *     </li>
+        * </ul>
         *
         * @param info the info of the specific parameter, usually with default 
value
         * @param <V>  the type of the specific parameter
         * @return the value of the specific parameter, or default value 
defined in the {@code info} if
         * this Params doesn't contain the parameter
-        * @throws RuntimeException if the Params doesn't contains the specific 
parameter, while the
-        *                          param is not optional but has no default 
value in the {@code info}
+        * @throws IllegalArgumentException if no value can be found for 
specified parameter
         */
-       @SuppressWarnings("unchecked")
        public <V> V get(ParamInfo<V> info) {
-               V value = (V) paramMap.getOrDefault(info.getName(), 
info.getDefaultValue());
-               if (value == null && !info.isOptional() && 
!info.hasDefaultValue()) {
-                       throw new RuntimeException(info.getName() +
-                               " not exist which is not optional and don't 
have a default value");
+               String value = null;
+               String usedParamName = null;
+               for (String nameOrAlias : getParamNameAndAlias(info)) {
+                       if (params.containsKey(nameOrAlias)) {
+                               if (usedParamName != null) {
+                                       throw new 
IllegalArgumentException(String.format("Duplicate parameters of %s and %s",
+                                               usedParamName, nameOrAlias));
+                               }
+                               usedParamName = nameOrAlias;
+                               value = params.get(nameOrAlias);
+                       }
+               }
+
+               if (usedParamName != null) {
+                       // The param value was set by the user.
+                       return valueFromJson(value, info.getValueClass());
+               } else {
+                       // The param value was not set by the user.
+                       if (!info.isOptional()) {
+                               throw new IllegalArgumentException("Missing 
non-optional parameter " + info.getName());
+                       } else if (!info.hasDefaultValue()) {
+                               throw new IllegalArgumentException("Cannot find 
default value for optional parameter " + info.getName());
+                       }
+                       return info.getDefaultValue();
                }
-               return value;
        }
 
        /**
@@ -69,20 +137,11 @@ public class Params implements Serializable {
         *                          evaluated as illegal by the validator
         */
        public <V> Params set(ParamInfo<V> info, V value) {
-               if (!info.isOptional() && value == null) {
-                       throw new RuntimeException(
-                               "Setting " + info.getName() + " as null while 
it's not a optional param");
-               }
-               if (value == null) {
-                       remove(info);
-                       return this;
-               }
-
                if (info.getValidator() != null && 
!info.getValidator().validate(value)) {
                        throw new RuntimeException(
                                "Setting " + info.getName() + " as a invalid 
value:" + value);
                }
-               paramMap.put(info.getName(), value);
+               params.put(info.getName(), valueToJson(value));
                return this;
        }
 
@@ -93,18 +152,20 @@ public class Params implements Serializable {
         * @param <V>  the type of the specific parameter
         */
        public <V> void remove(ParamInfo<V> info) {
-               paramMap.remove(info.getName());
+               params.remove(info.getName());
+               for (String a : info.getAlias()) {
+                       params.remove(a);
+               }
        }
 
        /**
-        * Creates and returns a deep clone of this Params.
+        * Check whether this params has a value set for the given {@code info}.
         *
-        * @return a deep clone of this Params
+        * @return <tt>true</tt> if this params has a value set for the 
specified {@code info}, false otherwise.
         */
-       public Params clone() {
-               Params newParams = new Params();
-               newParams.paramMap.putAll(this.paramMap);
-               return newParams;
+       public <V> boolean contains(ParamInfo<V> info) {
+               return params.containsKey(info.getName()) ||
+                       
Arrays.stream(info.getAlias()).anyMatch(params::containsKey);
        }
 
        /**
@@ -114,13 +175,9 @@ public class Params implements Serializable {
         * @return a json containing all parameters in this Params
         */
        public String toJson() {
-               ObjectMapper mapper = new ObjectMapper();
-               Map<String, String> stringMap = new HashMap<>();
+               assertMapperInited();
                try {
-                       for (Map.Entry<String, Object> e : paramMap.entrySet()) 
{
-                               stringMap.put(e.getKey(), 
mapper.writeValueAsString(e.getValue()));
-                       }
-                       return mapper.writeValueAsString(stringMap);
+                       return mapper.writeValueAsString(params);
                } catch (JsonProcessingException e) {
                        throw new RuntimeException("Failed to serialize params 
to json", e);
                }
@@ -128,24 +185,94 @@ public class Params implements Serializable {
 
        /**
         * Restores the parameters from the given json. The parameters should 
be exactly the same with
-        * the one who was serialized to the input json after the restoration. 
The class mapping of the
-        * parameters in the json is required because it is hard to directly 
restore a param of a user
-        * defined type. Params will be treated as String if it doesn't exist 
in the {@code classMap}.
+        * the one who was serialized to the input json after the restoration.
         *
-        * @param json     the json String to restore from
-        * @param classMap the classes of the parameters contained in the json
+        * @param json the json String to restore from
         */
        @SuppressWarnings("unchecked")
-       public void loadJson(String json, Map<String, Class<?>> classMap) {
-               ObjectMapper mapper = new ObjectMapper();
+       public void loadJson(String json) {
+               assertMapperInited();
+               Map<String, String> params;
+               try {
+                       params = mapper.readValue(json, Map.class);
+               } catch (IOException e) {
+                       throw new RuntimeException("Failed to deserialize 
json:" + json, e);
+               }
+               this.params.putAll(params);
+       }
+
+       /**
+        * Factory method for constructing params.
+        *
+        * @param json the json string to load
+        * @return the {@code Params} loaded from the json string.
+        */
+       public static Params fromJson(String json) {
+               Params params = new Params();
+               params.loadJson(json);
+               return params;
+       }
+
+       /**
+        * Merge other params into this.
+        *
+        * @param otherParams other params
+        * @return this
+        */
+       public Params merge(Params otherParams) {
+               if (otherParams != null) {
+                       this.params.putAll(otherParams.params);
+               }
+               return this;
+       }
+
+       /**
+        * Creates and returns a deep clone of this Params.
+        *
+        * @return a deep clone of this Params
+        */
+       @Override
+       public Params clone() {
+               Params newParams = new Params();
+               newParams.params.putAll(this.params);
+               return newParams;
+       }
+
+       private void assertMapperInited() {
+               if (mapper == null) {
+                       mapper = new ObjectMapper();
+               }
+       }
+
+       private String valueToJson(Object value) {
+               assertMapperInited();
                try {
-                       Map<String, String> m = mapper.readValue(json, 
Map.class);
-                       for (Map.Entry<String, String> e : m.entrySet()) {
-                               Class<?> valueClass = 
classMap.getOrDefault(e.getKey(), String.class);
-                               paramMap.put(e.getKey(), 
mapper.readValue(e.getValue(), valueClass));
+                       if (value == null) {
+                               return null;
                        }
+                       return mapper.writeValueAsString(value);
+               } catch (JsonProcessingException e) {
+                       throw new RuntimeException("Failed to serialize to 
json:" + value, e);
+               }
+       }
+
+       private <T> T valueFromJson(String json, Class<T> clazz) {
+               assertMapperInited();
+               try {
+                       if (json == null) {
+                               return null;
+                       }
+                       return mapper.readValue(json, clazz);
                } catch (IOException e) {
                        throw new RuntimeException("Failed to deserialize 
json:" + json, e);
                }
        }
+
+       private <V> List<String> getParamNameAndAlias(
+               ParamInfo <V> info) {
+               List<String> paramNames = new 
ArrayList<>(info.getAlias().length + 1);
+               paramNames.add(info.getName());
+               paramNames.addAll(Arrays.asList(info.getAlias()));
+               return paramNames;
+       }
 }
diff --git 
a/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
 
b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
index 8bdf95b..7d40847 100644
--- 
a/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
+++ 
b/flink-ml-parent/flink-ml-api/src/test/java/org/apache/flink/ml/api/misc/ParamsTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.ml.api.misc.param.ParamInfo;
 import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
 import org.apache.flink.ml.api.misc.param.Params;
 
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -39,7 +40,11 @@ public class ParamsTest {
 
                ParamInfo<String> optionalWithoutDefault =
                        ParamInfoFactory.createParamInfo("a", 
String.class).build();
-               assert params.get(optionalWithoutDefault) == null;
+
+               // It should call params.contain to check when get the param in 
this case.
+               thrown.expect(RuntimeException.class);
+               thrown.expectMessage("Cannot find default value for optional 
parameter a");
+               params.get(optionalWithoutDefault);
 
                ParamInfo<String> optionalWithDefault =
                        ParamInfoFactory.createParamInfo("a", 
String.class).setHasDefaultValue("def").build();
@@ -69,4 +74,97 @@ public class ParamsTest {
                thrown.expectMessage("Setting a as a invalid value:0");
                params.set(intParam, 0);
        }
+
+       @Test
+       public void getOptionalParam() {
+               ParamInfo <String> key = ParamInfoFactory
+                       .createParamInfo("key", String.class)
+                       .setHasDefaultValue(null)
+                       .setDescription("")
+                       .build();
+
+               Params params = new Params();
+               Assert.assertNull(params.get(key));
+
+               String val = "3";
+               params.set(key, val);
+               Assert.assertEquals(params.get(key), val);
+
+               params.set(key, null);
+               Assert.assertNull(params.get(key));
+       }
+
+       @Test
+       public void getOptionalWithoutDefaultParam() {
+               ParamInfo <String> key = ParamInfoFactory
+                       .createParamInfo("key", String.class)
+                       .setOptional()
+                       .setDescription("")
+                       .build();
+               Params params = new Params();
+
+               try {
+                       String val = params.get(key);
+                       Assert.fail("Should throw exception.");
+               } catch (IllegalArgumentException ex) {
+                       Assert.assertTrue(ex.getMessage().startsWith("Cannot 
find default value for optional parameter"));
+               }
+
+               Assert.assertFalse(params.contains(key));
+
+               String val = "3";
+               params.set(key, val);
+               Assert.assertEquals(params.get(key), val);
+
+               Assert.assertTrue(params.contains(key));
+
+               params.set(key, null);
+               Assert.assertNull(params.get(key));
+       }
+
+       @Test
+       public void getRequiredParam() {
+               ParamInfo <String> labelWithRequired = ParamInfoFactory
+                       .createParamInfo("label", String.class)
+                       .setDescription("")
+                       .setRequired()
+                       .build();
+               Params params = new Params();
+               try {
+                       params.get(labelWithRequired);
+                       Assert.fail("failure");
+               } catch (IllegalArgumentException ex) {
+                       Assert.assertTrue(ex.getMessage().startsWith("Missing 
non-optional parameter"));
+               }
+
+               params.set(labelWithRequired, null);
+               Assert.assertNull(params.get(labelWithRequired));
+
+               String val = "3";
+               params.set(labelWithRequired, val);
+               Assert.assertEquals(params.get(labelWithRequired), val);
+       }
+
+       @Test
+       public void testGetAliasParam() {
+               ParamInfo <String> predResultColName = ParamInfoFactory
+                       .createParamInfo("predResultColName", String.class)
+                       .setDescription("Column name of predicted result.")
+                       .setRequired()
+                       .setAlias(new String[] {"predColName", "outputColName"})
+                       .build();
+
+               Params params = 
Params.fromJson("{\"predResultColName\":\"\\\"f0\\\"\"}");
+
+               Assert.assertEquals("f0", params.get(predResultColName));
+
+               params = 
Params.fromJson("{\"predResultColName\":\"\\\"f0\\\"\", 
\"predColName\":\"\\\"f0\\\"\"}");
+
+               try {
+                       params.get(predResultColName);
+                       Assert.fail("failure");
+               } catch (IllegalArgumentException ex) {
+                       Assert.assertTrue(ex.getMessage().startsWith("Duplicate 
parameters of predResultColName and predColName"));
+               }
+       }
 }

Reply via email to