This is an automated email from the ASF dual-hosted git repository. myui pushed a commit to branch HIVEMALL-253-2 in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
commit 6d708abbb46bc740b52bab09ba4eda943dadaf85 Author: Makoto Yui <[email protected]> AuthorDate: Mon Jun 10 15:29:26 2019 +0900 merged from https://github.com/Solodye/incubator-hivemall.git master ignoring pom.xml updates --- .../java/hivemall/tools/map/MapRouletteUDF.java | 192 +++++++++++++++++++++ .../hivemall/tools/map/MapRouletteUDFTest.java | 148 ++++++++++++++++ docs/gitbook/misc/generic_funcs.md | 38 +++- resources/ddl/define-all.hive | 3 + 4 files changed, 380 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java b/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java new file mode 100644 index 0000000..e69dd53 --- /dev/null +++ b/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.map; + +import hivemall.utils.hadoop.HiveUtils; +import org.apache.hadoop.hive.ql.exec.*; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import java.util.*; +import static hivemall.HivemallConstants.*; + +/** + * The map_roulette() can be use to do roulette, according to each map.entry 's weight. + * + * @author Wang, Yizheng + */ +@Description(name = "map_roulette", value = "_FUNC_(Map<K, number> map)" + + " - Returns the key K which determine to its weight , the bigger weight is ,the more probability K will return. " + + "Number is a probability value or a positive weight") +@UDFType(deterministic = false, stateful = false) // it is false because it return value base on probability +public class MapRouletteUDF extends GenericUDF { + + /** + * The map passed in saved all the value and its weight + * + * @param m A map contains a lot of item as key, with their weight as value + * @return The key that computer selected according to key's weight + */ + private Object algorithm(Map<Object, Double> m) { + // normalize the weight + double sum = 0; + for (Map.Entry<Object, Double> entry : m.entrySet()) { + sum += entry.getValue(); + } + for (Map.Entry<Object, Double> entry : m.entrySet()) { + entry.setValue(entry.getValue() / sum); + } + + // sort and generate a number axis + List<Map.Entry<Object, Double>> entryList = new ArrayList<>(m.entrySet()); + Collections.sort(entryList, new MapRouletteUDF.KvComparator()); + double tmp = 0; + for (Map.Entry<Object, Double> entry : entryList) { + tmp += entry.getValue(); + entry.setValue(tmp); + } + + // judge last value + if (entryList.get(entryList.size() - 1).getValue() > 1.0) { + entryList.get(entryList.size() - 1).setValue(1.0); + } + + // pick a Object base on its weight + double cursor = Math.random(); + for (Map.Entry<Object, Double> entry : entryList) { + if (cursor < entry.getValue()) { + return entry.getKey(); + } + } + return null; + } + + private transient MapObjectInspector mapOI; + private transient PrimitiveObjectInspector valueOI; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) + throw new UDFArgumentLengthException( + "Expected one arguments for map_find_max_prob: " + arguments.length); + if (arguments[0].getCategory() != ObjectInspector.Category.MAP) { + throw new UDFArgumentTypeException(0, + "Only map type arguments are accepted for the key but " + arguments[0].getTypeName() + + " was passed as parameter 1."); + } + mapOI = HiveUtils.asMapOI(arguments[0]); + ObjectInspector keyOI = mapOI.getMapKeyObjectInspector(); + + //judge valueOI is a number + valueOI = (PrimitiveObjectInspector) mapOI.getMapValueObjectInspector(); + switch (valueOI.getTypeName()) { + case INT_TYPE_NAME: + case DOUBLE_TYPE_NAME: + case BIGINT_TYPE_NAME: + case FLOAT_TYPE_NAME: + case SMALLINT_TYPE_NAME: + case TINYINT_TYPE_NAME: + case DECIMAL_TYPE_NAME: + case STRING_TYPE_NAME: + // Pass an empty map or a map full of {null, null} will get string type + // An number in string format like "3.5" also support + break; + default: + throw new UDFArgumentException( + "Expected a number but get: " + valueOI.getTypeName()); + } + return keyOI; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + Map<Object, Double> input = processObjectDoubleMap(arguments[0]); + if (input == null) { + return null; + } + // handle empty map + if (input.isEmpty()) { + return null; + } + return algorithm(input); + } + + /** + * Process the data passed by user. + * + * @param argument data passed by user + * @return If all the value is , + * @throws HiveException If get the wrong weight value like {key = "Wang", value = "Zhang"}, + * "Zhang" isn't a number ,this Method will throw exception when + * convertPrimitiveToDouble("Zhang", valueOD) + */ + private Map<Object, Double> processObjectDoubleMap(DeferredObject argument) + throws HiveException { + // get + Map<?, ?> m = mapOI.getMap(argument.get()); + if (m == null) { + return null; + } + if (m.size() == 0) { + return null; + } + // convert + Map<Object, Double> input = new HashMap<>(); + Double avg = 0.0; + for (Map.Entry<?, ?> entry : m.entrySet()) { + Object key = entry.getKey(); + Double value = null; + if (entry.getValue() != null) { + value = PrimitiveObjectInspectorUtils.convertPrimitiveToDouble(entry.getValue(), + valueOI); + if (value < 0) { + throw new UDFArgumentException(entry.getValue() + " < 0"); + } + avg += value; + } + input.put(key, value); + } + avg /= m.size(); + for (Map.Entry<?, ?> entry : input.entrySet()) { + if (entry.getValue() == null) { + Object key = entry.getKey(); + input.put(key, avg); + } + } + return input; + } + + @Override + public String getDisplayString(String[] children) { + return "map_roulette(" + Arrays.toString(children) + ")"; + } + + private static class KvComparator implements Comparator<Map.Entry<Object, Double>> { + + @Override + public int compare(Map.Entry<Object, Double> o1, Map.Entry<Object, Double> o2) { + return o1.getValue().compareTo(o2.getValue()); + } + } + +} diff --git a/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java b/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java new file mode 100644 index 0000000..a7497d8 --- /dev/null +++ b/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.map; + +import hivemall.TestUtils; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Test; +import java.io.IOException; +import java.util.*; + +/** + * Unit test for {@link hivemall.tools.map.MapRouletteUDF} + * + * @author Wang, Yizheng + */ +public class MapRouletteUDFTest { + + /** + * Tom, Jerry, Amy, Wong, Zhao joined a roulette. Jerry has 0.2 weight to win. Zhao's weight is + * highest, he has more chance to win. During data processing ,Tom 's weight was Lost. Algorithm + * treat Tom 's weight as average. After 1000000 times of roulette, Zhao wins the most. Jerry + * wins less than Zhao but more than the other. + * + * @throws HiveException fmp.initialize may throws UDFArgumentException when checking parameter, + * org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector#getMap(java.lang.Object) + * may throw Hive Exception + */ + @Test + public void testRoulette() throws HiveException { + MapRouletteUDF fmp = new MapRouletteUDF(); + fmp.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)}); + Map<Object, Integer> solve = new HashMap<>(); + solve.put("Tom", 0); + solve.put("Jerry", 0); + solve.put("Amy", 0); + solve.put("Wong", 0); + solve.put("Zhao", 0); + int T = 1000000; + while (T-- > 0) { + Map<Object, Double> m = new HashMap<>(); + m.put("Tom", null); + m.put("Jerry", 0.2); + m.put("Amy", 0.1); + m.put("Wong", 0.1); + m.put("Zhao", 0.5); + GenericUDF.DeferredObject[] arguments = + new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)}; + Object key = fmp.evaluate(arguments); + solve.put(key, solve.get(key) + 1); + } + List<Map.Entry<Object, Integer>> solveList = new ArrayList<>(solve.entrySet()); + Collections.sort(solveList, new KvComparator()); + Object highestSolve = solveList.get(solveList.size() - 1).getKey(); + Assert.assertEquals(highestSolve.toString(), "Zhao"); + Object secondarySolve = solveList.get(solveList.size() - 2).getKey(); + Assert.assertEquals(secondarySolve.toString(), "Jerry"); + } + + private static class KvComparator implements Comparator<Map.Entry<Object, Integer>> { + + @Override + public int compare(Map.Entry<Object, Integer> o1, Map.Entry<Object, Integer> o2) { + return o1.getValue().compareTo(o2.getValue()); + } + } + + @Test + public void testSerialization() throws HiveException, IOException { + Map<Object, Double> m = new HashMap<>(); + m.put("Tom", 0.1); + m.put("Jerry", 0.2); + m.put("Amy", 0.1); + m.put("Wong", 0.1); + m.put("Zhao", null); + + TestUtils.testGenericUDFSerialization(MapRouletteUDF.class, + new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)}, + new Object[] {m}); + byte[] serialized = TestUtils.serializeObjectByKryo(new MapRouletteUDFTest()); + TestUtils.deserializeObjectByKryo(serialized, MapRouletteUDFTest.class); + } + + @Test + public void testEmptyMapAndAllNullMap() throws HiveException { + MapRouletteUDF udf = new MapRouletteUDF(); + Map<Object, Double> m = new HashMap<>(); + udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)}); + GenericUDF.DeferredObject[] arguments = + new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)}; + Assert.assertNull(udf.evaluate(arguments)); + m.put(null, null); + arguments = new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)}; + Assert.assertNull(udf.evaluate(arguments)); + } + + @Test + public void testOnlyOne() throws HiveException { + MapRouletteUDF udf = new MapRouletteUDF(); + Map<Object, Double> m = new HashMap<>(); + udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)}); + m.put("One", 324.6); + GenericUDF.DeferredObject[] arguments = + new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)}; + Assert.assertEquals(udf.evaluate(arguments), "One"); + } + + @Test + public void testString() throws HiveException { + MapRouletteUDF udf = new MapRouletteUDF(); + Map<Object, String> m = new HashMap<>(); + udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector)}); + m.put("One", "0.7"); + GenericUDF.DeferredObject[] arguments = + new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)}; + Assert.assertEquals(udf.evaluate(arguments), "One"); + } +} diff --git a/docs/gitbook/misc/generic_funcs.md b/docs/gitbook/misc/generic_funcs.md index 4f53f4d..328969b 100644 --- a/docs/gitbook/misc/generic_funcs.md +++ b/docs/gitbook/misc/generic_funcs.md @@ -539,7 +539,43 @@ This page describes a list of useful Hivemall generic functions. See also a [lis to_ordered_map(key, value, -100) -- {3:"banana",4:"candy",10:"apple"} (tail-100) from t ``` - + +- `map_roulette(Map<key, number> map)` - Returns the `key` which determine to its `number` weight, the bigger weight is ,the more probability K will return.`Number` is a probability value or a positive weight + + We can use `map_roulette()` on a `Map<key, number>` that was secured from data. + ```sql + select map_roulette(to_map(a, b)) -- 25% Tom, 21% Zhang, 54% Wang + from( + select 'Wang' as a, 54 as b + union + select 'Zhang' as a, 21 as b + union + select 'Tom' as a, 25 as b + )tmp; + ``` + We can pass an `empty map` or a map full of `null` value. Then we will get `null`. + ```sql + select map_roulette(map(null, null, null, null)); -- NULL + select map_roulette(map()); -- NULL + ``` + An occasional `null` weight will be treated as average weight. + ```sql + select map_roulette(map(1, 0.5, 'Wang', null)); -- 50% Wang, 50% 1 + select map_roulette(map(1, 0.5, 'Wang', null, 'Zhang', null)); -- 1/3 Wang, 1/3 1, 1/3 Zhang + ``` + All the weight is zero will return `null`. + ```sql + select map_roulette(map(1, 0)); -- NULL + select map_roulette(map(1, 0, '5', 0)); -- NULL + ``` + This udf isn't support non-numeric weight or negative weight. + ```sql + select map_roulette(map('Wong', 'A string', 'Zhao', 2)); + --Failed with exception java.io.IOException:org.apache.hadoop.hive.ql.metadata.HiveException: Error evaluating map_roulette([map('Wong':'A string','Zhao':2)]) + select map_roulette(map('Wong', 3, 'Zhao', -2)); + -- Failed with exception java.io.IOException:org.apache.hadoop.hive.ql.exec.UDFArgumentException: -2 < 0 + ``` + # MapReduce - `distcache_gets(filepath, key, default_value [, parseKey])` - Returns map<key_type, value_type>|value_type diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index e6f7c0b..4faaeed 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -507,6 +507,9 @@ create temporary function map_get as 'hivemall.tools.map.MapGetUDF'; drop temporary function if exists map_key_values; create temporary function map_key_values as 'hivemall.tools.map.MapKeyValuesUDF'; +drop temporary function if exists map_roulette; +create temporary function map_roulette as 'hivemall.tools.map.MapRouletteUDF'; + --------------------- -- list functions -- ---------------------
