http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/json/ToJsonUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/json/ToJsonUDF.java b/core/src/main/java/hivemall/tools/json/ToJsonUDF.java index 70c62b9..95d1af9 100644 --- a/core/src/main/java/hivemall/tools/json/ToJsonUDF.java +++ b/core/src/main/java/hivemall/tools/json/ToJsonUDF.java @@ -37,8 +37,82 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.Text; +// @formatter:off @Description(name = "to_json", - value = "_FUNC_(ANY object [, const array<string>|const string columnNames]) - Returns Json string") + value = "_FUNC_(ANY object [, const array<string>|const string columnNames]) - Returns Json string", + extended = "SELECT \n" + + " NAMED_STRUCT(\"Name\", \"John\", \"age\", 31),\n" + + " to_json(\n" + + " NAMED_STRUCT(\"Name\", \"John\", \"age\", 31)\n" + + " ),\n" + + " to_json(\n" + + " NAMED_STRUCT(\"Name\", \"John\", \"age\", 31),\n" + + " array('Name', 'age')\n" + + " ),\n" + + " to_json(\n" + + " NAMED_STRUCT(\"Name\", \"John\", \"age\", 31),\n" + + " array('name', 'age')\n" + + " ),\n" + + " to_json(\n" + + " NAMED_STRUCT(\"Name\", \"John\", \"age\", 31),\n" + + " array('age')\n" + + " ),\n" + + " to_json(\n" + + " NAMED_STRUCT(\"Name\", \"John\", \"age\", 31),\n" + + " array()\n" + + " ),\n" + + " to_json(\n" + + " null,\n" + + " array()\n" + + " ),\n" + + " to_json(\n" + + " struct(\"123\", \"456\", 789, array(314,007)),\n" + + " array('ti','si','i','bi')\n" + + " ),\n" + + " to_json(\n" + + " struct(\"123\", \"456\", 789, array(314,007)),\n" + + " 'ti,si,i,bi'\n" + + " ),\n" + + " to_json(\n" + + " struct(\"123\", \"456\", 789, array(314,007))\n" + + " ),\n" + + " to_json(\n" + + " NAMED_STRUCT(\"country\", \"japan\", \"city\", \"tokyo\")\n" + + " ),\n" + + " to_json(\n" + + " NAMED_STRUCT(\"country\", \"japan\", \"city\", \"tokyo\"), \n" + + " array('city')\n" + + " ),\n" + + " to_json(\n" + + " ARRAY(\n" + + " NAMED_STRUCT(\"country\", \"japan\", \"city\", \"tokyo\"), \n" + + " NAMED_STRUCT(\"country\", \"japan\", \"city\", \"osaka\")\n" + + " )\n" + + " ),\n" + + " to_json(\n" + + " ARRAY(\n" + + " NAMED_STRUCT(\"country\", \"japan\", \"city\", \"tokyo\"), \n" + + " NAMED_STRUCT(\"country\", \"japan\", \"city\", \"osaka\")\n" + + " ),\n" + + " array('city')\n" + + " );\n" + + "```\n\n" + + "```\n" + + " {\"name\":\"John\",\"age\":31}\n" + + " {\"name\":\"John\",\"age\":31}\n" + + " {\"Name\":\"John\",\"age\":31}\n" + + " {\"name\":\"John\",\"age\":31}\n" + + " {\"age\":31}\n" + + " {}\n" + + " NULL\n" + + " {\"ti\":\"123\",\"si\":\"456\",\"i\":789,\"bi\":[314,7]}\n" + + " {\"ti\":\"123\",\"si\":\"456\",\"i\":789,\"bi\":[314,7]}\n" + + " {\"col1\":\"123\",\"col2\":\"456\",\"col3\":789,\"col4\":[314,7]}\n" + + " {\"country\":\"japan\",\"city\":\"tokyo\"}\n" + + " {\"city\":\"tokyo\"}\n" + + " [{\"country\":\"japan\",\"city\":\"tokyo\"},{\"country\":\"japan\",\"city\":\"osaka\"}]\n" + + " [{\"country\":\"japan\",\"city\":\"tokyo\"},{\"country\":\"japan\",\"city\":\"osaka\"}]") +// @formatter:on @UDFType(deterministic = true, stateful = false) public final class ToJsonUDF extends GenericUDF {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java index 5ef6ddb..83adb0f 100644 --- a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java +++ b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java @@ -66,15 +66,16 @@ import org.apache.hadoop.io.IntWritable; /** * Return list of values sorted by value itself or specific key. */ +//@formatter:off @Description(name = "to_ordered_list", value = "_FUNC_(PRIMITIVE value [, PRIMITIVE key, const string options])" + " - Return list of values sorted by value itself or specific key", - extended = "with t as (\n" + " select 5 as key, 'apple' as value\n" + " union all\n" - + " select 3 as key, 'banana' as value\n" + " union all\n" - + " select 4 as key, 'candy' as value\n" + " union all\n" - + " select 2 as key, 'donut' as value\n" + " union all\n" - + " select 3 as key, 'egg' as value\n" + ")\n" - + "select -- expected output\n" + extended = "WITH t as (\n" + " SELECT 5 as key, 'apple' as value\n" + " UNION ALL\n" + + " SELECT 3 as key, 'banana' as value\n" + " UNION ALL\n" + + " SELECT 4 as key, 'candy' as value\n" + " UNION ALL\n" + + " SELECT 2 as key, 'donut' as value\n" + " UNION ALL\n" + + " SELECT 3 as key, 'egg' as value\n" + ")\n" + + "SELECT -- expected output\n" + " to_ordered_list(value, key, '-reverse'), -- [apple, candy, (banana, egg | egg, banana), donut] (reverse order)\n" + " to_ordered_list(value, key, '-k 2'), -- [apple, candy] (top-k)\n" + " to_ordered_list(value, key, '-k 100'), -- [apple, candy, (banana, egg | egg, banana), dunut]\n" @@ -86,7 +87,8 @@ import org.apache.hadoop.io.IntWritable; + " to_ordered_list(value, '-k 2'), -- [egg, donut] (alphabetically)\n" + " to_ordered_list(key, '-k -2 -reverse'), -- [5, 4] (top-2 keys)\n" + " to_ordered_list(key) -- [2, 3, 3, 4, 5] (natural ordered keys)\n" - + "from\n" + " t") + + "FROM\n" + " t") +//@formatter:on public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/map/MapExcludeKeysUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/map/MapExcludeKeysUDF.java b/core/src/main/java/hivemall/tools/map/MapExcludeKeysUDF.java new file mode 100644 index 0000000..6a46210 --- /dev/null +++ b/core/src/main/java/hivemall/tools/map/MapExcludeKeysUDF.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.map; + +import hivemall.utils.hadoop.HiveUtils; + +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang.StringUtils; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; + +@Description(name = "map_exclude_keys", + value = "_FUNC_(Map<K,V> map, array<K> filteringKeys)" + + " - Returns the filtered entries of a map not having specified keys", + extended = "SELECT map_exclude_keys(map(1,'one',2,'two',3,'three'),array(2,3));\n" + + "{1:\"one\"}") +@UDFType(deterministic = true, stateful = false) +public final class MapExcludeKeysUDF extends GenericUDF { + + private MapObjectInspector mapOI; + private ListObjectInspector listOI; + + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 2) { + throw new UDFArgumentLengthException( + "Expected two arguments for map_filter_keys: " + argOIs.length); + } + + this.mapOI = HiveUtils.asMapOI(argOIs[0]); + this.listOI = HiveUtils.asListOI(argOIs[1]); + + ObjectInspector mapKeyOI = mapOI.getMapKeyObjectInspector(); + ObjectInspector filterKeyOI = listOI.getListElementObjectInspector(); + + if (!ObjectInspectorUtils.compareTypes(mapKeyOI, filterKeyOI)) { + throw new UDFArgumentException("Element types does not match: mapKey " + + mapKeyOI.getTypeName() + ", filterKey" + filterKeyOI.getTypeName()); + } + + return ObjectInspectorUtils.getStandardObjectInspector(mapOI, + ObjectInspectorCopyOption.WRITABLE); + } + + @Override + public Map<?, ?> evaluate(DeferredObject[] arguments) throws HiveException { + Object arg0 = arguments[0].get(); + if (arg0 == null) { + return null; + } + final Map<?, ?> map = (Map<?, ?>) ObjectInspectorUtils.copyToStandardObject(arg0, mapOI, + ObjectInspectorCopyOption.WRITABLE); + + Object arg1 = arguments[1].get(); + if (arg1 == null) { + return map; + } + + final List<?> filterKeys = (List<?>) ObjectInspectorUtils.copyToStandardObject(arg1, listOI, + ObjectInspectorCopyOption.WRITABLE); + for (Object k : filterKeys) { + map.remove(k); + } + + return map; + } + + @Override + public String getDisplayString(String[] children) { + return "map_exclude_keys(" + StringUtils.join(children, ',') + ")"; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/map/MapIncludeKeysUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/map/MapIncludeKeysUDF.java b/core/src/main/java/hivemall/tools/map/MapIncludeKeysUDF.java new file mode 100644 index 0000000..902569f --- /dev/null +++ b/core/src/main/java/hivemall/tools/map/MapIncludeKeysUDF.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.map; + +import hivemall.utils.hadoop.HiveUtils; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang.StringUtils; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; + +@Description(name = "map_include_keys", + value = "_FUNC_(Map<K,V> map, array<K> filteringKeys)" + + " - Returns the filtered entries of a map having specified keys", + extended = "SELECT map_include_keys(map(1,'one',2,'two',3,'three'),array(2,3));\n" + + "{2:\"two\",3:\"three\"}") +@UDFType(deterministic = true, stateful = false) +public final class MapIncludeKeysUDF extends GenericUDF { + + private MapObjectInspector mapOI; + private ListObjectInspector listOI; + + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 2) { + throw new UDFArgumentLengthException( + "Expected two arguments for map_filter_keys: " + argOIs.length); + } + + this.mapOI = HiveUtils.asMapOI(argOIs[0]); + this.listOI = HiveUtils.asListOI(argOIs[1]); + + ObjectInspector mapKeyOI = mapOI.getMapKeyObjectInspector(); + ObjectInspector filterKeyOI = listOI.getListElementObjectInspector(); + if (!ObjectInspectorUtils.compareTypes(mapKeyOI, filterKeyOI)) { + throw new UDFArgumentException("Element types does not match: mapKey " + + mapKeyOI.getTypeName() + ", filterKey" + filterKeyOI.getTypeName()); + } + + return ObjectInspectorUtils.getStandardObjectInspector(mapOI, + ObjectInspectorCopyOption.WRITABLE); + } + + @Override + public Map<?, ?> evaluate(DeferredObject[] arguments) throws HiveException { + Object arg0 = arguments[0].get(); + if (arg0 == null) { + return null; + } + final Map<?, ?> map = (Map<?, ?>) ObjectInspectorUtils.copyToStandardObject(arg0, mapOI, + ObjectInspectorCopyOption.WRITABLE); + + Object arg1 = arguments[1].get(); + if (arg1 == null) { + return null; + } + final List<?> filterKeys = (List<?>) ObjectInspectorUtils.copyToStandardObject(arg1, listOI, + ObjectInspectorCopyOption.WRITABLE); + + final Map<Object, Object> result = new HashMap<>(); + for (Object k : filterKeys) { + Object v = map.get(k); + if (v != null) { + result.put(k, v); + } + } + return result; + } + + @Override + public String getDisplayString(String[] children) { + return "map_include_keys(" + StringUtils.join(children, ',') + ")"; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/map/MapIndexUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/map/MapIndexUDF.java b/core/src/main/java/hivemall/tools/map/MapIndexUDF.java new file mode 100644 index 0000000..73ffa36 --- /dev/null +++ b/core/src/main/java/hivemall/tools/map/MapIndexUDF.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.map; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; + +//@formatter:off +@Description(name = "map_index", + value = "_FUNC_(a, n) - Returns the n-th element of the given array", + extended = "WITH tmp as (\n" + + " SELECT \"one\" as key\n" + + " UNION ALL\n" + + " SELECT \"two\" as key\n" + + ")\n" + + "SELECT map_index(map(\"one\",1,\"two\",2),key)\n" + + "FROM tmp;\n\n" + + "1\n" + + "2") +//@formatter:on +@UDFType(deterministic = true, stateful = false) +public final class MapIndexUDF extends GenericUDF { + + private transient MapObjectInspector mapOI; + private transient Converter converter; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 2) { + throw new UDFArgumentLengthException("The function INDEX accepts exactly 2 arguments."); + } + + if (arguments[0] instanceof MapObjectInspector) { + this.mapOI = (MapObjectInspector) arguments[0]; + } else { + throw new UDFArgumentTypeException(0, "\"map\" is expected at function INDEX, but \"" + + arguments[0].getTypeName() + "\" is found"); + } + + // index has to be a primitive + if (!(arguments[1] instanceof PrimitiveObjectInspector)) { + throw new UDFArgumentTypeException(1, + "Primitive Type is expected but " + arguments[1].getTypeName() + "\" is found"); + } + + PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) arguments[1]; + ObjectInspector indexOI = + ObjectInspectorConverters.getConvertedOI(inputOI, mapOI.getMapKeyObjectInspector()); + this.converter = ObjectInspectorConverters.getConverter(inputOI, indexOI); + + return mapOI.getMapValueObjectInspector(); + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert (arguments.length == 2); + Object index = arguments[1].get(); + + Object indexObject = converter.convert(index); + if (indexObject == null) { + return null; + } + + Object arg0 = arguments[0].get(); + if (arg0 == null) { + return null; + } + + return mapOI.getMapValueElement(arg0, indexObject); + } + + @Override + public String getDisplayString(String[] children) { + assert (children.length == 2); + return children[0] + "[" + children[1] + "]"; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/map/MapKeyValuesUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/map/MapKeyValuesUDF.java b/core/src/main/java/hivemall/tools/map/MapKeyValuesUDF.java new file mode 100644 index 0000000..3992f9e --- /dev/null +++ b/core/src/main/java/hivemall/tools/map/MapKeyValuesUDF.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.map; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nullable; + +import org.apache.commons.lang.StringUtils; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; + +@Description(name = "map_key_values", + value = "_FUNC_(map) - " + "Returns a array of key-value pairs.", + extended = "SELECT map_key_values(map(\"one\",1,\"two\",2));\n\n" + + "[{\"key\":\"one\",\"value\":1},{\"key\":\"two\",\"value\":2}]") +@UDFType(deterministic = true, stateful = false) +public final class MapKeyValuesUDF extends GenericUDF { + + private final ArrayList<Object[]> retArray = new ArrayList<Object[]>(); + + private MapObjectInspector mapOI; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentLengthException( + "The function MAP_KEYS only accepts one argument."); + } else if (!(arguments[0] instanceof MapObjectInspector)) { + throw new UDFArgumentTypeException(0, + "\"" + Category.MAP.toString().toLowerCase() + + "\" is expected at function MAP_KEYS, " + "but \"" + + arguments[0].getTypeName() + "\" is found"); + } + + this.mapOI = (MapObjectInspector) arguments[0]; + + List<String> structFieldNames = new ArrayList<String>(); + List<ObjectInspector> structFieldObjectInspectors = new ArrayList<ObjectInspector>(); + structFieldNames.add("key"); + structFieldObjectInspectors.add(mapOI.getMapKeyObjectInspector()); + structFieldNames.add("value"); + structFieldObjectInspectors.add(mapOI.getMapValueObjectInspector()); + + return ObjectInspectorFactory.getStandardListObjectInspector( + ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames, + structFieldObjectInspectors)); + } + + @Override + @Nullable + public List<Object[]> evaluate(DeferredObject[] arguments) throws HiveException { + Object mapObj = arguments[0].get(); + if (mapObj == null) { + return null; + } + retArray.clear(); + final Map<?, ?> map = mapOI.getMap(mapObj); + for (Map.Entry<?, ?> e : map.entrySet()) { + retArray.add(new Object[] {e.getKey(), e.getValue()}); + } + return retArray; + } + + @Override + public String getDisplayString(String[] children) { + return "map_key_values(" + StringUtils.join(children, ',') + ')'; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/map/MergeMapsUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/map/MergeMapsUDAF.java b/core/src/main/java/hivemall/tools/map/MergeMapsUDAF.java new file mode 100644 index 0000000..e4a2516 --- /dev/null +++ b/core/src/main/java/hivemall/tools/map/MergeMapsUDAF.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.map; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Preconditions; + +import java.util.HashMap; +import java.util.Map; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +//@formatter:off +@Description(name = "merge_maps", + value = "_FUNC_(x) - Returns a map which contains the union of an aggregation of maps." + + " Note that an existing value of a key can be replaced with the other duplicate key entry.", + extended = "SELECT \n" + + " merge_maps(m) \n" + + "FROM (\n" + + " SELECT map('A',10,'B',20,'C',30) \n" + + " UNION ALL \n" + + " SELECT map('A',10,'B',20,'C',30)\n" + + ") t") +//@formatter:on +public final class MergeMapsUDAF extends AbstractGenericUDAFResolver { + + @Override + public MergeMapsEvaluator getEvaluator(TypeInfo[] types) throws SemanticException { + if (types.length != 1) { + throw new UDFArgumentTypeException(types.length - 1, + "One argument is expected but got " + types.length); + } + TypeInfo paramType = types[0]; + if (paramType.getCategory() != Category.MAP) { + throw new UDFArgumentTypeException(0, "Only maps supported for now "); + } + return new MergeMapsEvaluator(); + } + + public static final class MergeMapsEvaluator extends GenericUDAFEvaluator { + + private transient MapObjectInspector inputMapOI, mergeMapOI; + private transient ObjectInspector inputKeyOI, inputValOI; + + @AggregationType(estimable = false) + static final class MapAggBuffer extends AbstractAggregationBuffer { + @Nonnull + final Map<Object, Object> collectMap = new HashMap<Object, Object>(); + } + + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + Preconditions.checkArgument(parameters.length == 1); + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.inputMapOI = HiveUtils.asMapOI(parameters[0]); + this.inputKeyOI = inputMapOI.getMapKeyObjectInspector(); + this.inputValOI = inputMapOI.getMapValueObjectInspector(); + } else {// from partial aggregation + this.mergeMapOI = HiveUtils.asMapOI(parameters[0]); + this.inputKeyOI = mergeMapOI.getMapKeyObjectInspector(); + this.inputValOI = mergeMapOI.getMapValueObjectInspector(); + } + + return ObjectInspectorFactory.getStandardMapObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI), + ObjectInspectorUtils.getStandardObjectInspector(inputValOI)); + } + + @Override + public MapAggBuffer getNewAggregationBuffer() throws HiveException { + MapAggBuffer buff = new MapAggBuffer(); + reset(buff); + return buff; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer buff) + throws HiveException { + MapAggBuffer aggrBuf = (MapAggBuffer) buff; + aggrBuf.collectMap.clear(); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + Preconditions.checkArgument(parameters.length == 1); + + Object param0 = parameters[0]; + if (param0 == null) { + return; + } + + Map<?, ?> m = inputMapOI.getMap(param0); + MapAggBuffer myagg = (MapAggBuffer) agg; + putIntoSet(m, myagg.collectMap, inputMapOI); + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + MapAggBuffer myagg = (MapAggBuffer) agg; + Map<?, ?> m = mergeMapOI.getMap(partial); + putIntoSet(m, myagg.collectMap, mergeMapOI); + } + + private static void putIntoSet(@Nonnull final Map<?, ?> m, + @Nonnull final Map<Object, Object> dst, @Nonnull final MapObjectInspector mapOI) { + final ObjectInspector keyOI = mapOI.getMapKeyObjectInspector(); + final ObjectInspector valueOI = mapOI.getMapValueObjectInspector(); + + for (Map.Entry<?, ?> e : m.entrySet()) { + Object k = e.getKey(); + Object v = e.getValue(); + Object keyCopy = ObjectInspectorUtils.copyToStandardObject(k, keyOI); + Object valCopy = ObjectInspectorUtils.copyToStandardObject(v, valueOI); + dst.put(keyCopy, valCopy); + } + } + + @Override + @Nonnull + public Map<Object, Object> terminatePartial( + @SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { + MapAggBuffer myagg = (MapAggBuffer) agg; + return myagg.collectMap; + } + + @Override + public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + MapAggBuffer myagg = (MapAggBuffer) agg; + return myagg.collectMap; + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java b/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java index 3c92e30..95c97dc 100644 --- a/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java +++ b/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java @@ -28,8 +28,8 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.io.LongWritable; -@Description(name = "rownum", value = "_FUNC_() - Returns a generated row number in long", - extended = "returns sprintf(`%d%04d`,sequence,taskId) as long") +@Description(name = "rownum", value = "_FUNC_() - Returns a generated row number `sprintf(`%d%04d`,sequence,taskId)` in long", + extended = "SELECT rownum() as rownum, xxx from ...") @UDFType(deterministic = false, stateful = true) public final class RowNumberUDF extends UDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/sanity/AssertUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/sanity/AssertUDF.java b/core/src/main/java/hivemall/tools/sanity/AssertUDF.java index d34cd20..c74aeb8 100644 --- a/core/src/main/java/hivemall/tools/sanity/AssertUDF.java +++ b/core/src/main/java/hivemall/tools/sanity/AssertUDF.java @@ -25,8 +25,10 @@ import org.apache.hadoop.hive.ql.udf.UDFType; @Description(name = "assert", value = "_FUNC_(boolean condition) or _FUNC_(boolean condition, string errMsg)" - + "- Throws HiveException if condition is not met") -@UDFType(deterministic = true, stateful = false) + + "- Throws HiveException if condition is not met", + extended = "SELECT count(1) FROM stock_price WHERE assert(price > 0.0);\n" + + "SELECT count(1) FROM stock_price WHERE assert(price > 0.0, 'price MUST be more than 0.0')") +@UDFType(deterministic = false, stateful = false) public final class AssertUDF extends UDF { public boolean evaluate(boolean condition) throws HiveException { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/sanity/RaiseErrorUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/sanity/RaiseErrorUDF.java b/core/src/main/java/hivemall/tools/sanity/RaiseErrorUDF.java index 194085c..fb6b8eb 100644 --- a/core/src/main/java/hivemall/tools/sanity/RaiseErrorUDF.java +++ b/core/src/main/java/hivemall/tools/sanity/RaiseErrorUDF.java @@ -18,21 +18,48 @@ */ package hivemall.tools.sanity; +import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -@Description(name = "raise_error", value = "_FUNC_() or _FUNC_(string msg) - Throws an error") -@UDFType(deterministic = true, stateful = false) -public final class RaiseErrorUDF extends UDF { +@Description(name = "raise_error", value = "_FUNC_() or _FUNC_(string msg) - Throws an error", + extended = "SELECT product_id, price, raise_error('Found an invalid record') FROM xxx WHERE price < 0.0") +@UDFType(deterministic = false, stateful = false) +public class RaiseErrorUDF extends GenericUDF { - public boolean evaluate() throws HiveException { - throw new HiveException(); + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 0 && argOIs.length != 1) { + throw new UDFArgumentLengthException( + "Expected one or two arguments for raise_error UDF: " + argOIs.length); + } + + return PrimitiveObjectInspectorFactory.writableBooleanObjectInspector; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + if (arguments.length == 1) { + Object arg0 = arguments[0].get(); + if (arg0 == null) { + throw new HiveException(); + } + String msg = arg0.toString(); + throw new HiveException(msg); + } else { + throw new HiveException(); + } } - public boolean evaluate(String errorMessage) throws HiveException { - throw new HiveException(errorMessage); + @Override + public String getDisplayString(String[] children) { + return "raise_error(" + StringUtils.join(children, ',') + ')'; } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/text/Base91UDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/text/Base91UDF.java b/core/src/main/java/hivemall/tools/text/Base91UDF.java index 6f52599..73f365b 100644 --- a/core/src/main/java/hivemall/tools/text/Base91UDF.java +++ b/core/src/main/java/hivemall/tools/text/Base91UDF.java @@ -40,7 +40,7 @@ import org.apache.hadoop.io.Text; @Description(name = "base91", value = "_FUNC_(BINARY bin) - Convert the argument from binary to a BASE91 string", - extended = "select base91(deflate('aaaaaaaaaaaaaaaabbbbccc'));\n" + "> AA+=kaIM|WTt!+wbGAA") + extended = "SELECT base91(deflate('aaaaaaaaaaaaaaaabbbbccc'));\n" + " AA+=kaIM|WTt!+wbGAA") @UDFType(deterministic = true, stateful = false) public final class Base91UDF extends GenericUDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/text/NormalizeUnicodeUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/text/NormalizeUnicodeUDF.java b/core/src/main/java/hivemall/tools/text/NormalizeUnicodeUDF.java index aefb4e2..9a7b547 100644 --- a/core/src/main/java/hivemall/tools/text/NormalizeUnicodeUDF.java +++ b/core/src/main/java/hivemall/tools/text/NormalizeUnicodeUDF.java @@ -29,8 +29,8 @@ import org.apache.hadoop.hive.ql.udf.UDFType; @Description(name = "normalize_unicode", value = "_FUNC_(string str [, string form]) - Transforms `str` with the specified normalization form. " + "The `form` takes one of NFC (default), NFD, NFKC, or NFKD", - extended = "select normalize_unicode('ï¾ï¾ï½¶ï½¸ï½¶ï¾ ','NFKC');\n" + "> ãã³ã«ã¯ã«ã\n" + "\n" - + "select normalize_unicode('ã±ã§ã¦â ¢','NFKC');\n" + "> (æ ª)ãã³ãã«III") + extended = "SELECT normalize_unicode('ï¾ï¾ï½¶ï½¸ï½¶ï¾ ','NFKC');\n" + " ãã³ã«ã¯ã«ã\n" + "\n" + + "SELECT normalize_unicode('ã±ã§ã¦â ¢','NFKC');\n" + " (æ ª)ãã³ãã«III") @UDFType(deterministic = true, stateful = false) public final class NormalizeUnicodeUDF extends UDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/text/SingularizeUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/text/SingularizeUDF.java b/core/src/main/java/hivemall/tools/text/SingularizeUDF.java index 73d2d63..3b16828 100644 --- a/core/src/main/java/hivemall/tools/text/SingularizeUDF.java +++ b/core/src/main/java/hivemall/tools/text/SingularizeUDF.java @@ -40,7 +40,7 @@ import org.apache.hadoop.hive.ql.udf.UDFType; // https://github.com/clips/pattern/blob/3eef00481a4555331cf9a099308910d977f6fc22/pattern/text/en/inflect.py#L445-L623 @Description(name = "singularize", value = "_FUNC_(string word) - Returns singular form of a given English word", - extended = "select singularize(lower(\"Apples\"));\n" + "\n" + "> \"apple\"") + extended = "SELECT singularize(lower(\"Apples\"));\n" + "\n" + " \"apple\"") @UDFType(deterministic = true, stateful = false) public final class SingularizeUDF extends UDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/text/SplitWordsUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/text/SplitWordsUDF.java b/core/src/main/java/hivemall/tools/text/SplitWordsUDF.java index 0b10c2f..31d155d 100644 --- a/core/src/main/java/hivemall/tools/text/SplitWordsUDF.java +++ b/core/src/main/java/hivemall/tools/text/SplitWordsUDF.java @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.io.Text; @Description(name = "split_words", - value = "_FUNC_(string query [, string regex]) - Returns an array<text> containing split strings") + value = "_FUNC_(string query [, string regex]) - Returns an array<text> containing splitted strings") @UDFType(deterministic = true, stateful = false) public final class SplitWordsUDF extends UDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/text/Unbase91UDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/text/Unbase91UDF.java b/core/src/main/java/hivemall/tools/text/Unbase91UDF.java index a96b3bf..9c277ce 100644 --- a/core/src/main/java/hivemall/tools/text/Unbase91UDF.java +++ b/core/src/main/java/hivemall/tools/text/Unbase91UDF.java @@ -39,8 +39,8 @@ import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Text; @Description(name = "unbase91", value = "_FUNC_(string) - Convert a BASE91 string to a binary", - extended = "select inflate(unbase91(base91(deflate('aaaaaaaaaaaaaaaabbbbccc'))));\n" - + "> aaaaaaaaaaaaaaaabbbbccc") + extended = "SELECT inflate(unbase91(base91(deflate('aaaaaaaaaaaaaaaabbbbccc'))));\n" + + " aaaaaaaaaaaaaaaabbbbccc") @UDFType(deterministic = true, stateful = false) public final class Unbase91UDF extends GenericUDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/text/WordNgramsUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/text/WordNgramsUDF.java b/core/src/main/java/hivemall/tools/text/WordNgramsUDF.java index fa8308b..9b3658f 100644 --- a/core/src/main/java/hivemall/tools/text/WordNgramsUDF.java +++ b/core/src/main/java/hivemall/tools/text/WordNgramsUDF.java @@ -37,8 +37,8 @@ import java.util.List; @Description(name = "word_ngrams", value = "_FUNC_(array<string> words, int minSize, int maxSize])" + " - Returns list of n-grams for given words, where `minSize <= n <= maxSize`", - extended = "select word_ngrams(tokenize('Machine learning is fun!', true), 1, 2);\n" + "\n" - + "> [\"machine\",\"machine learning\",\"learning\",\"learning is\",\"is\",\"is fun\",\"fun\"]") + extended = "SELECT word_ngrams(tokenize('Machine learning is fun!', true), 1, 2);\n" + "\n" + + " [\"machine\",\"machine learning\",\"learning\",\"learning is\",\"is\",\"is fun\",\"fun\"]") @UDFType(deterministic = true, stateful = false) public final class WordNgramsUDF extends UDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/timeseries/MovingAverageUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/timeseries/MovingAverageUDTF.java b/core/src/main/java/hivemall/tools/timeseries/MovingAverageUDTF.java new file mode 100644 index 0000000..b5267a1 --- /dev/null +++ b/core/src/main/java/hivemall/tools/timeseries/MovingAverageUDTF.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.timeseries; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.stats.MovingAverage; + +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.Writable; + +// @formatter:off +@Description(name = "moving_avg", + value = "_FUNC_(NUMBER value, const int windowSize)" + + " - Returns moving average of a time series using a given window", + extended = "SELECT moving_avg(x, 3) FROM (SELECT explode(array(1.0,2.0,3.0,4.0,5.0,6.0,7.0)) as x) series;\n" + + " 1.0\n" + + " 1.5\n" + + " 2.0\n" + + " 3.0\n" + + " 4.0\n" + + " 5.0\n" + + " 6.0") +// @formatter:on +@UDFType(deterministic = false, stateful = true) +public final class MovingAverageUDTF extends GenericUDTF { + + private PrimitiveObjectInspector valueOI; + + private MovingAverage movingAvg; + + private Writable[] forwardObjs; + private DoubleWritable result; + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 2) { + throw new UDFArgumentException( + "Two argument is expected for moving_avg(NUMBER value, const int windowSize): " + + argOIs.length); + } + this.valueOI = HiveUtils.asNumberOI(argOIs[0]); + + int windowSize = HiveUtils.getConstInt(argOIs[1]); + this.movingAvg = new MovingAverage(windowSize); + + this.result = new DoubleWritable(); + this.forwardObjs = new Writable[] {result}; + + List<String> fieldNames = Arrays.asList("avg"); + List<ObjectInspector> fieldOIs = Arrays.<ObjectInspector>asList( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public void process(Object[] args) throws HiveException { + double x = HiveUtils.getDouble(args[0], valueOI); + + double avg = movingAvg.add(x); + result.set(avg); + + forward(forwardObjs); + } + + @Override + public void close() throws HiveException {} + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/vector/VectorAddUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/vector/VectorAddUDF.java b/core/src/main/java/hivemall/tools/vector/VectorAddUDF.java index 8442ae3..98fab99 100644 --- a/core/src/main/java/hivemall/tools/vector/VectorAddUDF.java +++ b/core/src/main/java/hivemall/tools/vector/VectorAddUDF.java @@ -42,7 +42,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; @Description(name = "vector_add", - value = "_FUNC_(array<NUMBER> x, array<NUMBER> y) - Perform vector ADD operation.") + value = "_FUNC_(array<NUMBER> x, array<NUMBER> y) - Perform vector ADD operation.", + extended = "SELECT vector_add(array(1.0,2.0,3.0), array(2, 3, 4));\n" + "[3.0,5.0,7.0]") @UDFType(deterministic = true, stateful = false) public final class VectorAddUDF extends GenericUDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/tools/vector/VectorDotUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/vector/VectorDotUDF.java b/core/src/main/java/hivemall/tools/vector/VectorDotUDF.java index 2aa3c03..b43b562 100644 --- a/core/src/main/java/hivemall/tools/vector/VectorDotUDF.java +++ b/core/src/main/java/hivemall/tools/vector/VectorDotUDF.java @@ -43,7 +43,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn @Description(name = "vector_dot", value = "_FUNC_(array<NUMBER> x, array<NUMBER> y) - Performs vector dot product.", - extended = "_FUNC_(array<NUMBER> x, NUMBER y) - Performs vector multiplication") + extended = "SELECT vector_dot(array(1.0,2.0,3.0),array(2.0,3.0,4.0));\n20\n\n" + + "SELECT vector_dot(array(1.0,2.0,3.0),2);\n[2.0,4.0,6.0]") @UDFType(deterministic = true, stateful = false) public final class VectorDotUDF extends GenericUDF { @@ -65,19 +66,19 @@ public final class VectorDotUDF extends GenericUDF { ObjectInspector argOI1 = argOIs[1]; if (HiveUtils.isNumberListOI(argOI1)) { this.evaluator = new Dot2DVectors(xListOI, HiveUtils.asListOI(argOI1)); + return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector; } else if (HiveUtils.isNumberOI(argOI1)) { this.evaluator = new Multiply2D1D(xListOI, argOI1); + return ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector); } else { throw new UDFArgumentException( "Expected array<number> or number for the send argument: " + argOI1.getTypeName()); } - - return ObjectInspectorFactory.getStandardListObjectInspector( - PrimitiveObjectInspectorFactory.javaDoubleObjectInspector); } @Override - public List<Double> evaluate(DeferredObject[] args) throws HiveException { + public Object evaluate(DeferredObject[] args) throws HiveException { final Object arg0 = args[0].get(); final Object arg1 = args[1].get(); if (arg0 == null || arg1 == null) { @@ -90,7 +91,7 @@ public final class VectorDotUDF extends GenericUDF { interface Evaluator extends Serializable { @Nonnull - List<Double> dot(@Nonnull Object x, @Nonnull Object y) throws HiveException; + Object dot(@Nonnull Object x, @Nonnull Object y) throws HiveException; } @@ -144,7 +145,7 @@ public final class VectorDotUDF extends GenericUDF { } @Override - public List<Double> dot(@Nonnull Object x, @Nonnull Object y) throws HiveException { + public Double dot(@Nonnull Object x, @Nonnull Object y) throws HiveException { final int xLen = xListOI.getListLength(x); final int yLen = yListOI.getListLength(y); if (xLen != yLen) { @@ -152,7 +153,7 @@ public final class VectorDotUDF extends GenericUDF { + ", y=" + yListOI.getList(y)); } - final Double[] arr = new Double[xLen]; + double result = 0.d; for (int i = 0; i < xLen; i++) { Object xi = xListOI.getListElement(x, i); Object yi = yListOI.getListElement(y, i); @@ -162,10 +163,10 @@ public final class VectorDotUDF extends GenericUDF { double xd = PrimitiveObjectInspectorUtils.getDouble(xi, xElemOI); double yd = PrimitiveObjectInspectorUtils.getDouble(yi, yElemOI); double v = xd * yd; - arr[i] = Double.valueOf(v); + result += v; } - return Arrays.asList(arr); + return Double.valueOf(result); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index f3fe703..12b0e97 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -1211,7 +1211,7 @@ public final class HiveUtils { @Nonnull public static ObjectInspector getObjectInspector(@Nonnull final String typeString, - boolean preferWritable) { + final boolean preferWritable) { TypeInfo typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeString); if (preferWritable) { return TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(typeInfo); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/main/java/hivemall/utils/math/MatrixUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MatrixUtils.java b/core/src/main/java/hivemall/utils/math/MatrixUtils.java index 6c08a61..38329c1 100644 --- a/core/src/main/java/hivemall/utils/math/MatrixUtils.java +++ b/core/src/main/java/hivemall/utils/math/MatrixUtils.java @@ -239,7 +239,6 @@ public final class MatrixUtils { Preconditions.checkArgument(dim >= 1, "Invalid dimension: " + dim); Preconditions.checkArgument(c.length >= dim, "|c| must be greater than " + dim + ": " + c.length); - /* * Toeplitz matrix (symmetric, invertible, k*dimensions by k*dimensions) * @@ -511,8 +510,12 @@ public final class MatrixUtils { } /** - * Find the first singular vector/value of a matrix A based on the Power method. + * Find the first singular vector/value of a matrix A based on the Power method. <<<<<<< HEAD * + * ======= + * + * >>>>>>> Applied spotless-maven-plugin formatter + * * @see http * ://www.cs.yale.edu/homes/el327/datamining2013aFiles/07_singular_value_decomposition.pdf * @param A target matrix http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java index 16196eb..5b7aa8f 100644 --- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java @@ -85,6 +85,7 @@ public class FieldAwareFactorizationMachineUDTFTest { @Test public void testSample() throws IOException, HiveException { + System.setProperty("https.protocols", "TLSv1,TLSv1.1,TLSv1.2"); run("[Sample.ffm] default option", "https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz", "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43", 0.01f); @@ -92,6 +93,7 @@ public class FieldAwareFactorizationMachineUDTFTest { // TODO @Test public void testSampleEnableNorm() throws IOException, HiveException { + System.setProperty("https.protocols", "TLSv1,TLSv1.1,TLSv1.2"); run("[Sample.ffm] default option", "https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz", "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43 -enable_norm", http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java b/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java index d91eb66..06277c2 100644 --- a/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java +++ b/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java @@ -21,6 +21,9 @@ package hivemall.ftvec.hashing; import hivemall.TestUtils; import hivemall.utils.hashing.MurmurHash3; +import java.io.IOException; +import java.util.Arrays; + import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; @@ -29,9 +32,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.junit.Assert; import org.junit.Test; -import java.io.IOException; -import java.util.Arrays; - public class FeatureHashingUDFTest { @Test http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java index 018de82..b789e71 100644 --- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java +++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java @@ -19,6 +19,7 @@ package hivemall.smile.classification; import static org.junit.Assert.assertEquals; + import hivemall.math.matrix.Matrix; import hivemall.math.matrix.builders.CSRMatrixBuilder; import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; @@ -29,6 +30,10 @@ import hivemall.smile.tools.TreeExportUDF.Evaluator; import hivemall.smile.tools.TreeExportUDF.OutputType; import hivemall.smile.utils.SmileExtUtils; import hivemall.utils.codec.Base91; +import smile.data.AttributeDataset; +import smile.data.parser.ArffParser; +import smile.math.Math; +import smile.validation.LOOCV; import java.io.BufferedInputStream; import java.io.IOException; @@ -43,11 +48,6 @@ import org.apache.hadoop.io.Text; import org.junit.Assert; import org.junit.Test; -import smile.data.AttributeDataset; -import smile.data.parser.ArffParser; -import smile.math.Math; -import smile.validation.LOOCV; - public class DecisionTreeTest { private static final boolean DEBUG = false; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java index a4a7f05..9d24b54 100644 --- a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java +++ b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java @@ -27,6 +27,8 @@ import hivemall.smile.data.Attribute.NumericAttribute; import hivemall.smile.tools.TreeExportUDF.Evaluator; import hivemall.smile.tools.TreeExportUDF.OutputType; import hivemall.utils.codec.Base91; +import smile.math.Math; +import smile.validation.LOOCV; import java.io.IOException; import java.text.ParseException; @@ -39,9 +41,6 @@ import org.apache.hadoop.io.Text; import org.junit.Assert; import org.junit.Test; -import smile.math.Math; -import smile.validation.LOOCV; - public class RegressionTreeTest { private static final boolean DEBUG = false; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/statistics/MovingAverageUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/statistics/MovingAverageUDTFTest.java b/core/src/test/java/hivemall/statistics/MovingAverageUDTFTest.java deleted file mode 100644 index e755e26..0000000 --- a/core/src/test/java/hivemall/statistics/MovingAverageUDTFTest.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.statistics; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import hivemall.TestUtils; -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.udf.generic.Collector; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.junit.Assert; -import org.junit.Test; - -public class MovingAverageUDTFTest { - - @Test - public void test() throws HiveException { - MovingAverageUDTF udtf = new MovingAverageUDTF(); - - ObjectInspector argOI0 = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; - ObjectInspector argOI1 = ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaIntObjectInspector, 3); - - final List<Double> results = new ArrayList<>(); - udtf.initialize(new ObjectInspector[] {argOI0, argOI1}); - udtf.setCollector(new Collector() { - @Override - public void collect(Object input) throws HiveException { - Object[] objs = (Object[]) input; - Assert.assertEquals(1, objs.length); - Assert.assertTrue(objs[0] instanceof DoubleWritable); - double x = ((DoubleWritable) objs[0]).get(); - results.add(x); - } - }); - - udtf.process(new Object[] {1.f, null}); - udtf.process(new Object[] {2.f, null}); - udtf.process(new Object[] {3.f, null}); - udtf.process(new Object[] {4.f, null}); - udtf.process(new Object[] {5.f, null}); - udtf.process(new Object[] {6.f, null}); - udtf.process(new Object[] {7.f, null}); - - Assert.assertEquals(Arrays.asList(1.d, 1.5d, 2.d, 3.d, 4.d, 5.d, 6.d), results); - } - - @Test - public void testSerialization() throws HiveException { - TestUtils.testGenericUDTFSerialization(MovingAverageUDTF.class, - new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaFloatObjectInspector, - ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaIntObjectInspector, 3)}, - new Object[][] {{1.f}, {2.f}, {3.f}, {4.f}, {5.f}}); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/tools/GenerateSeriesUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/GenerateSeriesUDTFTest.java b/core/src/test/java/hivemall/tools/GenerateSeriesUDTFTest.java new file mode 100644 index 0000000..04f432c --- /dev/null +++ b/core/src/test/java/hivemall/tools/GenerateSeriesUDTFTest.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools; + +import hivemall.TestUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.Collector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.junit.Assert; +import org.junit.Test; + +public class GenerateSeriesUDTFTest { + + @Test + public void testTwoConstArgs() throws HiveException { + GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); + + udtf.initialize(new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intTypeInfo, new IntWritable(1)), + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intTypeInfo, new IntWritable(3))}); + + final List<IntWritable> actual = new ArrayList<>(); + + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException { + Object[] row = (Object[]) args; + IntWritable row0 = (IntWritable) row[0]; + actual.add(new IntWritable(row0.get())); + } + }); + + udtf.process(new Object[] {new IntWritable(1), new IntWritable(3)}); + + List<IntWritable> expected = + Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3)); + Assert.assertEquals(expected, actual); + } + + @Test + public void testTwoIntArgs() throws HiveException { + GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); + + udtf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.writableIntObjectInspector}); + + final List<IntWritable> actual = new ArrayList<>(); + + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException { + Object[] row = (Object[]) args; + IntWritable row0 = (IntWritable) row[0]; + actual.add(new IntWritable(row0.get())); + } + }); + + udtf.process(new Object[] {1, new IntWritable(3)}); + + List<IntWritable> expected = + Arrays.asList(new IntWritable(1), new IntWritable(2), new IntWritable(3)); + Assert.assertEquals(expected, actual); + } + + @Test + public void testTwoLongArgs() throws HiveException { + GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); + + udtf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.writableLongObjectInspector}); + + final List<LongWritable> actual = new ArrayList<>(); + + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException { + Object[] row = (Object[]) args; + LongWritable row0 = (LongWritable) row[0]; + actual.add(new LongWritable(row0.get())); + } + }); + + udtf.process(new Object[] {1, new LongWritable(3)}); + + List<LongWritable> expected = + Arrays.asList(new LongWritable(1), new LongWritable(2), new LongWritable(3)); + Assert.assertEquals(expected, actual); + } + + @Test + public void testThreeIntArgs() throws HiveException { + GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); + + udtf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.writableIntObjectInspector, + PrimitiveObjectInspectorFactory.javaLongObjectInspector}); + + final List<IntWritable> actual = new ArrayList<>(); + + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException { + Object[] row = (Object[]) args; + IntWritable row0 = (IntWritable) row[0]; + actual.add(new IntWritable(row0.get())); + } + }); + + udtf.process(new Object[] {1, new IntWritable(7), 3L}); + + List<IntWritable> expected = + Arrays.asList(new IntWritable(1), new IntWritable(4), new IntWritable(7)); + Assert.assertEquals(expected, actual); + } + + @Test + public void testThreeLongArgs() throws HiveException { + GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); + + udtf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector, + PrimitiveObjectInspectorFactory.writableLongObjectInspector, + PrimitiveObjectInspectorFactory.javaLongObjectInspector}); + + final List<LongWritable> actual = new ArrayList<>(); + + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException { + Object[] row = (Object[]) args; + LongWritable row0 = (LongWritable) row[0]; + actual.add(new LongWritable(row0.get())); + } + }); + + udtf.process(new Object[] {1L, new LongWritable(7), 3L}); + + List<LongWritable> expected = + Arrays.asList(new LongWritable(1), new LongWritable(4), new LongWritable(7)); + Assert.assertEquals(expected, actual); + } + + @Test + public void testNegativeStepInt() throws HiveException { + GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); + + udtf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.writableIntObjectInspector, + PrimitiveObjectInspectorFactory.javaLongObjectInspector}); + + final List<IntWritable> actual = new ArrayList<>(); + + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException { + Object[] row = (Object[]) args; + IntWritable row0 = (IntWritable) row[0]; + actual.add(new IntWritable(row0.get())); + } + }); + + udtf.process(new Object[] {5, new IntWritable(1), -2L}); + + List<IntWritable> expected = + Arrays.asList(new IntWritable(5), new IntWritable(3), new IntWritable(1)); + Assert.assertEquals(expected, actual); + } + + @Test + public void testNegativeStepLong() throws HiveException { + GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); + + udtf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaLongObjectInspector, + PrimitiveObjectInspectorFactory.writableIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector}); + + final List<LongWritable> actual = new ArrayList<>(); + + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException { + Object[] row = (Object[]) args; + LongWritable row0 = (LongWritable) row[0]; + actual.add(new LongWritable(row0.get())); + } + }); + + udtf.process(new Object[] {5L, new IntWritable(1), -2}); + + List<LongWritable> expected = + Arrays.asList(new LongWritable(5), new LongWritable(3), new LongWritable(1)); + Assert.assertEquals(expected, actual); + } + + @Test + public void testSerialization() throws HiveException { + GenerateSeriesUDTF udtf = new GenerateSeriesUDTF(); + + udtf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.writableIntObjectInspector}); + + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException {} + }); + + udtf.process(new Object[] {1, new IntWritable(3)}); + + byte[] serialized = TestUtils.serializeObjectByKryo(udtf); + TestUtils.deserializeObjectByKryo(serialized, GenerateSeriesUDTF.class); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/tools/array/ArrayAppendUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/ArrayAppendUDFTest.java b/core/src/test/java/hivemall/tools/array/ArrayAppendUDFTest.java index b376abe..1e01274 100644 --- a/core/src/test/java/hivemall/tools/array/ArrayAppendUDFTest.java +++ b/core/src/test/java/hivemall/tools/array/ArrayAppendUDFTest.java @@ -85,22 +85,21 @@ public class ArrayAppendUDFTest { udf.close(); } - @Test - public void testEvaluateReturnNull() throws HiveException, IOException { + public void testEvaluateNullList() throws HiveException, IOException { ArrayAppendUDF udf = new ArrayAppendUDF(); udf.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector( - PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}); DeferredObject[] args = new DeferredObject[] {new GenericUDF.DeferredJavaObject(null), - new GenericUDF.DeferredJavaObject(new Double(3))}; + new GenericUDF.DeferredJavaObject(new Double(3d))}; List<Object> result = udf.evaluate(args); - Assert.assertNull(result); + Assert.assertEquals(Arrays.asList(new DoubleWritable(3d)), result); udf.close(); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/tools/array/ArrayFlattenUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/ArrayFlattenUDFTest.java b/core/src/test/java/hivemall/tools/array/ArrayFlattenUDFTest.java index f69cdd8..11754aa 100644 --- a/core/src/test/java/hivemall/tools/array/ArrayFlattenUDFTest.java +++ b/core/src/test/java/hivemall/tools/array/ArrayFlattenUDFTest.java @@ -18,11 +18,12 @@ */ package hivemall.tools.array; +import hivemall.TestUtils; + import java.io.IOException; import java.util.Arrays; import java.util.List; -import hivemall.TestUtils; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; @@ -64,4 +65,5 @@ public class ArrayFlattenUDFTest { new Object[] {Arrays.asList(Arrays.asList(0, 1, 2, 3), Arrays.asList(4, 5), Arrays.asList(6, 7))}); } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/tools/array/ArraySliceUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/ArraySliceUDFTest.java b/core/src/test/java/hivemall/tools/array/ArraySliceUDFTest.java index a260936..fbc212a 100644 --- a/core/src/test/java/hivemall/tools/array/ArraySliceUDFTest.java +++ b/core/src/test/java/hivemall/tools/array/ArraySliceUDFTest.java @@ -18,12 +18,13 @@ */ package hivemall.tools.array; +import hivemall.TestUtils; + import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; -import hivemall.TestUtils; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; @@ -129,4 +130,5 @@ public class ArraySliceUDFTest { new Object[] {Arrays.asList("zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"), 2, 5}); } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/tools/array/ArrayToStrUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/ArrayToStrUDFTest.java b/core/src/test/java/hivemall/tools/array/ArrayToStrUDFTest.java new file mode 100644 index 0000000..373217f --- /dev/null +++ b/core/src/test/java/hivemall/tools/array/ArrayToStrUDFTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.array; + +import hivemall.TestUtils; + +import java.io.IOException; +import java.util.Arrays; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.Text; +import org.junit.Assert; +import org.junit.Test; + +public class ArrayToStrUDFTest { + + @Test + public void testSimpleCase() throws HiveException, IOException { + ArrayToStrUDF udf = new ArrayToStrUDF(); + + udf.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector), + PrimitiveObjectInspectorFactory.writableStringObjectInspector}); + + Text sep = new Text("#"); + DeferredObject[] args = + new DeferredObject[] {new GenericUDF.DeferredJavaObject(Arrays.asList(1, 2, 3)), + new GenericUDF.DeferredJavaObject(sep)}; + Assert.assertEquals("1#2#3", udf.evaluate(args)); + + args = new DeferredObject[] {new GenericUDF.DeferredJavaObject(Arrays.asList(1, 2, 3)), + new GenericUDF.DeferredJavaObject(null)}; + Assert.assertEquals("1,2,3", udf.evaluate(args)); + + udf.close(); + } + + @Test + public void testNoSep() throws HiveException, IOException { + ArrayToStrUDF udf = new ArrayToStrUDF(); + + udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector)}); + + DeferredObject[] args = + new DeferredObject[] {new GenericUDF.DeferredJavaObject(Arrays.asList(1, 2, 3))}; + + Assert.assertEquals("1,2,3", udf.evaluate(args)); + + udf.close(); + } + + @Test + public void testNull() throws HiveException, IOException { + ArrayToStrUDF udf = new ArrayToStrUDF(); + + udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector)}); + + DeferredObject[] args = + new DeferredObject[] {new GenericUDF.DeferredJavaObject(Arrays.asList(1, null, 3))}; + + Assert.assertEquals("1,3", udf.evaluate(args)); + + args = new DeferredObject[] {new GenericUDF.DeferredJavaObject(Arrays.asList(null, 2, 3))}; + + Assert.assertEquals("2,3", udf.evaluate(args)); + + udf.close(); + } + + @Test + public void testSerialization() throws HiveException, IOException { + TestUtils.testGenericUDFSerialization(ArrayToStrUDF.class, + new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector), + PrimitiveObjectInspectorFactory.javaStringObjectInspector}, + new Object[] {Arrays.asList(1, 2, 3), "-"}); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/tools/array/ArrayUnionUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/ArrayUnionUDFTest.java b/core/src/test/java/hivemall/tools/array/ArrayUnionUDFTest.java index ac0f735..cc17039 100644 --- a/core/src/test/java/hivemall/tools/array/ArrayUnionUDFTest.java +++ b/core/src/test/java/hivemall/tools/array/ArrayUnionUDFTest.java @@ -74,4 +74,5 @@ public class ArrayUnionUDFTest { PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)}, new Object[] {Arrays.asList(0.d, 1.d), Arrays.asList(2.d, 3.d)}); } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/49496032/core/src/test/java/hivemall/tools/array/ConditionalEmitUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/ConditionalEmitUDTFTest.java b/core/src/test/java/hivemall/tools/array/ConditionalEmitUDTFTest.java index ef45983..7045235 100644 --- a/core/src/test/java/hivemall/tools/array/ConditionalEmitUDTFTest.java +++ b/core/src/test/java/hivemall/tools/array/ConditionalEmitUDTFTest.java @@ -18,12 +18,12 @@ */ package hivemall.tools.array; +import hivemall.TestUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import hivemall.TestUtils; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.Collector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -80,4 +80,5 @@ public class ConditionalEmitUDTFTest { {Arrays.asList(true, false, true), Arrays.asList("one", "two", "three")}, {Arrays.asList(true, true, false), Arrays.asList("one", "two", "three")}}); } + }
