Close #105: [HIVEMALL-24-2] Make ffm_predict function more scalable by creating its UDAF implementation
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/3410ba64 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/3410ba64 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/3410ba64 Branch: refs/heads/master Commit: 3410ba642665baded960e3c6b8036c58ef1006c2 Parents: 7205de1 Author: Makoto Yui <m...@apache.org> Authored: Mon Sep 11 14:36:12 2017 +0900 Committer: Makoto Yui <m...@apache.org> Committed: Mon Sep 11 14:36:12 2017 +0900 ---------------------------------------------------------------------- .travis.yml | 2 +- .../java/hivemall/common/ConversionState.java | 21 +- core/src/main/java/hivemall/fm/Entry.java | 242 +++++++--- .../java/hivemall/fm/FFMPredictGenericUDAF.java | 262 +++++++++++ .../main/java/hivemall/fm/FFMPredictUDF.java | 187 -------- .../java/hivemall/fm/FFMPredictionModel.java | 349 -------------- .../hivemall/fm/FFMStringFeatureMapModel.java | 315 ++++++++----- .../java/hivemall/fm/FMHyperParameters.java | 74 +-- .../java/hivemall/fm/FMIntFeatureMapModel.java | 6 +- .../java/hivemall/fm/FMPredictGenericUDAF.java | 15 + .../hivemall/fm/FMStringFeatureMapModel.java | 8 +- .../hivemall/fm/FactorizationMachineUDTF.java | 6 +- core/src/main/java/hivemall/fm/Feature.java | 76 ++- .../fm/FieldAwareFactorizationMachineModel.java | 161 ++++++- .../fm/FieldAwareFactorizationMachineUDTF.java | 158 ++++--- core/src/main/java/hivemall/fm/IntFeature.java | 6 +- .../ftvec/pairing/FeaturePairsUDTF.java | 155 ++++-- .../ftvec/ranking/PositiveOnlyFeedback.java | 8 +- .../ftvec/trans/AddFieldIndicesUDF.java | 89 ++++ .../ftvec/trans/CategoricalFeaturesUDF.java | 121 +++-- .../hivemall/ftvec/trans/FFMFeaturesUDF.java | 47 +- .../ftvec/trans/QuantifiedFeaturesUDTF.java | 7 +- .../ftvec/trans/QuantitativeFeaturesUDF.java | 101 +++- .../ftvec/trans/VectorizeFeaturesUDF.java | 110 +++-- .../main/java/hivemall/mf/FactorizedModel.java | 18 +- .../hivemall/model/AbstractPredictionModel.java | 6 +- .../java/hivemall/model/NewSparseModel.java | 2 +- .../main/java/hivemall/model/SparseModel.java | 2 +- .../tools/array/ArrayAvgGenericUDAF.java | 17 +- .../java/hivemall/utils/buffer/HeapBuffer.java | 37 +- .../maps/Int2FloatOpenHashTable.java | 11 +- .../collections/maps/Int2IntOpenHashTable.java | 9 +- .../collections/maps/Int2LongOpenHashMap.java | 346 ++++++++++++++ .../collections/maps/Int2LongOpenHashTable.java | 114 +++-- .../utils/collections/maps/IntOpenHashMap.java | 467 ------------------- .../collections/maps/IntOpenHashTable.java | 142 ++++-- .../maps/Long2DoubleOpenHashTable.java | 9 +- .../maps/Long2FloatOpenHashTable.java | 11 +- .../collections/maps/Long2IntOpenHashTable.java | 9 +- .../utils/collections/maps/OpenHashMap.java | 128 +++-- .../utils/collections/maps/OpenHashTable.java | 12 +- .../java/hivemall/utils/hadoop/HiveUtils.java | 74 ++- .../java/hivemall/utils/hashing/HashUtils.java | 89 ++++ .../java/hivemall/utils/lang/NumberUtils.java | 68 +++ .../java/hivemall/utils/lang/Primitives.java | 24 - .../java/hivemall/utils/math/MathUtils.java | 33 +- .../hivemall/fm/FFMPredictionModelTest.java | 65 --- core/src/test/java/hivemall/fm/FeatureTest.java | 7 +- .../FieldAwareFactorizationMachineUDTFTest.java | 66 +-- .../smile/tools/TreePredictUDFv1Test.java | 1 + .../maps/Int2FloatOpenHashMapTest.java | 98 ---- .../maps/Int2FloatOpenHashTableTest.java | 98 ++++ .../maps/Int2LongOpenHashMapTest.java | 66 +-- .../maps/Int2LongOpenHashTableTest.java | 130 ++++++ .../collections/maps/IntOpenHashMapTest.java | 75 --- .../collections/maps/IntOpenHashTableTest.java | 23 + .../maps/Long2IntOpenHashMapTest.java | 115 ----- .../maps/Long2IntOpenHashTableTest.java | 115 +++++ docs/gitbook/getting_started/input-format.md | 31 +- pom.xml | 18 + resources/ddl/define-all-as-permanent.hive | 5 +- resources/ddl/define-all.hive | 5 +- resources/ddl/define-all.spark | 5 +- resources/ddl/define-udfs.td.hql | 3 + spark/spark-2.0/pom.xml | 10 +- spark/spark-2.1/pom.xml | 10 +- spark/spark-common/pom.xml | 8 +- 67 files changed, 2962 insertions(+), 2146 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/.travis.yml ---------------------------------------------------------------------- diff --git a/.travis.yml b/.travis.yml index 323e36a..96f8f4e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,7 +19,7 @@ env: language: java jdk: - openjdk7 - - oraclejdk7 +# - oraclejdk7 - oraclejdk8 branches: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/common/ConversionState.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/common/ConversionState.java b/core/src/main/java/hivemall/common/ConversionState.java index 7b5923f..435bf75 100644 --- a/core/src/main/java/hivemall/common/ConversionState.java +++ b/core/src/main/java/hivemall/common/ConversionState.java @@ -99,18 +99,25 @@ public final class ConversionState { if (changeRate < convergenceRate) { if (readyToFinishIterations) { // NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY - logger.info("Training converged at " + curIter + "-th iteration. [curLosses=" - + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate - + ']'); + if (logger.isInfoEnabled()) { + logger.info("Training converged at " + curIter + "-th iteration. [curLosses=" + + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + + changeRate + ']'); + } return true; } else { + if (logger.isInfoEnabled()) { + logger.info("Iteration #" + curIter + " [curLosses=" + currLosses + + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate + + ", #trainingExamples=" + observedTrainingExamples + ']'); + } this.readyToFinishIterations = true; } } else { - if (logger.isDebugEnabled()) { - logger.debug("Iteration #" + curIter + " [curLosses=" + currLosses - + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate - + ", #trainingExamples=" + observedTrainingExamples + ']'); + if (logger.isInfoEnabled()) { + logger.info("Iteration #" + curIter + " [curLosses=" + currLosses + ", prevLosses=" + + prevLosses + ", changeRate=" + changeRate + ", #trainingExamples=" + + observedTrainingExamples + ']'); } this.readyToFinishIterations = false; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/Entry.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/Entry.java b/core/src/main/java/hivemall/fm/Entry.java index 1882f85..974ab5b 100644 --- a/core/src/main/java/hivemall/fm/Entry.java +++ b/core/src/main/java/hivemall/fm/Entry.java @@ -20,17 +20,27 @@ package hivemall.fm; import hivemall.utils.buffer.HeapBuffer; import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.SizeOf; +import hivemall.utils.math.MathUtils; +import java.util.Arrays; + +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; class Entry { @Nonnull protected final HeapBuffer _buf; + @Nonnegative protected final int _size; + @Nonnegative protected final int _factors; + // temporary variables used only in training phase + protected int _key; + @Nonnegative protected long _offset; Entry(@Nonnull HeapBuffer buf, int factors) { @@ -39,128 +49,210 @@ class Entry { this._factors = factors; } - Entry(@Nonnull HeapBuffer buf, int factors, long offset) { - this(buf, factors, Entry.sizeOf(factors), offset); + Entry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) { + this(buf, 1, key, offset); + } + + Entry(@Nonnull HeapBuffer buf, int factors, int key, @Nonnegative long offset) { + this(buf, factors, Entry.sizeOf(factors), key, offset); } - private Entry(@Nonnull HeapBuffer buf, int factors, int size, long offset) { + private Entry(@Nonnull HeapBuffer buf, int factors, int size, int key, @Nonnegative long offset) { this._buf = buf; this._size = size; this._factors = factors; - setOffset(offset); + this._key = key; + this._offset = offset; } - int getSize() { + final int getSize() { return _size; } - long getOffset() { + final int getKey() { + return _key; + } + + final long getOffset() { return _offset; } - void setOffset(long offset) { + final void setOffset(final long offset) { this._offset = offset; } - float getW() { + final float getW() { return _buf.getFloat(_offset); } - void setW(final float value) { + final void setW(final float value) { _buf.putFloat(_offset, value); } - void getV(@Nonnull final float[] Vf) { - final long offset = _offset + SizeOf.FLOAT; + final void getV(@Nonnull final float[] Vf) { + final long offset = _offset; final int len = Vf.length; - for (int i = 0; i < len; i++) { - Vf[i] = _buf.getFloat(offset + SizeOf.FLOAT * i); + for (int f = 0; f < len; f++) { + long index = offset + SizeOf.FLOAT * f; + Vf[f] = _buf.getFloat(index); } } - void setV(@Nonnull final float[] Vf) { - final long offset = _offset + SizeOf.FLOAT; + final void setV(@Nonnull final float[] Vf) { + final long offset = _offset; final int len = Vf.length; - for (int i = 0; i < len; i++) { - _buf.putFloat(offset + SizeOf.FLOAT * i, Vf[i]); + for (int f = 0; f < len; f++) { + long index = offset + SizeOf.FLOAT * f; + _buf.putFloat(index, Vf[f]); } } - float getV(final int f) { - return _buf.getFloat(_offset + SizeOf.FLOAT + SizeOf.FLOAT * f); + final float getV(final int f) { + long index = _offset + SizeOf.FLOAT * f; + return _buf.getFloat(index); } - void setV(final int f, final float value) { - long index = _offset + SizeOf.FLOAT + SizeOf.FLOAT * f; + final void setV(final int f, final float value) { + long index = _offset + SizeOf.FLOAT * f; _buf.putFloat(index, value); } - double getSumOfSquaredGradientsV() { + double getSumOfSquaredGradients(@Nonnegative int f) { throw new UnsupportedOperationException(); } - void addGradientV(float grad) { + void addGradient(@Nonnegative int f, float grad) { throw new UnsupportedOperationException(); } - float updateZ(float gradW, float alpha) { + final float updateZ(final float gradW, final float alpha) { + float w = getW(); + return updateZ(0, w, gradW, alpha); + } + + float updateZ(@Nonnegative int f, float W, float gradW, float alpha) { throw new UnsupportedOperationException(); } - double updateN(float gradW) { + final double updateN(final float gradW) { + return updateN(0, gradW); + } + + double updateN(@Nonnegative int f, float gradW) { throw new UnsupportedOperationException(); } - static int sizeOf(int factors) { - return SizeOf.FLOAT + SizeOf.FLOAT * factors; + boolean removable() { + if (!isEntryW(_key)) {// entry for V + final long offset = _offset; + for (int f = 0; f < _factors; f++) { + final float Vf = _buf.getFloat(offset + SizeOf.FLOAT * f); + if (!MathUtils.closeToZero(Vf, 1E-9f)) { + return false; + } + } + } + return true; + } + + void clear() {}; + + static int sizeOf(@Nonnegative final int factors) { + Preconditions.checkArgument(factors >= 1, "Factors must be greather than 0: " + factors); + return SizeOf.FLOAT * factors; + } + + static boolean isEntryW(final int i) { + return i < 0; + } + + @Override + public String toString() { + if (Entry.isEntryW(_key)) { + return "W=" + getW(); + } else { + float[] Vf = new float[_factors]; + getV(Vf); + return "V=" + Arrays.toString(Vf); + } } - static class AdaGradEntry extends Entry { + static final class AdaGradEntry extends Entry { final long _gg_offset; - AdaGradEntry(@Nonnull HeapBuffer buf, int factors, long offset) { - super(buf, factors, AdaGradEntry.sizeOf(factors), offset); - this._gg_offset = _offset + SizeOf.FLOAT + SizeOf.FLOAT * _factors; + AdaGradEntry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) { + this(buf, 1, key, offset); } - private AdaGradEntry(@Nonnull HeapBuffer buf, int factors, int size, long offset) { - super(buf, factors, size, offset); - this._gg_offset = _offset + SizeOf.FLOAT + SizeOf.FLOAT * _factors; + AdaGradEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key, + @Nonnegative long offset) { + super(buf, factors, AdaGradEntry.sizeOf(factors), key, offset); + this._gg_offset = _offset + Entry.sizeOf(factors); } @Override - double getSumOfSquaredGradientsV() { - return _buf.getDouble(_gg_offset); + double getSumOfSquaredGradients(@Nonnegative final int f) { + Preconditions.checkArgument(f >= 0); + + long offset = _gg_offset + SizeOf.DOUBLE * f; + return _buf.getDouble(offset); } @Override - void addGradientV(float grad) { - double v = _buf.getDouble(_gg_offset); + void addGradient(@Nonnegative final int f, final float grad) { + Preconditions.checkArgument(f >= 0); + + long offset = _gg_offset + SizeOf.DOUBLE * f; + double v = _buf.getDouble(offset); v += grad * grad; - _buf.putDouble(_gg_offset, v); + _buf.putDouble(offset, v); } - static int sizeOf(int factors) { - return Entry.sizeOf(factors) + SizeOf.DOUBLE; + @Override + void clear() { + for (int f = 0; f < _factors; f++) { + long offset = _gg_offset + SizeOf.DOUBLE * f; + _buf.putDouble(offset, 0.d); + } + } + + static int sizeOf(@Nonnegative final int factors) { + return Entry.sizeOf(factors) + SizeOf.DOUBLE * factors; + } + + @Override + public String toString() { + final double[] gg = new double[_factors]; + for (int f = 0; f < _factors; f++) { + gg[f] = getSumOfSquaredGradients(f); + } + return super.toString() + ", gg=" + Arrays.toString(gg); } } - static final class FTRLEntry extends AdaGradEntry { + static final class FTRLEntry extends Entry { final long _z_offset; - FTRLEntry(@Nonnull HeapBuffer buf, int factors, long offset) { - super(buf, factors, FTRLEntry.sizeOf(factors), offset); - this._z_offset = _gg_offset + SizeOf.DOUBLE; + FTRLEntry(@Nonnull HeapBuffer buf, int key, long offset) { + this(buf, 1, key, offset); + } + + FTRLEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key, long offset) { + super(buf, factors, FTRLEntry.sizeOf(factors), key, offset); + this._z_offset = _offset + Entry.sizeOf(factors); } @Override - float updateZ(float gradW, float alpha) { - final float W = getW(); - final float z = getZ(); - final double n = getN(); + float updateZ(final int f, final float W, final float gradW, final float alpha) { + Preconditions.checkArgument(f >= 0); + + final long zOffset = offsetZ(f); + + final float z = _buf.getFloat(zOffset); + final double n = _buf.getFloat(offsetN(f)); // implicit cast to float double gg = gradW * gradW; float sigma = (float) ((Math.sqrt(n + gg) - Math.sqrt(n)) / alpha); @@ -171,44 +263,56 @@ class Entry { + gradW + ", sigma=" + sigma + ", W=" + W + ", n=" + n + ", gg=" + gg + ", alpha=" + alpha); } - setZ(newZ); + _buf.putFloat(zOffset, newZ); return newZ; } - private float getZ() { - return _buf.getFloat(_z_offset); - } - - private void setZ(final float value) { - _buf.putFloat(_z_offset, value); - } - @Override - double updateN(final float gradW) { - final double n = getN(); + double updateN(final int f, final float gradW) { + Preconditions.checkArgument(f >= 0); + + final long nOffset = offsetN(f); + final double n = _buf.getFloat(nOffset); final double newN = n + gradW * gradW; if (!NumberUtils.isFinite(newN)) { throw new IllegalStateException("Got newN " + newN + " where n=" + n + ", gradW=" + gradW); } - setN(newN); + _buf.putFloat(nOffset, NumberUtils.castToFloat(newN)); // cast may throw ArithmeticException return newN; } - private double getN() { - long index = _z_offset + SizeOf.FLOAT; - return _buf.getDouble(index); + private long offsetZ(@Nonnegative final int f) { + return _z_offset + SizeOf.FLOAT * f; } - private void setN(final double value) { - long index = _z_offset + SizeOf.FLOAT; - _buf.putDouble(index, value); + private long offsetN(@Nonnegative final int f) { + return _z_offset + SizeOf.FLOAT * (_factors + f); } - static int sizeOf(int factors) { - return AdaGradEntry.sizeOf(factors) + SizeOf.FLOAT + SizeOf.DOUBLE; + @Override + void clear() { + for (int f = 0; f < _factors; f++) { + _buf.putFloat(offsetZ(f), 0.f); + _buf.putFloat(offsetN(f), 0.f); + } } + static int sizeOf(@Nonnegative final int factors) { + return Entry.sizeOf(factors) + (SizeOf.FLOAT + SizeOf.FLOAT) * factors; + } + + @Override + public String toString() { + final float[] Z = new float[_factors]; + final float[] N = new float[_factors]; + for (int f = 0; f < _factors; f++) { + Z[f] = _buf.getFloat(offsetZ(f)); + N[f] = _buf.getFloat(offsetN(f)); + } + return super.toString() + ", Z=" + Arrays.toString(Z) + ", N=" + Arrays.toString(N); + } } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java new file mode 100644 index 0000000..7cbd688 --- /dev/null +++ b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.fm; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.SizeOf; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +@Description(name = "ffm_predict", + value = "_FUNC_(float Wi, array<float> Vifj, array<float> Vjfi, float Xi, float Xj)" + + " - Returns a prediction value in Double") +public final class FFMPredictGenericUDAF extends AbstractGenericUDAFResolver { + + private FFMPredictGenericUDAF() {} + + @Override + public Evaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 5) { + throw new UDFArgumentLengthException( + "Expected argument length is 5 but given argument length was " + typeInfo.length); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[0])) { + throw new UDFArgumentTypeException(0, + "Number type is expected for the first argument Wi: " + typeInfo[0].getTypeName()); + } + if (typeInfo[1].getCategory() != Category.LIST) { + throw new UDFArgumentTypeException(1, + "List type is expected for the second argument Vifj: " + typeInfo[1].getTypeName()); + } + if (typeInfo[2].getCategory() != Category.LIST) { + throw new UDFArgumentTypeException(2, + "List type is expected for the third argument Vjfi: " + typeInfo[2].getTypeName()); + } + ListTypeInfo typeInfo1 = (ListTypeInfo) typeInfo[1]; + if (!HiveUtils.isFloatingPointTypeInfo(typeInfo1.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(1, + "Double or Float type is expected for the element type of list Vifj: " + + typeInfo1.getTypeName()); + } + ListTypeInfo typeInfo2 = (ListTypeInfo) typeInfo[2]; + if (!HiveUtils.isFloatingPointTypeInfo(typeInfo2.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(2, + "Double or Float type is expected for the element type of list Vjfi: " + + typeInfo1.getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) { + throw new UDFArgumentTypeException(3, + "Number type is expected for the third argument Xi: " + typeInfo[3].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[4])) { + throw new UDFArgumentTypeException(4, + "Number type is expected for the third argument Xi: " + typeInfo[4].getTypeName()); + } + return new Evaluator(); + } + + public static final class Evaluator extends GenericUDAFEvaluator { + + // input OI + private PrimitiveObjectInspector wiOI; + private ListObjectInspector vijOI, vjiOI; + private PrimitiveObjectInspector vijElemOI, vjiElemOI; + private PrimitiveObjectInspector xiOI, xjOI; + + // merge input OI + private DoubleObjectInspector mergeInputOI; + + public Evaluator() {} + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 5); + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.wiOI = HiveUtils.asDoubleCompatibleOI(parameters[0]); + this.vijOI = HiveUtils.asListOI(parameters[1]); + this.vijElemOI = HiveUtils.asFloatingPointOI(vijOI.getListElementObjectInspector()); + this.vjiOI = HiveUtils.asListOI(parameters[2]); + this.vjiElemOI = HiveUtils.asFloatingPointOI(vjiOI.getListElementObjectInspector()); + this.xiOI = HiveUtils.asDoubleCompatibleOI(parameters[3]); + this.xjOI = HiveUtils.asDoubleCompatibleOI(parameters[4]); + } else {// from partial aggregation + this.mergeInputOI = HiveUtils.asDoubleOI(parameters[0]); + } + + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + + @Override + public FFMPredictAggregationBuffer getNewAggregationBuffer() throws HiveException { + FFMPredictAggregationBuffer myAggr = new FFMPredictAggregationBuffer(); + reset(myAggr); + return myAggr; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + myAggr.reset(); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + final FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + + if (parameters[0] == null) {// Wi is null + if (parameters[3] == null || parameters[4] == null) { + // both Xi and Xj are nonnull => <Vifj, Vjfi> Xi Xj + return; + } + if (parameters[1] == null || parameters[2] == null) { + // vi, vj can be null where feature index does not exist in the prediction model + return; + } + + // (i, j, xi, xj) => (wi, vi, vj, xi, xj) + float[] vij = HiveUtils.asFloatArray(parameters[1], vijOI, vijElemOI, false); + float[] vji = HiveUtils.asFloatArray(parameters[2], vjiOI, vjiElemOI, false); + double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI); + double xj = PrimitiveObjectInspectorUtils.getDouble(parameters[4], xjOI); + + myAggr.addViVjXiXj(vij, vji, xi, xj); + } else { + final double wi = PrimitiveObjectInspectorUtils.getDouble(parameters[0], wiOI); + + if (parameters[3] == null && parameters[4] == null) {// Xi and Xj are null => global bias `w0` + // (i=0, j=null, xi=null, xj=null) => (wi, vi=?, vj=null, xi=null, xj=null) + myAggr.addW0(wi); + } else if (parameters[4] == null) {// Only Xi is nonnull => linear combination `wi` * `xi` + // (i, j=null, xi, xj=null) => (wi, vi, vj=null, xi, xj=null) + double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI); + myAggr.addWiXi(wi, xi); + } + } + } + + @Override + public DoubleWritable terminatePartial( + @SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { + FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + double sum = myAggr.get(); + return new DoubleWritable(sum); + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + double sum = mergeInputOI.get(partial); + myAggr.merge(sum); + } + + @Override + public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + double result = myAggr.get(); + return new DoubleWritable(result); + } + + } + + @AggregationType(estimable = true) + public static final class FFMPredictAggregationBuffer extends AbstractAggregationBuffer { + + private double sum; + + FFMPredictAggregationBuffer() { + super(); + } + + void reset() { + this.sum = 0.d; + } + + void merge(double o_sum) { + this.sum += o_sum; + } + + double get() { + return sum; + } + + void addW0(final double W0) { + this.sum += W0; + } + + void addWiXi(final double Wi, final double Xi) { + this.sum += (Wi * Xi); + } + + void addViVjXiXj(@Nonnull final float[] Vij, @Nonnull final float[] Vji, final double Xi, + final double Xj) throws UDFArgumentException { + if (Vij.length != Vji.length) { + throw new UDFArgumentException("Mismatch in the number of factors"); + } + + final int factors = Vij.length; + + // compute inner product <Vifj, Vjfi> + double prod = 0.d; + for (int f = 0; f < factors; f++) { + prod += (Vij[f] * Vji[f]); + } + + this.sum += (prod * Xi * Xj); + } + + @Override + public int estimate() { + return SizeOf.DOUBLE; + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FFMPredictUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMPredictUDF.java b/core/src/main/java/hivemall/fm/FFMPredictUDF.java deleted file mode 100644 index 48745d9..0000000 --- a/core/src/main/java/hivemall/fm/FFMPredictUDF.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.fm; - -import hivemall.annotations.Experimental; -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.lang.NumberUtils; - -import java.io.IOException; -import java.util.Arrays; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; -import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; -import org.apache.hadoop.io.Text; - -/** - * @since v0.5-rc.1 - */ -@Description(name = "ffm_predict", - value = "_FUNC_(string modelId, string model, array<string> features)" - + " returns a prediction result in double from a Field-aware Factorization Machine") -@UDFType(deterministic = true, stateful = false) -@Experimental -public final class FFMPredictUDF extends GenericUDF { - - private StringObjectInspector _modelIdOI; - private StringObjectInspector _modelOI; - private ListObjectInspector _featureListOI; - - private DoubleWritable _result; - @Nullable - private String _cachedModeId; - @Nullable - private FFMPredictionModel _cachedModel; - @Nullable - private Feature[] _probes; - - public FFMPredictUDF() {} - - @Override - public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - if (argOIs.length != 3) { - throw new UDFArgumentException("_FUNC_ takes 3 arguments"); - } - this._modelIdOI = HiveUtils.asStringOI(argOIs[0]); - this._modelOI = HiveUtils.asStringOI(argOIs[1]); - this._featureListOI = HiveUtils.asListOI(argOIs[2]); - - this._result = new DoubleWritable(); - return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; - } - - @Override - public Object evaluate(DeferredObject[] args) throws HiveException { - String modelId = _modelIdOI.getPrimitiveJavaObject(args[0].get()); - if (modelId == null) { - throw new HiveException("modelId is not set"); - } - - final FFMPredictionModel model; - if (modelId.equals(_cachedModeId)) { - model = this._cachedModel; - } else { - Text serModel = _modelOI.getPrimitiveWritableObject(args[1].get()); - if (serModel == null) { - throw new HiveException("Model is null for model ID: " + modelId); - } - byte[] b = serModel.getBytes(); - final int length = serModel.getLength(); - try { - model = FFMPredictionModel.deserialize(b, length); - b = null; - } catch (ClassNotFoundException e) { - throw new HiveException(e); - } catch (IOException e) { - throw new HiveException(e); - } - this._cachedModeId = modelId; - this._cachedModel = model; - } - - int numFeatures = model.getNumFeatures(); - int numFields = model.getNumFields(); - - Object arg2 = args[2].get(); - // [workaround] - // java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray - // cannot be cast to [Ljava.lang.Object; - if (arg2 instanceof LazyBinaryArray) { - arg2 = ((LazyBinaryArray) arg2).getList(); - } - Feature[] x = Feature.parseFFMFeatures(arg2, _featureListOI, _probes, numFeatures, - numFields); - if (x == null || x.length == 0) { - return null; // return NULL if there are no features - } - this._probes = x; - - double predicted = predict(x, model); - _result.set(predicted); - return _result; - } - - private static double predict(@Nonnull final Feature[] x, - @Nonnull final FFMPredictionModel model) throws HiveException { - // w0 - double ret = model.getW0(); - // W - for (Feature e : x) { - double xi = e.getValue(); - float wi = model.getW(e); - double wx = wi * xi; - ret += wx; - } - // V - final int factors = model.getNumFactors(); - final float[] vij = new float[factors]; - final float[] vji = new float[factors]; - for (int i = 0; i < x.length; ++i) { - final Feature ei = x[i]; - final double xi = ei.getValue(); - final int iField = ei.getField(); - for (int j = i + 1; j < x.length; ++j) { - final Feature ej = x[j]; - final double xj = ej.getValue(); - final int jField = ej.getField(); - if (!model.getV(ei, jField, vij)) { - continue; - } - if (!model.getV(ej, iField, vji)) { - continue; - } - for (int f = 0; f < factors; f++) { - float vijf = vij[f]; - float vjif = vji[f]; - ret += vijf * vjif * xi * xj; - } - } - } - if (!NumberUtils.isFinite(ret)) { - throw new HiveException("Detected " + ret + " in ffm_predict"); - } - return ret; - } - - @Override - public void close() throws IOException { - super.close(); - // clean up to help GC - this._cachedModel = null; - this._probes = null; - } - - @Override - public String getDisplayString(String[] args) { - return "ffm_predict(" + Arrays.toString(args) + ")"; - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FFMPredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMPredictionModel.java b/core/src/main/java/hivemall/fm/FFMPredictionModel.java deleted file mode 100644 index befbec9..0000000 --- a/core/src/main/java/hivemall/fm/FFMPredictionModel.java +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.fm; - -import hivemall.utils.buffer.HeapBuffer; -import hivemall.utils.codec.VariableByteCodec; -import hivemall.utils.codec.ZigZagLEB128Codec; -import hivemall.utils.collections.maps.Int2LongOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashTable; -import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm; -import hivemall.utils.io.IOUtils; -import hivemall.utils.lang.ArrayUtils; -import hivemall.utils.lang.HalfFloat; -import hivemall.utils.lang.ObjectUtils; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.Arrays; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -public final class FFMPredictionModel implements Externalizable { - private static final Log LOG = LogFactory.getLog(FFMPredictionModel.class); - - private static final byte HALF_FLOAT_ENTRY = 1; - private static final byte W_ONLY_HALF_FLOAT_ENTRY = 2; - private static final byte FLOAT_ENTRY = 3; - private static final byte W_ONLY_FLOAT_ENTRY = 4; - - /** - * maps feature to feature weight pointer - */ - private Int2LongOpenHashTable _map; - private HeapBuffer _buf; - - private double _w0; - private int _factors; - private int _numFeatures; - private int _numFields; - - public FFMPredictionModel() {}// for Externalizable - - public FFMPredictionModel(@Nonnull Int2LongOpenHashTable map, @Nonnull HeapBuffer buf, - double w0, int factor, int numFeatures, int numFields) { - this._map = map; - this._buf = buf; - this._w0 = w0; - this._factors = factor; - this._numFeatures = numFeatures; - this._numFields = numFields; - } - - public int getNumFactors() { - return _factors; - } - - public double getW0() { - return _w0; - } - - public int getNumFeatures() { - return _numFeatures; - } - - public int getNumFields() { - return _numFields; - } - - public int getActualNumFeatures() { - return _map.size(); - } - - public long approxBytesConsumed() { - int size = _map.size(); - - // [map] size * (|state| + |key| + |entry|) - long bytes = size * (1L + 4L + 4L + (4L * _factors)); - int rest = _map.capacity() - size; - if (rest > 0) { - bytes += rest * 1L; - } - // w0, factors, numFeatures, numFields, used, size - bytes += (8 + 4 + 4 + 4 + 4 + 4); - return bytes; - } - - @Nullable - private Entry getEntry(final int key) { - final long ptr = _map.get(key); - if (ptr == -1L) { - return null; - } - return new Entry(_buf, _factors, ptr); - } - - public float getW(@Nonnull final Feature x) { - int j = x.getFeatureIndex(); - - Entry entry = getEntry(j); - if (entry == null) { - return 0.f; - } - return entry.getW(); - } - - /** - * @return true if V exists - */ - public boolean getV(@Nonnull final Feature x, @Nonnull final int yField, @Nonnull float[] dst) { - int j = Feature.toIntFeature(x, yField, _numFields); - - Entry entry = getEntry(j); - if (entry == null) { - return false; - } - - entry.getV(dst); - if (ArrayUtils.equals(dst, 0.f)) { - return false; // treat as null - } - return true; - } - - @Override - public void writeExternal(@Nonnull ObjectOutput out) throws IOException { - out.writeDouble(_w0); - final int factors = _factors; - out.writeInt(factors); - out.writeInt(_numFeatures); - out.writeInt(_numFields); - - int used = _map.size(); - out.writeInt(used); - - final int[] keys = _map.getKeys(); - final int size = keys.length; - out.writeInt(size); - - final byte[] states = _map.getStates(); - writeStates(states, out); - - final long[] values = _map.getValues(); - - final HeapBuffer buf = _buf; - final Entry e = new Entry(buf, factors); - final float[] Vf = new float[factors]; - for (int i = 0; i < size; i++) { - if (states[i] != IntOpenHashTable.FULL) { - continue; - } - ZigZagLEB128Codec.writeSignedInt(keys[i], out); - e.setOffset(values[i]); - writeEntry(e, factors, Vf, out); - } - - // help GC - this._map = null; - this._buf = null; - } - - private static void writeEntry(@Nonnull final Entry e, final int factors, - @Nonnull final float[] Vf, @Nonnull final DataOutput out) throws IOException { - final float W = e.getW(); - e.getV(Vf); - - if (ArrayUtils.almostEquals(Vf, 0.f)) { - if (HalfFloat.isRepresentable(W)) { - out.writeByte(W_ONLY_HALF_FLOAT_ENTRY); - out.writeShort(HalfFloat.floatToHalfFloat(W)); - } else { - out.writeByte(W_ONLY_FLOAT_ENTRY); - out.writeFloat(W); - } - } else if (isRepresentableAsHalfFloat(W, Vf)) { - out.writeByte(HALF_FLOAT_ENTRY); - out.writeShort(HalfFloat.floatToHalfFloat(W)); - for (int i = 0; i < factors; i++) { - out.writeShort(HalfFloat.floatToHalfFloat(Vf[i])); - } - } else { - out.writeByte(FLOAT_ENTRY); - out.writeFloat(W); - IOUtils.writeFloats(Vf, factors, out); - } - } - - private static boolean isRepresentableAsHalfFloat(final float W, @Nonnull final float[] Vf) { - if (!HalfFloat.isRepresentable(W)) { - return false; - } - for (float V : Vf) { - if (!HalfFloat.isRepresentable(V)) { - return false; - } - } - return true; - } - - @Nonnull - static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out) - throws IOException { - // write empty states's indexes differentially - final int size = status.length; - int cardinarity = 0; - for (int i = 0; i < size; i++) { - if (status[i] != IntOpenHashTable.FULL) { - cardinarity++; - } - } - out.writeInt(cardinarity); - if (cardinarity == 0) { - return; - } - int prev = 0; - for (int i = 0; i < size; i++) { - if (status[i] != IntOpenHashTable.FULL) { - int diff = i - prev; - assert (diff >= 0); - VariableByteCodec.encodeUnsignedInt(diff, out); - prev = i; - } - } - } - - @Override - public void readExternal(@Nonnull final ObjectInput in) throws IOException, - ClassNotFoundException { - this._w0 = in.readDouble(); - final int factors = in.readInt(); - this._factors = factors; - this._numFeatures = in.readInt(); - this._numFields = in.readInt(); - - final int used = in.readInt(); - final int size = in.readInt(); - final int[] keys = new int[size]; - final long[] values = new long[size]; - final byte[] states = new byte[size]; - readStates(in, states); - - final int entrySize = Entry.sizeOf(factors); - int numChunks = (entrySize * used) / HeapBuffer.DEFAULT_CHUNK_BYTES + 1; - final HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE, numChunks); - final Entry e = new Entry(buf, factors); - final float[] Vf = new float[factors]; - for (int i = 0; i < size; i++) { - if (states[i] != IntOpenHashTable.FULL) { - continue; - } - keys[i] = ZigZagLEB128Codec.readSignedInt(in); - long ptr = buf.allocate(entrySize); - e.setOffset(ptr); - readEntry(in, factors, Vf, e); - values[i] = ptr; - } - - this._map = new Int2LongOpenHashTable(keys, values, states, used); - this._buf = buf; - } - - @Nonnull - private static void readEntry(@Nonnull final DataInput in, final int factors, - @Nonnull final float[] Vf, @Nonnull Entry dst) throws IOException { - final byte type = in.readByte(); - switch (type) { - case HALF_FLOAT_ENTRY: { - float W = HalfFloat.halfFloatToFloat(in.readShort()); - dst.setW(W); - for (int i = 0; i < factors; i++) { - Vf[i] = HalfFloat.halfFloatToFloat(in.readShort()); - } - dst.setV(Vf); - break; - } - case W_ONLY_HALF_FLOAT_ENTRY: { - float W = HalfFloat.halfFloatToFloat(in.readShort()); - dst.setW(W); - break; - } - case FLOAT_ENTRY: { - float W = in.readFloat(); - dst.setW(W); - IOUtils.readFloats(in, Vf); - dst.setV(Vf); - break; - } - case W_ONLY_FLOAT_ENTRY: { - float W = in.readFloat(); - dst.setW(W); - break; - } - default: - throw new IOException("Unexpected Entry type: " + type); - } - } - - @Nonnull - static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status) - throws IOException { - // read non-empty states differentially - final int cardinarity = in.readInt(); - Arrays.fill(status, IntOpenHashTable.FULL); - int prev = 0; - for (int j = 0; j < cardinarity; j++) { - int i = VariableByteCodec.decodeUnsignedInt(in) + prev; - status[i] = IntOpenHashTable.FREE; - prev = i; - } - } - - public byte[] serialize() throws IOException { - LOG.info("FFMPredictionModel#serialize(): " + _buf.toString()); - return ObjectUtils.toCompressedBytes(this, CompressionAlgorithm.lzma2, true); - } - - public static FFMPredictionModel deserialize(@Nonnull final byte[] serializedObj, final int len) - throws ClassNotFoundException, IOException { - FFMPredictionModel model = new FFMPredictionModel(); - ObjectUtils.readCompressedObject(serializedObj, len, model, CompressionAlgorithm.lzma2, - true); - LOG.info("FFMPredictionModel#deserialize(): " + model._buf.toString()); - return model; - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java index 4f445fa..22b0541 100644 --- a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java @@ -22,13 +22,20 @@ import hivemall.fm.Entry.AdaGradEntry; import hivemall.fm.Entry.FTRLEntry; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.buffer.HeapBuffer; +import hivemall.utils.collections.lists.LongArrayList; import hivemall.utils.collections.maps.Int2LongOpenHashTable; +import hivemall.utils.collections.maps.Int2LongOpenHashTable.MapIterator; import hivemall.utils.lang.NumberUtils; -import hivemall.utils.math.MathUtils; +import java.text.NumberFormat; +import java.util.Locale; + +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import org.roaringbitmap.RoaringBitmap; + public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachineModel { private static final int DEFAULT_MAPSIZE = 65536; @@ -36,37 +43,55 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi private float _w0; @Nonnull private final Int2LongOpenHashTable _map; + @Nonnull private final HeapBuffer _buf; + @Nonnull + private final LongArrayList _freelistW; + @Nonnull + private final LongArrayList _freelistV; + + private boolean _initV; + @Nonnull + private RoaringBitmap _removedV; + // hyperparams - private final int _numFeatures; private final int _numFields; - // FTEL - private final float _alpha; - private final float _beta; - private final float _lambda1; - private final float _lamdda2; + private final int _entrySizeW; + private final int _entrySizeV; - private final int _entrySize; + // statistics + private long _bytesAllocated, _bytesUsed; + private int _numAllocatedW, _numReusedW, _numRemovedW; + private int _numAllocatedV, _numReusedV, _numRemovedV; public FFMStringFeatureMapModel(@Nonnull FFMHyperParameters params) { super(params); this._w0 = 0.f; this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE); this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE); - this._numFeatures = params.numFeatures; + this._freelistW = new LongArrayList(); + this._freelistV = new LongArrayList(); + this._initV = true; + this._removedV = new RoaringBitmap(); this._numFields = params.numFields; - this._alpha = params.alphaFTRL; - this._beta = params.betaFTRL; - this._lambda1 = params.lambda1; - this._lamdda2 = params.lamdda2; - this._entrySize = entrySize(_factor, _useFTRL, _useAdaGrad); + this._entrySizeW = entrySize(1, _useFTRL, _useAdaGrad); + this._entrySizeV = entrySize(_factor, _useFTRL, _useAdaGrad); } - @Nonnull - FFMPredictionModel toPredictionModel() { - return new FFMPredictionModel(_map, _buf, _w0, _factor, _numFeatures, _numFields); + private static int entrySize(@Nonnegative int factors, boolean ftrl, boolean adagrad) { + if (ftrl) { + return FTRLEntry.sizeOf(factors); + } else if (adagrad) { + return AdaGradEntry.sizeOf(factors); + } else { + return Entry.sizeOf(factors); + } + } + + void disableInitV() { + this._initV = false; } @Override @@ -86,7 +111,7 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi @Override public float getW(@Nonnull final Feature x) { - int j = x.getFeatureIndex(); + int j = Feature.toIntFeature(x); Entry entry = getEntry(j); if (entry == null) { @@ -97,12 +122,11 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi @Override protected void setW(@Nonnull final Feature x, final float nextWi) { - final int j = x.getFeatureIndex(); + final int j = Feature.toIntFeature(x); Entry entry = getEntry(j); if (entry == null) { - float[] V = initV(); - entry = newEntry(nextWi, V); + entry = newEntry(j, nextWi); long ptr = entry.getOffset(); _map.put(j, ptr); } else { @@ -110,53 +134,6 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi } } - @Override - void updateWi(final double dloss, @Nonnull final Feature x, final float eta) { - final double Xi = x.getValue(); - float gradWi = (float) (dloss * Xi); - - final Entry theta = getEntry(x); - float wi = theta.getW(); - - float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi); - if (!NumberUtils.isFinite(nextWi)) { - throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() - + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss - + ", eta=" + eta); - } - theta.setW(nextWi); - } - - /** - * Update Wi using Follow-the-Regularized-Leader - */ - boolean updateWiFTRL(final double dloss, @Nonnull final Feature x, final float eta) { - final double Xi = x.getValue(); - float gradWi = (float) (dloss * Xi); - - final Entry theta = getEntry(x); - float wi = theta.getW(); - - final float z = theta.updateZ(gradWi, _alpha); - final double n = theta.updateN(gradWi); - - if (Math.abs(z) <= _lambda1) { - removeEntry(x); - return wi != 0; - } - - final float nextWi = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n)) - / _alpha + _lamdda2)); - if (!NumberUtils.isFinite(nextWi)) { - throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() - + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss - + ", eta=" + eta + ", n=" + n + ", z=" + z); - } - theta.setW(nextWi); - return (nextWi != 0) || (wi != 0); - } - - /** * @return V_x,yField,f */ @@ -166,10 +143,16 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi Entry entry = getEntry(j); if (entry == null) { + if (_initV == false) { + return 0.f; + } else if (_removedV.contains(j)) { + return 0.f; + } float[] V = initV(); - entry = newEntry(V); + entry = newEntry(j, V); long ptr = entry.getOffset(); _map.put(j, ptr); + return V[f]; } return entry.getV(f); } @@ -181,8 +164,13 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi Entry entry = getEntry(j); if (entry == null) { + if (_initV == false) { + return; + } else if (_removedV.contains(j)) { + return; + } float[] V = initV(); - entry = newEntry(V); + entry = newEntry(j, V); long ptr = entry.getOffset(); _map.put(j, ptr); } @@ -190,13 +178,12 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi } @Override - protected Entry getEntry(@Nonnull final Feature x) { - final int j = x.getFeatureIndex(); + protected Entry getEntryW(@Nonnull final Feature x) { + final int j = Feature.toIntFeature(x); Entry entry = getEntry(j); if (entry == null) { - float[] V = initV(); - entry = newEntry(V); + entry = newEntry(j, 0.f); long ptr = entry.getOffset(); _map.put(j, ptr); } @@ -204,51 +191,92 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi } @Override - protected Entry getEntry(@Nonnull final Feature x, @Nonnull final int yField) { + protected Entry getEntryV(@Nonnull final Feature x, @Nonnull final int yField) { final int j = Feature.toIntFeature(x, yField, _numFields); Entry entry = getEntry(j); if (entry == null) { + if (_initV == false) { + return null; + } else if (_removedV.contains(j)) { + return null; + } float[] V = initV(); - entry = newEntry(V); + entry = newEntry(j, V); long ptr = entry.getOffset(); _map.put(j, ptr); } return entry; } - protected void removeEntry(@Nonnull final Feature x) { - int j = x.getFeatureIndex(); - _map.remove(j); + @Override + protected void removeEntry(@Nonnull final Entry entry) { + final int j = entry.getKey(); + final long ptr = _map.remove(j); + if (ptr == -1L) { + return; // should never be happen. + } + entry.clear(); + if (Entry.isEntryW(j)) { + _freelistW.add(ptr); + this._numRemovedW++; + this._bytesUsed -= _entrySizeW; + } else { + _removedV.add(j); + _freelistV.add(ptr); + this._numRemovedV++; + this._bytesUsed -= _entrySizeV; + } } @Nonnull - protected final Entry newEntry(final float W, @Nonnull final float[] V) { - Entry entry = newEntry(); - entry.setW(W); - entry.setV(V); - return entry; - } + protected final Entry newEntry(final int key, final float W) { + final long ptr; + if (_freelistW.isEmpty()) { + ptr = _buf.allocate(_entrySizeW); + this._numAllocatedW++; + this._bytesAllocated += _entrySizeW; + this._bytesUsed += _entrySizeW; + } else {// reuse removed entry + ptr = _freelistW.remove(); + this._numReusedW++; + } + final Entry entry; + if (_useFTRL) { + entry = new FTRLEntry(_buf, key, ptr); + } else if (_useAdaGrad) { + entry = new AdaGradEntry(_buf, key, ptr); + } else { + entry = new Entry(_buf, key, ptr); + } - @Nonnull - protected final Entry newEntry(@Nonnull final float[] V) { - Entry entry = newEntry(); - entry.setV(V); + entry.setW(W); return entry; } @Nonnull - private Entry newEntry() { + protected final Entry newEntry(final int key, @Nonnull final float[] V) { + final long ptr; + if (_freelistV.isEmpty()) { + ptr = _buf.allocate(_entrySizeV); + this._numAllocatedV++; + this._bytesAllocated += _entrySizeV; + this._bytesUsed += _entrySizeV; + } else {// reuse removed entry + ptr = _freelistV.remove(); + this._numReusedV++; + } + final Entry entry; if (_useFTRL) { - long ptr = _buf.allocate(_entrySize); - return new FTRLEntry(_buf, _factor, ptr); + entry = new FTRLEntry(_buf, _factor, key, ptr); } else if (_useAdaGrad) { - long ptr = _buf.allocate(_entrySize); - return new AdaGradEntry(_buf, _factor, ptr); + entry = new AdaGradEntry(_buf, _factor, key, ptr); } else { - long ptr = _buf.allocate(_entrySize); - return new Entry(_buf, _factor, ptr); + entry = new Entry(_buf, _factor, key, ptr); } + + entry.setV(V); + return entry; } @Nullable @@ -257,28 +285,95 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi if (ptr == -1L) { return null; } - return getEntry(ptr); + return getEntry(key, ptr); } @Nonnull - private Entry getEntry(long ptr) { - if (_useFTRL) { - return new FTRLEntry(_buf, _factor, ptr); - } else if (_useAdaGrad) { - return new AdaGradEntry(_buf, _factor, ptr); + private Entry getEntry(final int key, @Nonnegative final long ptr) { + if (Entry.isEntryW(key)) { + if (_useFTRL) { + return new FTRLEntry(_buf, key, ptr); + } else if (_useAdaGrad) { + return new AdaGradEntry(_buf, key, ptr); + } else { + return new Entry(_buf, key, ptr); + } } else { - return new Entry(_buf, _factor, ptr); + if (_useFTRL) { + return new FTRLEntry(_buf, _factor, key, ptr); + } else if (_useAdaGrad) { + return new AdaGradEntry(_buf, _factor, key, ptr); + } else { + return new Entry(_buf, _factor, key, ptr); + } } } - private static int entrySize(int factors, boolean ftrl, boolean adagrad) { - if (ftrl) { - return FTRLEntry.sizeOf(factors); - } else if (adagrad) { - return AdaGradEntry.sizeOf(factors); - } else { - return Entry.sizeOf(factors); + @Nonnull + String getStatistics() { + final NumberFormat fmt = NumberFormat.getIntegerInstance(Locale.US); + return "FFMStringFeatureMapModel [bytesAllocated=" + + NumberUtils.prettySize(_bytesAllocated) + ", bytesUsed=" + + NumberUtils.prettySize(_bytesUsed) + ", numAllocatedW=" + + fmt.format(_numAllocatedW) + ", numReusedW=" + fmt.format(_numReusedW) + + ", numRemovedW=" + fmt.format(_numRemovedW) + ", numAllocatedV=" + + fmt.format(_numAllocatedV) + ", numReusedV=" + fmt.format(_numReusedV) + + ", numRemovedV=" + fmt.format(_numRemovedV) + "]"; + } + + @Override + public String toString() { + return getStatistics(); + } + + @Nonnull + EntryIterator entries() { + return new EntryIterator(this); + } + + static final class EntryIterator { + + @Nonnull + private final MapIterator dictItor; + @Nonnull + private final Entry entryProbeW; + @Nonnull + private final Entry entryProbeV; + + EntryIterator(@Nonnull FFMStringFeatureMapModel model) { + this.dictItor = model._map.entries(); + this.entryProbeW = new Entry(model._buf, 1); + this.entryProbeV = new Entry(model._buf, model._factor); + } + + @Nonnull + Entry getEntryProbeW() { + return entryProbeW; } + + @Nonnull + Entry getEntryProbeV() { + return entryProbeV; + } + + boolean hasNext() { + return dictItor.hasNext(); + } + + boolean next() { + return dictItor.next() != -1; + } + + int getEntryIndex() { + return dictItor.getKey(); + } + + @Nonnull + void getEntry(@Nonnull final Entry probe) { + long offset = dictItor.getValue(); + probe.setOffset(offset); + } + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FMHyperParameters.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java b/core/src/main/java/hivemall/fm/FMHyperParameters.java index accb99a..15c1c56 100644 --- a/core/src/main/java/hivemall/fm/FMHyperParameters.java +++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java @@ -143,16 +143,15 @@ class FMHyperParameters { int numFields = Feature.DEFAULT_NUM_FIELDS; // adagrad - boolean useAdaGrad = true; - float eta0_V = 1.f; + boolean useAdaGrad = false; float eps = 1.f; // FTRL - boolean useFTRL = true; - float alphaFTRL = 0.1f; // Learning Rate + boolean useFTRL = false; + float alphaFTRL = 0.2f; // Learning Rate float betaFTRL = 1.f; // Smoothing parameter for AdaGrad - float lambda1 = 0.1f; // L1 Regularization - float lamdda2 = 0.01f; // L2 Regularization + float lambda1 = 0.001f; // L1 Regularization + float lamdda2 = 0.0001f; // L2 Regularization FFMHyperParameters() { super(); @@ -171,42 +170,59 @@ class FMHyperParameters { // feature hashing if (numFeatures == -1) { - int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), - Feature.DEFAULT_FEATURE_BITS); - if (hashbits < 18 || hashbits > 31) { - throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: " - + hashbits); + int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), -1); + if (hashbits != -1) { + if (hashbits < 18 || hashbits > 31) { + throw new UDFArgumentException( + "-feature_hashing MUST be in range [18,31]: " + hashbits); + } + this.numFeatures = 1 << hashbits; } - this.numFeatures = 1 << hashbits; } this.numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), numFields); if (numFields <= 1) { throw new UDFArgumentException("-num_fields MUST be greater than 1: " + numFields); } - // adagrad - this.useAdaGrad = !cl.hasOption("disable_adagrad"); - this.eta0_V = Primitives.parseFloat(cl.getOptionValue("eta0_V"), eta0_V); - this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), eps); - - // FTRL - this.useFTRL = !cl.hasOption("disable_ftrl"); - this.alphaFTRL = Primitives.parseFloat(cl.getOptionValue("alphaFTRL"), alphaFTRL); - if (alphaFTRL == 0.f) { - throw new UDFArgumentException("-alphaFTRL SHOULD NOT be 0"); + // optimizer + final String optimizer = cl.getOptionValue("optimizer", "ftrl").toLowerCase(); + switch (optimizer) { + case "ftrl": { + this.useFTRL = true; + this.useAdaGrad = false; + this.alphaFTRL = Primitives.parseFloat(cl.getOptionValue("alphaFTRL"), + alphaFTRL); + if (alphaFTRL == 0.f) { + throw new UDFArgumentException("-alphaFTRL SHOULD NOT be 0"); + } + this.betaFTRL = Primitives.parseFloat(cl.getOptionValue("betaFTRL"), betaFTRL); + this.lambda1 = Primitives.parseFloat(cl.getOptionValue("lambda1"), lambda1); + this.lamdda2 = Primitives.parseFloat(cl.getOptionValue("lamdda2"), lamdda2); + break; + } + case "adagrad": { + this.useAdaGrad = true; + this.useFTRL = false; + this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), eps); + break; + } + case "sgd": + // fall through + default: { + this.useFTRL = false; + this.useAdaGrad = false; + break; + } } - this.betaFTRL = Primitives.parseFloat(cl.getOptionValue("betaFTRL"), betaFTRL); - this.lambda1 = Primitives.parseFloat(cl.getOptionValue("lambda1"), lambda1); - this.lamdda2 = Primitives.parseFloat(cl.getOptionValue("lamdda2"), lamdda2); } @Override public String toString() { return "FFMHyperParameters [globalBias=" + globalBias + ", linearCoeff=" + linearCoeff - + ", numFields=" + numFields + ", useAdaGrad=" + useAdaGrad + ", eta0_V=" - + eta0_V + ", eps=" + eps + ", useFTRL=" + useFTRL + ", alphaFTRL=" + alphaFTRL - + ", betaFTRL=" + betaFTRL + ", lambda1=" + lambda1 + ", lamdda2=" + lamdda2 - + "], " + super.toString(); + + ", numFields=" + numFields + ", useAdaGrad=" + useAdaGrad + ", eps=" + eps + + ", useFTRL=" + useFTRL + ", alphaFTRL=" + alphaFTRL + ", betaFTRL=" + + betaFTRL + ", lambda1=" + lambda1 + ", lamdda2=" + lamdda2 + "], " + + super.toString(); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java index 19ac287..be39b0b 100644 --- a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java @@ -19,7 +19,7 @@ package hivemall.fm; import hivemall.utils.collections.maps.Int2FloatOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashMap; +import hivemall.utils.collections.maps.IntOpenHashTable; import java.util.Arrays; @@ -33,7 +33,7 @@ public final class FMIntFeatureMapModel extends FactorizationMachineModel { // LEARNING PARAMS private float _w0; private final Int2FloatOpenHashTable _w; - private final IntOpenHashMap<float[]> _V; + private final IntOpenHashTable<float[]> _V; private int _minIndex, _maxIndex; @@ -42,7 +42,7 @@ public final class FMIntFeatureMapModel extends FactorizationMachineModel { this._w0 = 0.f; this._w = new Int2FloatOpenHashTable(DEFAULT_MAPSIZE); _w.defaultReturnValue(0.f); - this._V = new IntOpenHashMap<float[]>(DEFAULT_MAPSIZE); + this._V = new IntOpenHashTable<float[]>(DEFAULT_MAPSIZE); this._minIndex = 0; this._maxIndex = 0; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java index 667befb..730cc49 100644 --- a/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java +++ b/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java @@ -18,6 +18,9 @@ */ package hivemall.fm; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_ARRAY_META; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_REF; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES2; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; @@ -35,6 +38,7 @@ import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; @@ -234,6 +238,7 @@ public final class FMPredictGenericUDAF extends AbstractGenericUDAFResolver { } + @AggregationType(estimable = true) public static class FMPredictAggregationBuffer extends AbstractAggregationBuffer { private double ret; @@ -328,6 +333,16 @@ public final class FMPredictGenericUDAF extends AbstractGenericUDAFResolver { } return predict; } + + @Override + public int estimate() { + if (sumVjXj == null) { + return PRIMITIVES2 + 2 * JAVA64_REF; + } else { + // model.array() = JAVA64_ARRAY_META + JAVA64_REF + return PRIMITIVES2 + 2 * (JAVA64_ARRAY_META + PRIMITIVES2 * sumVjXj.length); + } + } } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java index cd99046..4eec280 100644 --- a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java @@ -19,7 +19,7 @@ package hivemall.fm; import hivemall.utils.collections.IMapIterator; -import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.collections.maps.OpenHashMap; import javax.annotation.Nonnull; @@ -28,12 +28,12 @@ public final class FMStringFeatureMapModel extends FactorizationMachineModel { // LEARNING PARAMS private float _w0; - private final OpenHashTable<String, Entry> _map; + private final OpenHashMap<String, Entry> _map; public FMStringFeatureMapModel(@Nonnull FMHyperParameters params) { super(params); this._w0 = 0.f; - this._map = new OpenHashTable<String, FMStringFeatureMapModel.Entry>(DEFAULT_MAPSIZE); + this._map = new OpenHashMap<String, FMStringFeatureMapModel.Entry>(DEFAULT_MAPSIZE); } @Override @@ -42,7 +42,7 @@ public final class FMStringFeatureMapModel extends FactorizationMachineModel { } IMapIterator<String, Entry> entries() { - return _map.entries(); + return _map.entries(true); } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java index 65b6ba7..24210a8 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java @@ -117,8 +117,8 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { opts.addOption("c", "classification", false, "Act as classification"); opts.addOption("seed", true, "Seed value [default: -1 (random)]"); opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]"); - opts.addOption("p", "num_features", true, "The size of feature dimensions"); - opts.addOption("factor", "factors", true, "The number of the latent variables [default: 5]"); + opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]"); + opts.addOption("f", "factors", true, "The number of the latent variables [default: 5]"); opts.addOption("sigma", true, "The standard deviation for initializing V [default: 0.1]"); opts.addOption("lambda0", "lambda", true, "The initial lambda value for regularization [default: 0.01]"); @@ -376,7 +376,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { double loss = _lossFunction.loss(p, y); _cvState.incrLoss(loss); - if (MathUtils.closeToZero(lossGrad)) { + if (MathUtils.closeToZero(lossGrad, 1E-9d)) { return; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/Feature.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/Feature.java b/core/src/main/java/hivemall/fm/Feature.java index 2966a02..8ae6f20 100644 --- a/core/src/main/java/hivemall/fm/Feature.java +++ b/core/src/main/java/hivemall/fm/Feature.java @@ -23,6 +23,7 @@ import hivemall.utils.lang.NumberUtils; import java.nio.ByteBuffer; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -30,7 +31,7 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; public abstract class Feature { - public static final int DEFAULT_NUM_FIELDS = 1024; + public static final int DEFAULT_NUM_FIELDS = 256; public static final int DEFAULT_FEATURE_BITS = 21; public static final int DEFAULT_NUM_FEATURES = 1 << 21; // 2^21 @@ -51,10 +52,11 @@ public abstract class Feature { throw new UnsupportedOperationException(); } - public void setFeatureIndex(int i) { + public void setFeatureIndex(@Nonnegative int i) { throw new UnsupportedOperationException(); } + @Nonnegative public int getFeatureIndex() { throw new UnsupportedOperationException(); } @@ -127,6 +129,7 @@ public abstract class Feature { } } + @Nullable public static Feature[] parseFFMFeatures(@Nonnull final Object arg, @Nonnull final ListObjectInspector listOI, @Nullable final Feature[] probes, final int numFeatures, final int numFields) throws HiveException { @@ -176,6 +179,9 @@ public abstract class Feature { int index = parseFeatureIndex(fv); return new IntFeature(index, 1.d); } else { + if ("0".equals(fv)) { + throw new HiveException("Index value should not be 0: " + fv); + } return new StringFeature(/* index */fv, 1.d); } } else { @@ -187,6 +193,9 @@ public abstract class Feature { return new IntFeature(index, value); } else { double value = parseFeatureValue(valueStr); + if ("0".equals(indexStr)) { + throw new HiveException("Index value should not be 0: " + fv); + } return new StringFeature(/* index */indexStr, value); } } @@ -198,6 +207,12 @@ public abstract class Feature { } @Nonnull + static IntFeature parseFFMFeature(@Nonnull final String fv, final int numFeatures) + throws HiveException { + return parseFFMFeature(fv, -1, DEFAULT_NUM_FIELDS); + } + + @Nonnull static IntFeature parseFFMFeature(@Nonnull final String fv, final int numFeatures, final int numFields) throws HiveException { final int pos1 = fv.indexOf(':'); @@ -219,25 +234,26 @@ public abstract class Feature { } else { index = MurmurHash3.murmurhash3(lead, numFields); } - short field = (short) index; + short field = NumberUtils.castToShort(index); double value = parseFeatureValue(rest); return new IntFeature(index, field, value); } - final String indexStr = rest.substring(0, pos2); - final int index; + final short field; - if (NumberUtils.isDigits(indexStr) && NumberUtils.isDigits(lead)) { - index = parseFeatureIndex(indexStr); - if (index >= (numFeatures + numFields)) { - throw new HiveException("Feature index MUST be less than " - + (numFeatures + numFields) + " but was " + index); - } + if (NumberUtils.isDigits(lead)) { field = parseField(lead, numFields); } else { + field = NumberUtils.castToShort(MurmurHash3.murmurhash3(lead, numFields)); + } + + final int index; + final String indexStr = rest.substring(0, pos2); + if (numFeatures == -1 && NumberUtils.isDigits(indexStr)) { + index = parseFeatureIndex(indexStr); + } else { // +NUM_FIELD to avoid conflict to quantitative features index = MurmurHash3.murmurhash3(indexStr, numFeatures) + numFields; - field = (short) MurmurHash3.murmurhash3(lead, numFields); } String valueStr = rest.substring(pos2 + 1); double value = parseFeatureValue(valueStr); @@ -253,6 +269,9 @@ public abstract class Feature { int index = parseFeatureIndex(fv); probe.setFeatureIndex(index); } else { + if ("0".equals(fv)) { + throw new HiveException("Index value should not be 0: " + fv); + } probe.setFeature(fv); } probe.value = 1.d; @@ -264,6 +283,9 @@ public abstract class Feature { probe.setFeatureIndex(index); probe.value = parseFeatureValue(valueStr); } else { + if ("0".equals(indexStr)) { + throw new HiveException("Index value should not be 0: " + fv); + } probe.setFeature(indexStr); probe.value = parseFeatureValue(valueStr); } @@ -296,27 +318,26 @@ public abstract class Feature { } else { index = MurmurHash3.murmurhash3(lead, numFields); } - short field = (short) index; + short field = NumberUtils.castToShort(index); probe.setField(field); probe.setFeatureIndex(index); probe.value = parseFeatureValue(rest); return; } - String indexStr = rest.substring(0, pos2); - final int index; final short field; - if (NumberUtils.isDigits(indexStr) && NumberUtils.isDigits(lead)) { - index = parseFeatureIndex(indexStr); - if (index >= (numFeatures + numFields)) { - throw new HiveException("Feature index MUST be less than " - + (numFeatures + numFields) + " but was " + index); - } + if (NumberUtils.isDigits(lead)) { field = parseField(lead, numFields); } else { + field = NumberUtils.castToShort(MurmurHash3.murmurhash3(lead, numFields)); + } + final int index; + final String indexStr = rest.substring(0, pos2); + if (numFeatures == -1 && NumberUtils.isDigits(indexStr)) { + index = parseFeatureIndex(indexStr); + } else { // +NUM_FIELD to avoid conflict to quantitative features index = MurmurHash3.murmurhash3(indexStr, numFeatures) + numFields; - field = (short) MurmurHash3.murmurhash3(lead, numFields); } probe.setField(field); probe.setFeatureIndex(index); @@ -325,7 +346,6 @@ public abstract class Feature { probe.value = parseFeatureValue(valueStr); } - private static int parseFeatureIndex(@Nonnull final String indexStr) throws HiveException { final int index; try { @@ -333,7 +353,7 @@ public abstract class Feature { } catch (NumberFormatException e) { throw new HiveException("Invalid index value: " + indexStr, e); } - if (index < 0) { + if (index <= 0) { throw new HiveException("Feature index MUST be greater than 0: " + indexStr); } return index; @@ -361,7 +381,13 @@ public abstract class Feature { return field; } - public static int toIntFeature(@Nonnull final Feature x, final int yField, final int numFields) { + public static int toIntFeature(@Nonnull final Feature x) { + int index = x.getFeatureIndex(); + return -index; + } + + public static int toIntFeature(@Nonnull final Feature x, @Nonnegative final int yField, + @Nonnegative final int numFields) { int index = x.getFeatureIndex(); return index * numFields + yField; }