http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/utils/lambda/ThrowingConsumer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lambda/ThrowingConsumer.java b/core/src/main/java/hivemall/utils/lambda/ThrowingConsumer.java new file mode 100644 index 0000000..7efd652 --- /dev/null +++ b/core/src/main/java/hivemall/utils/lambda/ThrowingConsumer.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.lambda; + +import java.util.function.Consumer; + +@FunctionalInterface +public interface ThrowingConsumer<T> extends Consumer<T> { + + @Override + default void accept(final T e) { + try { + accept0(e); + } catch (Throwable ex) { + Throwing.sneakyThrow(ex); + } + } + + void accept0(T e) throws Throwable; + +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/utils/math/FastMath.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/FastMath.java b/core/src/main/java/hivemall/utils/math/FastMath.java new file mode 100644 index 0000000..d27d6f8 --- /dev/null +++ b/core/src/main/java/hivemall/utils/math/FastMath.java @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.math; + +import hivemall.annotations.Experimental; + +@Experimental +public final class FastMath { + + private FastMath() {} + + @Deprecated + public static float sqrt(final float x) { + return x * invSqrt(x); + } + + @Deprecated + public static double sqrt(final double x) { + return x * invSqrt(x); + } + + /** + * https://en.wikipedia.org/wiki/Fast_inverse_square_root + */ + @Deprecated + public static float invSqrt(final float x) { + final float hx = 0.5f * x; + int i = 0x5f375a86 - (Float.floatToRawIntBits(x) >>> 1); + float y = Float.intBitsToFloat(i); + y *= (1.5f - hx * y * y); // pass 1 + y *= (1.5f - hx * y * y); // pass 2 + y *= (1.5f - hx * y * y); // pass 3 + //y *= (1.5f - hx * y * y); // pass 4 + // more pass for more accuracy + return y; + } + + /** + * https://en.wikipedia.org/wiki/Fast_inverse_square_root + */ + @Deprecated + public static double invSqrt(final double x) { + final double hx = 0.5d * x; + long i = 0x5fe6eb50c7b537a9L - (Double.doubleToRawLongBits(x) >>> 1); + double y = Double.longBitsToDouble(i); + y *= (1.5d - hx * y * y); // pass 1 + y *= (1.5d - hx * y * y); // pass 2 + y *= (1.5d - hx * y * y); // pass 3 + y *= (1.5d - hx * y * y); // pass 4 + // more pass for more accuracy + return y; + } + + public static double log(final double x) { + return JafamaMath.log(x); + } + + /** + * @return log(1+x) + */ + public static double log1p(final double x) { + return JafamaMath.log1p(x); + } + + /** + * https://martin.ankerl.com/2007/02/11/optimized-exponential-functions-for-java/ + * + * @return e^x + */ + public static double exp(final double x) { + return JafamaMath.exp(x); + } + + /** + * @return exp(x)-1 + */ + public static double expm1(final double x) { + return JafamaMath.expm1(x); + } + + public static double sigmoid(final double x) { + return 1 / (1 + exp(-x)); + } + + /* + * Copyright 2012-2015 Jeff Hain + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /* + * ============================================================================= + * Notice of fdlibm package this program is partially derived from: + * + * Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved. + * + * Developed at SunSoft, a Sun Microsystems, Inc. business. + * Permission to use, copy, modify, and distribute this + * software is freely granted, provided that this notice + * is preserved. + * ============================================================================= + */ + + /** + * Based on Jafama (https://github.com/jeffhain/jafama/) version 2.2. + */ + private static final class JafamaMath { + + static final double TWO_POW_52 = twoPow(52); + + /** + * Double.MIN_NORMAL since Java 6. + */ + static final double DOUBLE_MIN_NORMAL = Double.longBitsToDouble(0x0010000000000000L); // 2.2250738585072014E-308 + + // Not storing float/double mantissa size in constants, + // for 23 and 52 are shorter to read and more + // bitwise-explicit than some constant's name. + + static final int MIN_DOUBLE_EXPONENT = -1074; + static final int MAX_DOUBLE_EXPONENT = 1023; + + static final double LOG_2 = StrictMath.log(2.0); + + //-------------------------------------------------------------------------- + // CONSTANTS AND TABLES FOR EXP AND EXPM1 + //-------------------------------------------------------------------------- + + static final double EXP_OVERFLOW_LIMIT = Double.longBitsToDouble(0x40862E42FEFA39EFL); // 7.09782712893383973096e+02 + static final double EXP_UNDERFLOW_LIMIT = Double.longBitsToDouble(0xC0874910D52D3051L); // -7.45133219101941108420e+02 + static final int EXP_LO_DISTANCE_TO_ZERO_POT = 0; + static final int EXP_LO_DISTANCE_TO_ZERO = (1 << EXP_LO_DISTANCE_TO_ZERO_POT); + static final int EXP_LO_TAB_SIZE_POT = 11; + static final int EXP_LO_TAB_SIZE = (1 << EXP_LO_TAB_SIZE_POT) + 1; + static final int EXP_LO_TAB_MID_INDEX = ((EXP_LO_TAB_SIZE - 1) / 2); + static final int EXP_LO_INDEXING = EXP_LO_TAB_MID_INDEX / EXP_LO_DISTANCE_TO_ZERO; + static final int EXP_LO_INDEXING_DIV_SHIFT = EXP_LO_TAB_SIZE_POT - 1 + - EXP_LO_DISTANCE_TO_ZERO_POT; + + static final class MyTExp { + static final double[] expHiTab = new double[1 + (int) EXP_OVERFLOW_LIMIT + - (int) EXP_UNDERFLOW_LIMIT]; + static final double[] expLoPosTab = new double[EXP_LO_TAB_SIZE]; + static final double[] expLoNegTab = new double[EXP_LO_TAB_SIZE]; + + static { + init(); + } + + private static strictfp void init() { + for (int i = (int) EXP_UNDERFLOW_LIMIT; i <= (int) EXP_OVERFLOW_LIMIT; i++) { + expHiTab[i - (int) EXP_UNDERFLOW_LIMIT] = StrictMath.exp(i); + } + for (int i = 0; i < EXP_LO_TAB_SIZE; i++) { + // x: in [-EXPM1_DISTANCE_TO_ZERO,EXPM1_DISTANCE_TO_ZERO]. + double x = -EXP_LO_DISTANCE_TO_ZERO + i / (double) EXP_LO_INDEXING; + // exp(x) + expLoPosTab[i] = StrictMath.exp(x); + // 1-exp(-x), accurately computed + expLoNegTab[i] = -StrictMath.expm1(-x); + } + } + } + + //-------------------------------------------------------------------------- + // CONSTANTS AND TABLES FOR LOG AND LOG1P + //-------------------------------------------------------------------------- + + static final int LOG_BITS = 12; + static final int LOG_TAB_SIZE = (1 << LOG_BITS); + + static final class MyTLog { + static final double[] logXLogTab = new double[LOG_TAB_SIZE]; + static final double[] logXTab = new double[LOG_TAB_SIZE]; + static final double[] logXInvTab = new double[LOG_TAB_SIZE]; + + static { + init(); + } + + private static strictfp void init() { + for (int i = 0; i < LOG_TAB_SIZE; i++) { + // Exact to use inverse of tab size, since it is a power of two. + double x = 1 + i * (1.0 / LOG_TAB_SIZE); + logXLogTab[i] = StrictMath.log(x); + logXTab[i] = x; + logXInvTab[i] = 1 / x; + } + } + } + + /** + * @param value A double value. + * @return e^value. + */ + static double exp(final double value) { + // exp(x) = exp([x])*exp(y) + // with [x] the integer part of x, and y = x-[x] + // ===> + // We find an approximation of y, called z. + // ===> + // exp(x) = exp([x])*(exp(z)*exp(epsilon)) + // with epsilon = y - z + // ===> + // We have exp([x]) and exp(z) pre-computed in tables, we "just" have to compute exp(epsilon). + // + // We use the same indexing (cast to int) to compute x integer part and the + // table index corresponding to z, to avoid two int casts. + // Also, to optimize index multiplication and division, we use powers of two, + // so that we can do it with bits shifts. + + if (value > EXP_OVERFLOW_LIMIT) { + return Double.POSITIVE_INFINITY; + } else if (!(value >= EXP_UNDERFLOW_LIMIT)) { + return (value != value) ? Double.NaN : 0.0; + } + + final int indexes = (int) (value * EXP_LO_INDEXING); + + final int valueInt; + if (indexes >= 0) { + valueInt = (indexes >> EXP_LO_INDEXING_DIV_SHIFT); + } else { + valueInt = -((-indexes) >> EXP_LO_INDEXING_DIV_SHIFT); + } + final double hiTerm = MyTExp.expHiTab[valueInt - (int) EXP_UNDERFLOW_LIMIT]; + + final int zIndex = indexes - (valueInt << EXP_LO_INDEXING_DIV_SHIFT); + final double y = (value - valueInt); + final double z = zIndex * (1.0 / EXP_LO_INDEXING); + final double eps = y - z; + final double expZ = MyTExp.expLoPosTab[zIndex + EXP_LO_TAB_MID_INDEX]; + final double expEps = (1 + eps + * (1 + eps * (1.0 / 2 + eps * (1.0 / 6 + eps * (1.0 / 24))))); + final double loTerm = expZ * expEps; + + return hiTerm * loTerm; + } + + /** + * Much more accurate than exp(value)-1, for arguments (and results) close to zero. + * + * @param value A double value. + * @return e^value-1. + */ + static double expm1(final double value) { + // If value is far from zero, we use exp(value)-1. + // + // If value is close to zero, we use the following formula: + // exp(value)-1 + // = exp(valueApprox)*exp(epsilon)-1 + // = exp(valueApprox)*(exp(epsilon)-exp(-valueApprox)) + // = exp(valueApprox)*(1+epsilon+epsilon^2/2!+...-exp(-valueApprox)) + // = exp(valueApprox)*((1-exp(-valueApprox))+epsilon+epsilon^2/2!+...) + // exp(valueApprox) and exp(-valueApprox) being stored in tables. + + if (Math.abs(value) < EXP_LO_DISTANCE_TO_ZERO) { + // Taking int part instead of rounding, which takes too long. + int i = (int) (value * EXP_LO_INDEXING); + double delta = value - i * (1.0 / EXP_LO_INDEXING); + return MyTExp.expLoPosTab[i + EXP_LO_TAB_MID_INDEX] + * (MyTExp.expLoNegTab[i + EXP_LO_TAB_MID_INDEX] + delta + * (1 + delta + * (1.0 / 2 + delta + * (1.0 / 6 + delta + * (1.0 / 24 + delta * (1.0 / 120)))))); + } else { + return exp(value) - 1; + } + } + + /** + * @param value A double value. + * @return Value logarithm (base e). + */ + static double log(double value) { + if (value > 0.0) { + if (value == Double.POSITIVE_INFINITY) { + return Double.POSITIVE_INFINITY; + } + + // For normal values not close to 1.0, we use the following formula: + // log(value) + // = log(2^exponent*1.mantissa) + // = log(2^exponent) + log(1.mantissa) + // = exponent * log(2) + log(1.mantissa) + // = exponent * log(2) + log(1.mantissaApprox) + log(1.mantissa/1.mantissaApprox) + // = exponent * log(2) + log(1.mantissaApprox) + log(1+epsilon) + // = exponent * log(2) + log(1.mantissaApprox) + epsilon-epsilon^2/2+epsilon^3/3-epsilon^4/4+... + // with: + // 1.mantissaApprox <= 1.mantissa, + // log(1.mantissaApprox) in table, + // epsilon = (1.mantissa/1.mantissaApprox)-1 + // + // To avoid bad relative error for small results, + // values close to 1.0 are treated aside, with the formula: + // log(x) = z*(2+z^2*((2.0/3)+z^2*((2.0/5))+z^2*((2.0/7))+...))) + // with z=(x-1)/(x+1) + + double h; + if (value > 0.95) { + if (value < 1.14) { + double z = (value - 1.0) / (value + 1.0); + double z2 = z * z; + return z + * (2 + z2 + * ((2.0 / 3) + z2 + * ((2.0 / 5) + z2 + * ((2.0 / 7) + z2 + * ((2.0 / 9) + z2 * ((2.0 / 11))))))); + } + h = 0.0; + } else if (value < DOUBLE_MIN_NORMAL) { + // Ensuring value is normal. + value *= TWO_POW_52; + // log(x*2^52) + // = log(x)-ln(2^52) + // = log(x)-52*ln(2) + h = -52 * LOG_2; + } else { + h = 0.0; + } + + int valueBitsHi = (int) (Double.doubleToRawLongBits(value) >> 32); + int valueExp = (valueBitsHi >> 20) - MAX_DOUBLE_EXPONENT; + // Getting the first LOG_BITS bits of the mantissa. + int xIndex = ((valueBitsHi << 12) >>> (32 - LOG_BITS)); + + // 1.mantissa/1.mantissaApprox - 1 + double z = (value * twoPowNormalOrSubnormal(-valueExp)) * MyTLog.logXInvTab[xIndex] + - 1; + + z *= (1 - z * ((1.0 / 2) - z * ((1.0 / 3)))); + + return h + valueExp * LOG_2 + (MyTLog.logXLogTab[xIndex] + z); + + } else if (value == 0.0) { + return Double.NEGATIVE_INFINITY; + } else { // value < 0.0, or value is NaN + return Double.NaN; + } + } + + /** + * Much more accurate than log(1+value), for arguments (and results) close to zero. + * + * @param value A double value. + * @return Logarithm (base e) of (1+value). + */ + static double log1p(final double value) { + if (value > -1.0) { + if (value == Double.POSITIVE_INFINITY) { + return Double.POSITIVE_INFINITY; + } + + // ln'(x) = 1/x + // so + // log(x+epsilon) ~= log(x) + epsilon/x + // + // Let u be 1+value rounded: + // 1+value = u+epsilon + // + // log(1+value) + // = log(u+epsilon) + // ~= log(u) + epsilon/value + // We compute log(u) as done in log(double), and then add the corrective term. + + double valuePlusOne = 1.0 + value; + if (valuePlusOne == 1.0) { + return value; + } else if (Math.abs(value) < 0.15) { + double z = value / (value + 2.0); + double z2 = z * z; + return z + * (2 + z2 + * ((2.0 / 3) + z2 + * ((2.0 / 5) + z2 + * ((2.0 / 7) + z2 + * ((2.0 / 9) + z2 * ((2.0 / 11))))))); + } + + int valuePlusOneBitsHi = (int) (Double.doubleToRawLongBits(valuePlusOne) >> 32) & 0x7FFFFFFF; + int valuePlusOneExp = (valuePlusOneBitsHi >> 20) - MAX_DOUBLE_EXPONENT; + // Getting the first LOG_BITS bits of the mantissa. + int xIndex = ((valuePlusOneBitsHi << 12) >>> (32 - LOG_BITS)); + + // 1.mantissa/1.mantissaApprox - 1 + double z = (valuePlusOne * twoPowNormalOrSubnormal(-valuePlusOneExp)) + * MyTLog.logXInvTab[xIndex] - 1; + + z *= (1 - z * ((1.0 / 2) - z * (1.0 / 3))); + + // Adding epsilon/valuePlusOne to z, + // with + // epsilon = value - (valuePlusOne-1) + // (valuePlusOne + epsilon ~= 1+value (not rounded)) + + return valuePlusOneExp * LOG_2 + MyTLog.logXLogTab[xIndex] + + (z + (value - (valuePlusOne - 1)) / valuePlusOne); + } else if (value == -1.0) { + return Double.NEGATIVE_INFINITY; + } else { // value < -1.0, or value is NaN + return Double.NaN; + } + } + + /** + * @param power Must be in normal or subnormal values range. + */ + private static double twoPowNormalOrSubnormal(final int power) { + if (power <= -MAX_DOUBLE_EXPONENT) { // Not normal. + return Double.longBitsToDouble(0x0008000000000000L >> (-(power + MAX_DOUBLE_EXPONENT))); + } else { // Normal. + return Double.longBitsToDouble(((long) (power + MAX_DOUBLE_EXPONENT)) << 52); + } + } + + /** + * Returns the exact result, provided it's in double range, i.e. if power is in + * [-1074,1023]. + * + * @param power An int power. + * @return 2^power as a double, or +-Infinity in case of overflow. + */ + private static double twoPow(final int power) { + if (power <= -MAX_DOUBLE_EXPONENT) { // Not normal. + if (power >= MIN_DOUBLE_EXPONENT) { // Subnormal. + return Double.longBitsToDouble(0x0008000000000000L >> (-(power + MAX_DOUBLE_EXPONENT))); + } else { // Underflow. + return 0.0; + } + } else if (power > MAX_DOUBLE_EXPONENT) { // Overflow. + return Double.POSITIVE_INFINITY; + } else { // Normal. + return Double.longBitsToDouble(((long) (power + MAX_DOUBLE_EXPONENT)) << 52); + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index 71d4c29..dd9e892 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -16,22 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -// -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// package hivemall.utils.math; import java.util.Random; @@ -434,8 +418,10 @@ public final class MathUtils { if (sum == 0.d) { return new float[size]; } + // floating point multiplication is faster than division + final double multiplier = 1.d / sum; for (int i = 0; i < size; i++) { - arr[i] /= sum; + arr[i] *= multiplier; } return arr; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java index 6d053de..479f5cf 100644 --- a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java @@ -54,6 +54,8 @@ public class FactorizationMachineUDTFTest { udtf.initialize(argOIs); FactorizationMachineModel model = udtf.initModel(udtf._params); + + Assert.assertFalse(udtf._params.l2norm); Assert.assertTrue("Actual class: " + model.getClass().getName(), model instanceof FMStringFeatureMapModel); @@ -85,6 +87,21 @@ public class FactorizationMachineUDTFTest { Assert.assertTrue("Loss was greater than 0.1: " + loss, loss <= 0.1); } + @Test + public void testEnableL2Norm() throws HiveException, IOException { + FactorizationMachineUDTF udtf = new FactorizationMachineUDTF(); + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-factors 5 -min 1 -max 5 -iters 1 -init_v gaussian -eta0 0.01 -seed 31 -l2norm")}; + + udtf.initialize(argOIs); + udtf.initModel(udtf._params); + Assert.assertTrue(udtf._params.l2norm); + } + @Nonnull private static BufferedReader readFile(@Nonnull String fileName) throws IOException { InputStream is = FactorizationMachineUDTFTest.class.getResourceAsStream(fileName); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/fm/FeatureTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FeatureTest.java b/core/src/test/java/hivemall/fm/FeatureTest.java index 911a4a5..24ef0d8 100644 --- a/core/src/test/java/hivemall/fm/FeatureTest.java +++ b/core/src/test/java/hivemall/fm/FeatureTest.java @@ -18,6 +18,8 @@ */ package hivemall.fm; +import hivemall.utils.hashing.MurmurHash3; + import org.apache.hadoop.hive.ql.metadata.HiveException; import org.junit.Assert; import org.junit.Test; @@ -41,17 +43,13 @@ public class FeatureTest { Assert.assertEquals(0.3651d, f1.getValue(), 0.d); } - @Test - public void testParseQuantitativeFFMFeature() throws HiveException { - IntFeature f1 = Feature.parseFFMFeature("163:0.3651"); - Assert.assertEquals(163, f1.getField()); - Assert.assertEquals(163, f1.getFeatureIndex()); - Assert.assertEquals("163", f1.getFeature()); - Assert.assertEquals(0.3651d, f1.getValue(), 0.d); + @Test(expected = HiveException.class) + public void testParseQuantitativeFFMFeatureFails1() throws HiveException { + Feature.parseFFMFeature("163:0.3651"); } @Test(expected = HiveException.class) - public void testParseQuantitativeFFMFeatureFails() throws HiveException { + public void testParseQuantitativeFFMFeatureFails2() throws HiveException { Feature.parseFFMFeature("1163:0.3651"); } @@ -63,15 +61,19 @@ public class FeatureTest { Assert.assertEquals(0.3652d, probe.getValue(), 0.d); } + @Test public void testParseFFMFeatureProbe() throws HiveException { - IntFeature probe = Feature.parseFFMFeature("dummyFeature:dummyField:-1"); - Feature.parseFFMFeature("2:1163:0.3651", probe); + IntFeature probe = Feature.parseFFMFeature("dummyField:dummyFeature:-1"); + Assert.assertEquals(MurmurHash3.murmurhash3("dummyFeature", Feature.DEFAULT_NUM_FEATURES) + + Feature.DEFAULT_NUM_FIELDS, probe.getFeatureIndex()); + Feature.parseFFMFeature("2:1163:0.3651", probe, -1, Feature.DEFAULT_NUM_FIELDS); Assert.assertEquals(2, probe.getField()); Assert.assertEquals(1163, probe.getFeatureIndex()); Assert.assertEquals("1163", probe.getFeature()); Assert.assertEquals(0.3651d, probe.getValue(), 0.d); } + @Test public void testParseIntFeature() throws HiveException { Feature f = Feature.parseFeature("1163:0.3651", true); Assert.assertTrue(f instanceof IntFeature); @@ -90,4 +92,57 @@ public class FeatureTest { Feature.parseFFMFeature("0:0.3652"); } + @Test + public void testFFMFeatureL2Normalization() throws HiveException { + Feature[] features = new Feature[9]; + // (0, 0, 1, 1, 0, 1, 0, 1, 0) + features[0] = Feature.parseFFMFeature("11:1:0", -1); + features[1] = Feature.parseFFMFeature("22:2:0", -1); + features[2] = Feature.parseFFMFeature("33:3:1", -1); + features[3] = Feature.parseFFMFeature("44:4:1", -1); + features[4] = Feature.parseFFMFeature("55:5:0", -1); + features[5] = Feature.parseFFMFeature("66:6:1", -1); + features[6] = Feature.parseFFMFeature("77:7:0", -1); + features[7] = Feature.parseFFMFeature("88:8:1", -1); + features[8] = Feature.parseFFMFeature("99:9:0", -1); + Assert.assertEquals(features[0].getField(), 11); + Assert.assertEquals(features[1].getField(), 22); + Assert.assertEquals(features[2].getField(), 33); + Assert.assertEquals(features[3].getField(), 44); + Assert.assertEquals(features[4].getField(), 55); + Assert.assertEquals(features[5].getField(), 66); + Assert.assertEquals(features[6].getField(), 77); + Assert.assertEquals(features[7].getField(), 88); + Assert.assertEquals(features[8].getField(), 99); + Assert.assertEquals(features[0].getFeatureIndex(), 1); + Assert.assertEquals(features[1].getFeatureIndex(), 2); + Assert.assertEquals(features[2].getFeatureIndex(), 3); + Assert.assertEquals(features[3].getFeatureIndex(), 4); + Assert.assertEquals(features[4].getFeatureIndex(), 5); + Assert.assertEquals(features[5].getFeatureIndex(), 6); + Assert.assertEquals(features[6].getFeatureIndex(), 7); + Assert.assertEquals(features[7].getFeatureIndex(), 8); + Assert.assertEquals(features[8].getFeatureIndex(), 9); + Assert.assertEquals(0.d, features[0].value, 1E-15); + Assert.assertEquals(0.d, features[1].value, 1E-15); + Assert.assertEquals(1.d, features[2].value, 1E-15); + Assert.assertEquals(1.d, features[3].value, 1E-15); + Assert.assertEquals(0.d, features[4].value, 1E-15); + Assert.assertEquals(1.d, features[5].value, 1E-15); + Assert.assertEquals(0.d, features[6].value, 1E-15); + Assert.assertEquals(1.d, features[7].value, 1E-15); + Assert.assertEquals(0.d, features[8].value, 1E-15); + Feature.l2normalize(features); + // (0, 0, 0.5, 0.5, 0, 0.5, 0, 0.5, 0) + Assert.assertEquals(0.d, features[0].value, 1E-15); + Assert.assertEquals(0.d, features[1].value, 1E-15); + Assert.assertEquals(0.5d, features[2].value, 1E-15); + Assert.assertEquals(0.5d, features[3].value, 1E-15); + Assert.assertEquals(0.d, features[4].value, 1E-15); + Assert.assertEquals(0.5d, features[5].value, 1E-15); + Assert.assertEquals(0.d, features[6].value, 1E-15); + Assert.assertEquals(0.5d, features[7].value, 1E-15); + Assert.assertEquals(0.d, features[8].value, 1E-15); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java index 5b54b1e..585392b 100644 --- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java @@ -18,6 +18,8 @@ */ package hivemall.fm; +import hivemall.utils.lang.NumberUtils; + import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; @@ -37,8 +39,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.junit.Assert; import org.junit.Test; -import hivemall.utils.lang.NumberUtils; - public class FieldAwareFactorizationMachineUDTFTest { private static final boolean DEBUG = false; @@ -85,7 +85,14 @@ public class FieldAwareFactorizationMachineUDTFTest { public void testSample() throws IOException, HiveException { run("[Sample.ffm] default option", "https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz", - "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43", 0.1f); + "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43", 0.01f); + } + + // TODO @Test + public void testSampleEnableNorm() throws IOException, HiveException { + run("[Sample.ffm] default option", + "https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz", + "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43 -enable_norm", 0.01f); } private static void run(String testName, String testFile, String testOptions, @@ -104,7 +111,7 @@ public class FieldAwareFactorizationMachineUDTFTest { Assert.assertTrue("Actual class: " + model.getClass().getName(), model instanceof FFMStringFeatureMapModel); - + int lines = 0; BufferedReader data = readFile(testFile); while (true) { //gather features in current line @@ -112,6 +119,7 @@ public class FieldAwareFactorizationMachineUDTFTest { if (input == null) { break; } + lines++; String[] featureStrings = input.split(" "); double y = Double.parseDouble(featureStrings[0]); @@ -140,7 +148,7 @@ public class FieldAwareFactorizationMachineUDTFTest { println("model size=" + udtf._model.getSize()); - double avgLoss = udtf._cvState.getCumulativeLoss() / udtf._t; + double avgLoss = udtf._cvState.getAverageLoss(lines); Assert.assertTrue("Last loss was greater than expected: " + avgLoss, avgLoss < lossThreshold); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashTableTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashTableTest.java deleted file mode 100644 index 53814ac..0000000 --- a/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashTableTest.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections.maps; - -import hivemall.utils.collections.maps.Int2FloatOpenHashTable; - -import org.junit.Assert; -import org.junit.Test; - -public class Int2FloatOpenHashTableTest { - - @Test - public void testSize() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384); - map.put(1, 3.f); - Assert.assertEquals(3.f, map.get(1), 0.d); - map.put(1, 5.f); - Assert.assertEquals(5.f, map.get(1), 0.d); - Assert.assertEquals(1, map.size()); - } - - @Test - public void testDefaultReturnValue() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384); - Assert.assertEquals(0, map.size()); - Assert.assertEquals(-1.f, map.get(1), 0.d); - float ret = Float.MIN_VALUE; - map.defaultReturnValue(ret); - Assert.assertEquals(ret, map.get(1), 0.d); - } - - @Test - public void testPutAndGet() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Float v = map.get(i); - Assert.assertEquals(i + 0.1f, v.floatValue(), 0.d); - } - } - - @Test - public void testIterator() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(1000); - Int2FloatOpenHashTable.IMapIterator itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1.f, map.put(i, Float.valueOf(i + 0.1f)), 0.d); - } - Assert.assertEquals(numEntries, map.size()); - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - int k = itor.getKey(); - Float v = itor.getValue(); - Assert.assertEquals(k + 0.1f, v.floatValue(), 0.d); - } - Assert.assertEquals(-1, itor.next()); - } - - @Test - public void testIterator2() { - Int2FloatOpenHashTable map = new Int2FloatOpenHashTable(100); - map.put(33, 3.16f); - - Int2FloatOpenHashTable.IMapIterator itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - Assert.assertNotEquals(-1, itor.next()); - Assert.assertEquals(33, itor.getKey()); - Assert.assertEquals(3.16f, itor.getValue(), 0.d); - Assert.assertEquals(-1, itor.next()); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java deleted file mode 100644 index ee36a83..0000000 --- a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections.maps; - -import org.junit.Assert; -import org.junit.Test; - -public class Int2LongOpenHashMapTest { - - @Test - public void testSize() { - Int2LongOpenHashMap map = new Int2LongOpenHashMap(16384); - map.put(1, 3L); - Assert.assertEquals(3L, map.get(1)); - map.put(1, 5L); - Assert.assertEquals(5L, map.get(1)); - Assert.assertEquals(1, map.size()); - } - - @Test - public void testDefaultReturnValue() { - Int2LongOpenHashMap map = new Int2LongOpenHashMap(16384); - Assert.assertEquals(0, map.size()); - Assert.assertEquals(0L, map.get(1)); - Assert.assertEquals(Long.MIN_VALUE, map.get(1, Long.MIN_VALUE)); - } - - @Test - public void testPutAndGet() { - Int2LongOpenHashMap map = new Int2LongOpenHashMap(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(0L, map.put(i, i)); - Assert.assertEquals(0L, map.put(-i, -i)); - } - Assert.assertEquals(numEntries * 2 - 1, map.size()); - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(i, map.get(i)); - Assert.assertEquals(-i, map.get(-i)); - } - } - - @Test - public void testPutRemoveGet() { - Int2LongOpenHashMap map = new Int2LongOpenHashMap(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(0L, map.put(i, i)); - Assert.assertEquals(0L, map.put(-i, -i)); - if (i % 2 == 0) { - Assert.assertEquals(i, map.remove(i, -1)); - } else { - Assert.assertEquals(i, map.put(i, i)); - } - } - Assert.assertEquals(numEntries + (numEntries / 2) - 1, map.size()); - for (int i = 0; i < numEntries; i++) { - if (i % 2 == 0) { - Assert.assertFalse(map.containsKey(i)); - } else { - Assert.assertEquals(i, map.get(i)); - } - Assert.assertEquals(-i, map.get(-i)); - } - } - - @Test - public void testIterator() { - Int2LongOpenHashMap map = new Int2LongOpenHashMap(1000); - Int2LongOpenHashMap.MapIterator itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(0L, map.put(i, i)); - Assert.assertEquals(0L, map.put(-i, -i)); - } - Assert.assertEquals(numEntries * 2 - 1, map.size()); - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertTrue(itor.next()); - int k = itor.getKey(); - long v = itor.getValue(); - Assert.assertEquals(k, v); - } - Assert.assertFalse(itor.next()); - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashTableTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashTableTest.java deleted file mode 100644 index c2ce132..0000000 --- a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashTableTest.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections.maps; - -import hivemall.utils.lang.ObjectUtils; - -import java.io.IOException; - -import org.junit.Assert; -import org.junit.Test; - -public class Int2LongOpenHashTableTest { - - @Test - public void testSize() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); - map.put(1, 3L); - Assert.assertEquals(3L, map.get(1)); - map.put(1, 5L); - Assert.assertEquals(5L, map.get(1)); - Assert.assertEquals(1, map.size()); - } - - @Test - public void testDefaultReturnValue() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); - Assert.assertEquals(0, map.size()); - Assert.assertEquals(-1L, map.get(1)); - long ret = Long.MIN_VALUE; - map.defaultReturnValue(ret); - Assert.assertEquals(ret, map.get(1)); - } - - @Test - public void testPutAndGet() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1L, map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - long v = map.get(i); - Assert.assertEquals(i, v); - } - } - - @Test - public void testPutRemoveGet() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); - map.defaultReturnValue(0L); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(0L, map.put(i, i)); - Assert.assertEquals(0L, map.put(-i, -i)); - if (i % 2 == 0) { - Assert.assertEquals(i, map.remove(i)); - } else { - Assert.assertEquals(i, map.put(i, i)); - } - } - Assert.assertEquals(numEntries + (numEntries / 2) - 1, map.size()); - for (int i = 0; i < numEntries; i++) { - if (i % 2 == 0) { - Assert.assertFalse(map.containsKey(i)); - } else { - Assert.assertEquals(i, map.get(i)); - } - Assert.assertEquals(-i, map.get(-i)); - } - } - - @Test - public void testSerde() throws IOException, ClassNotFoundException { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1L, map.put(i, i)); - } - - byte[] b = ObjectUtils.toCompressedBytes(map); - map = new Int2LongOpenHashTable(16384); - ObjectUtils.readCompressedObject(b, map); - - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - long v = map.get(i); - Assert.assertEquals(i, v); - } - } - - @Test - public void testIterator() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(1000); - Int2LongOpenHashTable.MapIterator itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1L, map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - int k = itor.getKey(); - long v = itor.getValue(); - Assert.assertEquals(k, v); - } - Assert.assertEquals(-1, itor.next()); - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java deleted file mode 100644 index 46a3938..0000000 --- a/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections.maps; - -import hivemall.utils.collections.maps.IntOpenHashTable; - -import org.junit.Assert; -import org.junit.Test; - -public class IntOpenHashTableTest { - - @Test - public void testSize() { - IntOpenHashTable<Float> map = new IntOpenHashTable<Float>(16384); - map.put(1, Float.valueOf(3.f)); - Assert.assertEquals(Float.valueOf(3.f), map.get(1)); - map.put(1, Float.valueOf(5.f)); - Assert.assertEquals(Float.valueOf(5.f), map.get(1)); - Assert.assertEquals(1, map.size()); - } - - @Test - public void testPutAndGet() { - IntOpenHashTable<Integer> map = new IntOpenHashTable<Integer>(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertNull(map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Integer v = map.get(i); - Assert.assertEquals(i, v.intValue()); - } - } - - @Test - public void testIterator() { - IntOpenHashTable<Integer> map = new IntOpenHashTable<Integer>(1000); - IntOpenHashTable.IMapIterator<Integer> itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertNull(map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - int k = itor.getKey(); - Integer v = itor.getValue(); - Assert.assertEquals(k, v.intValue()); - } - Assert.assertEquals(-1, itor.next()); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTableTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTableTest.java new file mode 100644 index 0000000..a47c373 --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTableTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import hivemall.utils.lang.ObjectUtils; + +import java.io.IOException; + +import org.junit.Assert; +import org.junit.Test; + +public class Long2DoubleOpenHashTableTest { + + @Test + public void testSize() { + Long2DoubleOpenHashTable map = new Long2DoubleOpenHashTable(16384); + map.put(1L, 3); + Assert.assertEquals(3, map.get(1L), 1E-15); + map.put(1L, 5); + Assert.assertEquals(5, map.get(1L), 1E-15); + Assert.assertEquals(1, map.size()); + } + + @Test + public void testDefaultReturnValue() { + Long2DoubleOpenHashTable map = new Long2DoubleOpenHashTable(16384); + map.defaultReturnValue(-1); + Assert.assertEquals(0, map.size()); + Assert.assertEquals(-1, map.get(1L), 1E-15); + int ret = Integer.MAX_VALUE; + map.defaultReturnValue(ret); + Assert.assertEquals(ret, map.get(1L), 1E-15); + } + + @Test + public void testPutAndGet() { + Long2DoubleOpenHashTable map = new Long2DoubleOpenHashTable(16384); + map.defaultReturnValue(-1); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1L, map.put(i, i), 1E-15); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.get(i), 1E-15); + } + + map.clear(); + int i = 0; + for (long j = 1L + Integer.MAX_VALUE; i < 10000; j += 99L, i++) { + map.put(j, i); + } + Assert.assertEquals(i, map.size()); + i = 0; + for (long j = 1L + Integer.MAX_VALUE; i < 10000; j += 99L, i++) { + Assert.assertEquals(i, map.get(j), 1E-15); + } + } + + @Test + public void testSerde() throws IOException, ClassNotFoundException { + Long2DoubleOpenHashTable map = new Long2DoubleOpenHashTable(16384); + map.defaultReturnValue(-1); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.put(i, i), 1E-15); + } + + byte[] b = ObjectUtils.toCompressedBytes(map); + map = new Long2DoubleOpenHashTable(16384); + ObjectUtils.readCompressedObject(b, map); + + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.get(i), 1E-15); + } + } + + @Test + public void testIterator() { + Long2DoubleOpenHashTable map = new Long2DoubleOpenHashTable(1000); + map.defaultReturnValue(-1); + Long2DoubleOpenHashTable.IMapIterator itor = map.entries(); + Assert.assertFalse(itor.hasNext()); + + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.put(i, i), 1E-15); + } + Assert.assertEquals(numEntries, map.size()); + + itor = map.entries(); + Assert.assertTrue(itor.hasNext()); + while (itor.hasNext()) { + Assert.assertFalse(itor.next() == -1); + long k = itor.getKey(); + double v = itor.getValue(); + Assert.assertEquals(k, v, 1E-15); + } + Assert.assertEquals(-1, itor.next()); + } + + @Test + public void testPutRemoveGet() { + Long2DoubleOpenHashTable map = new Long2DoubleOpenHashTable(16384); + map.defaultReturnValue(-1); + map.defaultReturnValue(-2); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-2, map.put(i, i), 1E-15); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.remove(i), 1E-15); + } + Assert.assertEquals(0, map.size()); + map.defaultReturnValue(-1); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.get(i), 1E-15); + } + map.put(1, Integer.MAX_VALUE); + Assert.assertEquals(Integer.MAX_VALUE, map.get(1), 1E-15); + } + + @Test + public void testPutRemoveGet2() { + Long2DoubleOpenHashTable map = new Long2DoubleOpenHashTable(16384); + map.defaultReturnValue(-1); + map.defaultReturnValue(-2); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-2, map.put(i, i), 1E-15); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.remove(i), 1E-15); + } + Assert.assertEquals(0, map.size()); + map.defaultReturnValue(-1); + for (int i = numEntries, len = numEntries + (numEntries / 2); i < len; i++) { + Assert.assertEquals(-1, map.put(i, i), 1E-15); + } + Assert.assertEquals(numEntries / 2, map.size()); + for (int i = numEntries, len = numEntries + (numEntries / 2); i < len; i++) { + Assert.assertEquals(i, map.get(i), 1E-15); + } + for (int i = numEntries + (numEntries / 2), j = 0; j < numEntries; i++, j++) { + Assert.assertEquals(-1, map.put(i, i), 1E-15); + } + for (int i = numEntries + (numEntries / 2), j = 0; j < numEntries; i++, j++) { + Assert.assertEquals(i, map.get(i), 1E-15); + } + } + + @Test + public void testShrink() { + Long2DoubleOpenHashTable map = new Long2DoubleOpenHashTable(16384); + map.defaultReturnValue(-1); + final int numEntries = 65536; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.put(i, i), 1E-15); + Assert.assertEquals(i, map.remove(i), 1E-15); + } + Assert.assertEquals(0, map.size()); + map.defaultReturnValue(-2); + for (int i = 0, len = 2 * numEntries; i < len; i++) { + Assert.assertEquals(-2, map.put(i, i), 1E-15); + } + Assert.assertEquals(numEntries * 2, map.size()); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/collections/maps/Long2FloatOpenHashTableTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/Long2FloatOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/Long2FloatOpenHashTableTest.java new file mode 100644 index 0000000..c1f74a6 --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/maps/Long2FloatOpenHashTableTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import hivemall.utils.lang.ObjectUtils; + +import java.io.IOException; + +import org.junit.Assert; +import org.junit.Test; + +public class Long2FloatOpenHashTableTest { + + @Test + public void testSize() { + Long2FloatOpenHashTable map = new Long2FloatOpenHashTable(16384); + map.put(1L, 3); + Assert.assertEquals(3, map.get(1L), 1E-6f); + map.put(1L, 5); + Assert.assertEquals(5, map.get(1L), 1E-6f); + Assert.assertEquals(1, map.size()); + } + + @Test + public void testDefaultReturnValue() { + Long2FloatOpenHashTable map = new Long2FloatOpenHashTable(16384); + map.defaultReturnValue(-1); + Assert.assertEquals(0, map.size()); + Assert.assertEquals(-1, map.get(1L), 1E-6f); + int ret = Integer.MAX_VALUE; + map.defaultReturnValue(ret); + Assert.assertEquals(ret, map.get(1L), 1E-6f); + } + + @Test + public void testPutAndGet() { + Long2FloatOpenHashTable map = new Long2FloatOpenHashTable(16384); + map.defaultReturnValue(-1); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1L, map.put(i, i), 1E-6f); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.get(i), 1E-6f); + } + + map.clear(); + int i = 0; + for (long j = 1L + Integer.MAX_VALUE; i < 10000; j += 99L, i++) { + map.put(j, i); + } + Assert.assertEquals(i, map.size()); + i = 0; + for (long j = 1L + Integer.MAX_VALUE; i < 10000; j += 99L, i++) { + Assert.assertEquals(i, map.get(j), 1E-6f); + } + } + + @Test + public void testSerde() throws IOException, ClassNotFoundException { + Long2FloatOpenHashTable map = new Long2FloatOpenHashTable(16384); + map.defaultReturnValue(-1); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.put(i, i), 1E-6f); + } + + byte[] b = ObjectUtils.toCompressedBytes(map); + map = new Long2FloatOpenHashTable(16384); + ObjectUtils.readCompressedObject(b, map); + + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.get(i), 1E-6f); + } + } + + @Test + public void testIterator() { + Long2FloatOpenHashTable map = new Long2FloatOpenHashTable(1000); + map.defaultReturnValue(-1); + Long2FloatOpenHashTable.IMapIterator itor = map.entries(); + Assert.assertFalse(itor.hasNext()); + + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.put(i, i), 1E-6f); + } + Assert.assertEquals(numEntries, map.size()); + + itor = map.entries(); + Assert.assertTrue(itor.hasNext()); + while (itor.hasNext()) { + Assert.assertFalse(itor.next() == -1); + long k = itor.getKey(); + float v = itor.getValue(); + Assert.assertEquals(k, v, 1E-6f); + } + Assert.assertEquals(-1, itor.next()); + } + + @Test + public void testPutRemoveGet() { + Long2FloatOpenHashTable map = new Long2FloatOpenHashTable(16384); + map.defaultReturnValue(-1); + map.defaultReturnValue(-2); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-2, map.put(i, i), 1E-6f); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.remove(i), 1E-6f); + } + Assert.assertEquals(0, map.size()); + map.defaultReturnValue(-1); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.get(i), 1E-6f); + } + map.put(1, Integer.MAX_VALUE); + Assert.assertEquals(Integer.MAX_VALUE, map.get(1), 1E-6f); + } + + @Test + public void testPutRemoveGet2() { + Long2FloatOpenHashTable map = new Long2FloatOpenHashTable(16384); + map.defaultReturnValue(-1); + map.defaultReturnValue(-2); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-2, map.put(i, i), 1E-6f); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.remove(i), 1E-6f); + } + Assert.assertEquals(0, map.size()); + map.defaultReturnValue(-1); + for (int i = numEntries, len = numEntries + (numEntries / 2); i < len; i++) { + Assert.assertEquals(-1, map.put(i, i), 1E-6f); + } + Assert.assertEquals(numEntries / 2, map.size()); + for (int i = numEntries, len = numEntries + (numEntries / 2); i < len; i++) { + Assert.assertEquals(i, map.get(i), 1E-6f); + } + for (int i = numEntries + (numEntries / 2), j = 0; j < numEntries; i++, j++) { + Assert.assertEquals(-1, map.put(i, i), 1E-6f); + } + for (int i = numEntries + (numEntries / 2), j = 0; j < numEntries; i++, j++) { + Assert.assertEquals(i, map.get(i), 1E-6f); + } + } + + @Test + public void testShrink() { + Long2FloatOpenHashTable map = new Long2FloatOpenHashTable(16384); + map.defaultReturnValue(-1); + final int numEntries = 65536; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.put(i, i), 1E-6f); + Assert.assertEquals(i, map.remove(i), 1E-6f); + } + Assert.assertEquals(0, map.size()); + map.defaultReturnValue(-2); + for (int i = 0, len = 2 * numEntries; i < len; i++) { + Assert.assertEquals(-2, map.put(i, i), 1E-6f); + } + Assert.assertEquals(numEntries * 2, map.size()); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashTableTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashTableTest.java index ca43383..1b1503e 100644 --- a/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashTableTest.java +++ b/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashTableTest.java @@ -18,7 +18,6 @@ */ package hivemall.utils.collections.maps; -import hivemall.utils.collections.maps.Long2IntOpenHashTable; import hivemall.utils.lang.ObjectUtils; import java.io.IOException; @@ -73,6 +72,34 @@ public class Long2IntOpenHashTableTest { } @Test + public void testIncr() { + Long2IntOpenHashTable map = new Long2IntOpenHashTable(16384); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1L, map.put(i, i)); + if (i % 2 == 0) { + Assert.assertEquals(i, map.remove(i)); + } + } + Assert.assertEquals(numEntries / 2, map.size()); + for (int i = 0; i < numEntries; i++) { + if (i % 2 == 0) { + Assert.assertEquals(-1, map.incr(i, 10)); + } else { + Assert.assertEquals(i, map.incr(i, 10)); + } + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + if (i % 2 == 0) { + Assert.assertEquals(10, map.get(i)); + } else { + Assert.assertEquals(i + 10, map.get(i)); + } + } + } + + @Test public void testSerde() throws IOException, ClassNotFoundException { Long2IntOpenHashTable map = new Long2IntOpenHashTable(16384); final int numEntries = 1000000; @@ -112,4 +139,70 @@ public class Long2IntOpenHashTableTest { } Assert.assertEquals(-1, itor.next()); } + + @Test + public void testPutRemoveGet() { + Long2IntOpenHashTable map = new Long2IntOpenHashTable(16384); + map.defaultReturnValue(-2); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-2, map.put(i, i)); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.remove(i)); + } + Assert.assertEquals(0, map.size()); + map.defaultReturnValue(-1); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.get(i)); + } + map.put(1, Integer.MAX_VALUE); + Assert.assertEquals(Integer.MAX_VALUE, map.get(1)); + } + + @Test + public void testPutRemoveGet2() { + Long2IntOpenHashTable map = new Long2IntOpenHashTable(16384); + map.defaultReturnValue(-2); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-2, map.put(i, i)); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.remove(i)); + } + Assert.assertEquals(0, map.size()); + map.defaultReturnValue(-1); + for (int i = numEntries, len = numEntries + (numEntries / 2); i < len; i++) { + Assert.assertEquals(-1, map.put(i, i)); + } + Assert.assertEquals(numEntries / 2, map.size()); + for (int i = numEntries, len = numEntries + (numEntries / 2); i < len; i++) { + Assert.assertEquals(i, map.get(i)); + } + for (int i = numEntries + (numEntries / 2), j = 0; j < numEntries; i++, j++) { + Assert.assertEquals(-1, map.put(i, i)); + } + for (int i = numEntries + (numEntries / 2), j = 0; j < numEntries; i++, j++) { + Assert.assertEquals(i, map.get(i)); + } + } + + @Test + public void testShrink() { + Long2IntOpenHashTable map = new Long2IntOpenHashTable(16384); + final int numEntries = 65536; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1, map.put(i, i)); + Assert.assertEquals(i, map.remove(i)); + } + Assert.assertEquals(0, map.size()); + map.defaultReturnValue(-2); + for (int i = 0, len = 2 * numEntries; i < len; i++) { + Assert.assertEquals(-2, map.put(i, i)); + } + Assert.assertEquals(numEntries * 2, map.size()); + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/collections/maps/OpenHashMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/OpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/OpenHashMapTest.java deleted file mode 100644 index aa48a98..0000000 --- a/core/src/test/java/hivemall/utils/collections/maps/OpenHashMapTest.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections.maps; - -import hivemall.utils.collections.IMapIterator; -import hivemall.utils.collections.maps.OpenHashMap; -import hivemall.utils.lang.mutable.MutableInt; - -import java.util.Map; - -import org.junit.Assert; -import org.junit.Test; - -public class OpenHashMapTest { - - @Test - public void testPutAndGet() { - Map<Object, Object> map = new OpenHashMap<Object, Object>(16384); - final int numEntries = 5000000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), i); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Object v = map.get(Integer.toString(i)); - Assert.assertEquals(i, v); - } - map.put(Integer.toString(1), Integer.MAX_VALUE); - Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1))); - Assert.assertEquals(numEntries, map.size()); - } - - @Test - public void testIterator() { - OpenHashMap<String, Integer> map = new OpenHashMap<String, Integer>(1000); - IMapIterator<String, Integer> itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), i); - } - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - String k = itor.getKey(); - Integer v = itor.getValue(); - Assert.assertEquals(Integer.valueOf(k), v); - } - Assert.assertEquals(-1, itor.next()); - } - - @Test - public void testIteratorGetProbe() { - OpenHashMap<String, MutableInt> map = new OpenHashMap<String, MutableInt>(100); - IMapIterator<String, MutableInt> itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - map.put(Integer.toString(i), new MutableInt(i)); - } - - final MutableInt probe = new MutableInt(); - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - String k = itor.getKey(); - itor.getValue(probe); - Assert.assertEquals(Integer.valueOf(k).intValue(), probe.intValue()); - } - Assert.assertEquals(-1, itor.next()); - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/collections/maps/OpenHashTableTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/OpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/OpenHashTableTest.java index 708c164..cf784fb 100644 --- a/core/src/test/java/hivemall/utils/collections/maps/OpenHashTableTest.java +++ b/core/src/test/java/hivemall/utils/collections/maps/OpenHashTableTest.java @@ -19,7 +19,6 @@ package hivemall.utils.collections.maps; import hivemall.utils.collections.IMapIterator; -import hivemall.utils.collections.maps.OpenHashTable; import hivemall.utils.lang.ObjectUtils; import hivemall.utils.lang.mutable.MutableInt; @@ -48,6 +47,67 @@ public class OpenHashTableTest { } @Test + public void testPutRemoveGet() { + OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertNull(map.put(Integer.toString(i), i)); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.remove(Integer.toString(i))); + } + Assert.assertEquals(0, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertNull(map.get(Integer.toString(i))); + } + map.put(Integer.toString(1), Integer.MAX_VALUE); + Assert.assertEquals(Integer.MAX_VALUE, map.get(Integer.toString(1))); + } + + @Test + public void testPutRemoveGet2() { + OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertNull(map.put(Integer.toString(i), i)); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(i, map.remove(Integer.toString(i))); + } + Assert.assertEquals(0, map.size()); + for (int i = numEntries, len = numEntries + (numEntries / 2); i < len; i++) { + Assert.assertNull(map.put(Integer.toString(i), i)); + } + Assert.assertEquals(numEntries / 2, map.size()); + for (int i = numEntries, len = numEntries + (numEntries / 2); i < len; i++) { + Assert.assertEquals(i, map.get(Integer.toString(i))); + } + for (int i = numEntries + (numEntries / 2), j = 0; j < numEntries; i++, j++) { + Assert.assertNull(map.put(Integer.toString(i), i)); + } + for (int i = numEntries + (numEntries / 2), j = 0; j < numEntries; i++, j++) { + Assert.assertEquals(i, map.get(Integer.toString(i))); + } + } + + @Test + public void testShrink() { + OpenHashTable<Object, Object> map = new OpenHashTable<Object, Object>(16384); + final int numEntries = 65536; + for (int i = 0; i < numEntries; i++) { + Assert.assertNull(map.put(Integer.toString(i), i)); + Assert.assertEquals(i, map.remove(Integer.toString(i))); + } + Assert.assertEquals(0, map.size()); + for (int i = 0, len = 2 * numEntries; i < len; i++) { + Assert.assertNull(map.put(Integer.toString(i), i)); + } + Assert.assertEquals(numEntries * 2, map.size()); + } + + @Test public void testIterator() { OpenHashTable<String, Integer> map = new OpenHashTable<String, Integer>(1000); IMapIterator<String, Integer> itor = map.entries(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/lambda/ThrowingTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/lambda/ThrowingTest.java b/core/src/test/java/hivemall/utils/lambda/ThrowingTest.java new file mode 100644 index 0000000..8eab9f3 --- /dev/null +++ b/core/src/test/java/hivemall/utils/lambda/ThrowingTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.lambda; + +import static hivemall.utils.lambda.Throwing.rethrow; + +import java.io.IOException; +import java.util.Arrays; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +public class ThrowingTest { + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testRethrow() { + thrown.expect(IOException.class); + thrown.expectMessage("i=3"); + + Arrays.asList(1, 2, 3).forEach(rethrow(e -> { + int i = e.intValue(); + if (i == 3) { + throw new IOException("i=" + i); + } + })); + } + + @Test(expected = IOException.class) + public void testSneakyThrow() { + Throwing.sneakyThrow(new IOException()); + } + + @Test + public void testThrowingConsumer() { + thrown.expect(IOException.class); + thrown.expectMessage("i=3"); + + Arrays.asList(1, 2, 3).forEach((ThrowingConsumer<Integer>) e -> { + int i = e.intValue(); + if (i == 3) { + throw new IOException("i=" + i); + } + }); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/core/src/test/java/hivemall/utils/math/FastMathTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/math/FastMathTest.java b/core/src/test/java/hivemall/utils/math/FastMathTest.java new file mode 100644 index 0000000..6892b72 --- /dev/null +++ b/core/src/test/java/hivemall/utils/math/FastMathTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.math; + +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +public class FastMathTest { + private static final boolean DEBUG = false; + + @SuppressWarnings("deprecation") + @Test + public void testFastInverseSquareRootFloat() { + final Random rnd = new Random(43L); + for (int i = 0; i < 100; i++) { + float v = rnd.nextFloat() * (rnd.nextInt(10000) + 1); + Assert.assertEquals(Math.sqrt(v), FastMath.sqrt(v), 1E-5d); + } + } + + @SuppressWarnings("deprecation") + @Test + public void testFastInverseSquareRootDouble() { + final Random rnd = new Random(43L); + for (int i = 0; i < 100; i++) { + double v = rnd.nextDouble() * (rnd.nextInt(10000) + 1); + Assert.assertEquals(Math.sqrt(v), FastMath.sqrt(v), 1E-10d); + } + } + + @Test + public void testSigmoid() { + final Random rnd = new Random(43L); + for (int i = 0; i < 100; i++) { + double v = rnd.nextGaussian() * (rnd.nextInt(10000) + 1); + Assert.assertEquals(Double.toString(v), MathUtils.sigmoid(v), FastMath.sigmoid(v), + 1E-8d); + } + } + + @SuppressWarnings("deprecation") + @Test + public void testSqrtPerformance() { + double result1 = 0d; + // warm up for Math.sqrt + for (double x = 1d; x < 4_000_000d; x += 0.25d) { + result1 += Math.sqrt(x); + } + long startTime = System.nanoTime(); + for (double x = 1d; x < 4_000_000d; x += 0.25d) { + result1 += Math.sqrt(x); + } + long elaspedTimeForSqrt = System.nanoTime() - startTime; + + // warm up for FastMath.sqrt + double result2 = 0d; + for (double x = 1d; x < 4_000_000d; x += 0.25d) { + result2 += FastMath.sqrt(x); + } + startTime = System.nanoTime(); + for (double x = 1d; x < 4_000_000d; x += 0.25D) { + result2 += FastMath.sqrt(x); + } + long elaspedTimeForFastSqrt = System.nanoTime() - startTime; + + if (DEBUG) { + System.out.println("elaspedTimeForFastSqrt=" + elaspedTimeForFastSqrt + + " and elaspedTimeForSqrt=" + elaspedTimeForSqrt); + } + + Assert.assertFalse(result1 == 0d); + Assert.assertFalse(result2 == 0d); + Assert.assertEquals(result1, result2, 1E-5d); + + /* + Assert.assertTrue( + "Expected elaspedTimeForFastSqrt < elaspedTimeForSqrt while elaspedTimeForFastSqrt=" + + elaspedTimeForFastSqrt + " and elaspedTimeForSqrt=" + elaspedTimeForSqrt, + elaspedTimeForFastSqrt < elaspedTimeForSqrt); + */ + } + + public static void main(String[] args) { + FastMathTest test = new FastMathTest(); + for (int i = 1; i <= 10; i++) { + System.out.println("-- " + i); + test.testSqrtPerformance(); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/mixserv/pom.xml ---------------------------------------------------------------------- diff --git a/mixserv/pom.xml b/mixserv/pom.xml index 0e0e83c..b7300cb 100644 --- a/mixserv/pom.xml +++ b/mixserv/pom.xml @@ -151,7 +151,7 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> - <version>2.3</version> + <version>3.1.0</version> <executions> <execution> <id>jar-with-dependencies</id> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad15923a/nlp/pom.xml ---------------------------------------------------------------------- diff --git a/nlp/pom.xml b/nlp/pom.xml index 021cd6d..3941872 100644 --- a/nlp/pom.xml +++ b/nlp/pom.xml @@ -160,7 +160,7 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> - <version>2.3</version> + <version>3.1.0</version> <executions> <execution> <id>jar-with-dependencies</id>
