Repository: incubator-hivemall Updated Branches: refs/heads/master e4e94f8b0 -> f6765dff7
Fixed a Kryo serialization error in select_k_best UDF Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f6765dff Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f6765dff Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f6765dff Branch: refs/heads/master Commit: f6765dff7be67e1a3327709bbb9bfdc6eba7b97f Parents: e4e94f8 Author: Makoto Yui <m...@apache.org> Authored: Wed Apr 11 14:53:37 2018 +0900 Committer: Makoto Yui <m...@apache.org> Committed: Wed Apr 11 15:00:48 2018 +0900 ---------------------------------------------------------------------- .../hivemall/tools/array/SelectKBestUDF.java | 22 ++++---- core/src/test/java/hivemall/TestUtils.java | 56 ++++++++++++++++++++ .../tools/array/SelectKBestUDFTest.java | 29 ++++++++-- 3 files changed, 93 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f6765dff/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java index b363166..ff37217 100644 --- a/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java +++ b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java @@ -24,7 +24,6 @@ import hivemall.utils.lang.Preconditions; import java.io.IOException; import java.util.AbstractMap; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; @@ -83,24 +82,26 @@ public final class SelectKBestUDF extends GenericUDF { this.featuresOI = HiveUtils.asListOI(OIs[0]); this.featureOI = HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector()); this.importanceListOI = HiveUtils.asListOI(OIs[1]); - this.importanceElemOI = HiveUtils.asDoubleCompatibleOI(importanceListOI.getListElementObjectInspector()); + this.importanceElemOI = + HiveUtils.asDoubleCompatibleOI(importanceListOI.getListElementObjectInspector()); this._k = HiveUtils.getConstInt(OIs[2]); Preconditions.checkArgument(_k >= 1, UDFArgumentException.class); - final DoubleWritable[] array = new DoubleWritable[_k]; - for (int i = 0; i < array.length; i++) { - array[i] = new DoubleWritable(); + final List<DoubleWritable> result = new ArrayList<>(_k); + for (int i = 0; i < _k; i++) { + result.add(new DoubleWritable()); } - this._result = Arrays.asList(array); + this._result = result; - return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + return ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); } @Override public List<DoubleWritable> evaluate(DeferredObject[] dObj) throws HiveException { final double[] features = HiveUtils.asDoubleArray(dObj[0].get(), featuresOI, featureOI); - final double[] importanceList = HiveUtils.asDoubleArray(dObj[1].get(), importanceListOI, - importanceElemOI); + final double[] importanceList = + HiveUtils.asDoubleArray(dObj[1].get(), importanceListOI, importanceElemOI); Preconditions.checkNotNull(features, UDFArgumentException.class); Preconditions.checkNotNull(importanceList, UDFArgumentException.class); @@ -110,7 +111,8 @@ public final class SelectKBestUDF extends GenericUDF { int[] topKIndices = _topKIndices; if (topKIndices == null) { - final List<Map.Entry<Integer, Double>> list = new ArrayList<Map.Entry<Integer, Double>>(); + final List<Map.Entry<Integer, Double>> list = + new ArrayList<Map.Entry<Integer, Double>>(); for (int i = 0; i < importanceList.length; i++) { list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, importanceList[i])); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f6765dff/core/src/test/java/hivemall/TestUtils.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/TestUtils.java b/core/src/test/java/hivemall/TestUtils.java new file mode 100644 index 0000000..12d921e --- /dev/null +++ b/core/src/test/java/hivemall/TestUtils.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall; + +import java.io.ByteArrayOutputStream; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Utilities; +import org.apache.hive.com.esotericsoftware.kryo.Kryo; +import org.apache.hive.com.esotericsoftware.kryo.io.Input; +import org.apache.hive.com.esotericsoftware.kryo.io.Output; + +public final class TestUtils { + + @Nonnull + public static byte[] serializeObjectByKryo(@Nonnull Object obj) { + Kryo kryo = getKryo(); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + Output output = new Output(bos); + kryo.writeObject(output, obj); + output.close(); + return bos.toByteArray(); + } + + @Nonnull + public static <T> T deserializeObjectByKryo(@Nonnull byte[] in, @Nonnull Class<T> clazz) { + Kryo kryo = getKryo(); + Input inp = new Input(in); + T t = kryo.readObject(inp, clazz); + inp.close(); + return t; + } + + @Nonnull + private static Kryo getKryo() { + return Utilities.runtimeSerializationKryo.get(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f6765dff/core/src/test/java/hivemall/tools/array/SelectKBestUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/SelectKBestUDFTest.java b/core/src/test/java/hivemall/tools/array/SelectKBestUDFTest.java index 15366a7..49848af 100644 --- a/core/src/test/java/hivemall/tools/array/SelectKBestUDFTest.java +++ b/core/src/test/java/hivemall/tools/array/SelectKBestUDFTest.java @@ -18,10 +18,12 @@ */ package hivemall.tools.array; +import hivemall.TestUtils; import hivemall.utils.hadoop.WritableUtils; import java.util.List; +import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -37,8 +39,8 @@ public class SelectKBestUDFTest { public void test() throws Exception { final SelectKBestUDF selectKBest = new SelectKBestUDF(); final int k = 2; - final double[] data = new double[] {250.29999999999998, 170.90000000000003, 73.2, - 12.199999999999996}; + final double[] data = + new double[] {250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996}; final double[] importanceList = new double[] {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589}; @@ -48,8 +50,10 @@ public class SelectKBestUDFTest { new GenericUDF.DeferredJavaObject(k)}; selectKBest.initialize(new ObjectInspector[] { - ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), - ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaIntObjectInspector, k)}); final List<DoubleWritable> resultObj = selectKBest.evaluate(dObjs); @@ -66,4 +70,21 @@ public class SelectKBestUDFTest { Assert.assertArrayEquals(answer, result, 0.d); selectKBest.close(); } + + @Test + public void testSerialization() throws HiveException { + final SelectKBestUDF selectKBest = new SelectKBestUDF(); + final int k = 2; + selectKBest.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector, k)}); + + byte[] serialized = TestUtils.serializeObjectByKryo(selectKBest); + TestUtils.deserializeObjectByKryo(serialized, SelectKBestUDF.class); + } + }