Changed FFM prediction model as a scalable format
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/550bb4e6 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/550bb4e6 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/550bb4e6 Branch: refs/heads/HIVEMALL-24-2 Commit: 550bb4e6f0f69cfd25506c34caf67e3c014d5750 Parents: 36a6ca2 Author: Makoto Yui <[email protected]> Authored: Tue Jul 25 20:03:29 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Tue Jul 25 20:03:29 2017 +0900 ---------------------------------------------------------------------- core/src/main/java/hivemall/fm/Entry.java | 2 +- .../java/hivemall/fm/FFMPredictGenericUDAF.java | 21 +- .../main/java/hivemall/fm/FFMPredictUDF.java | 187 ---------- .../java/hivemall/fm/FFMPredictionModel.java | 349 ------------------- .../hivemall/fm/FFMStringFeatureMapModel.java | 53 ++- .../fm/FieldAwareFactorizationMachineUDTF.java | 82 +++-- .../hivemall/fm/FFMPredictionModelTest.java | 65 ---- resources/ddl/define-all-as-permanent.hive | 2 +- resources/ddl/define-all.hive | 2 +- resources/ddl/define-all.spark | 2 +- resources/ddl/define-udfs.td.hql | 2 + 11 files changed, 111 insertions(+), 656 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/Entry.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/Entry.java b/core/src/main/java/hivemall/fm/Entry.java index 1882f85..209112c 100644 --- a/core/src/main/java/hivemall/fm/Entry.java +++ b/core/src/main/java/hivemall/fm/Entry.java @@ -58,7 +58,7 @@ class Entry { return _offset; } - void setOffset(long offset) { + void setOffset(final long offset) { this._offset = offset; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java index 91d1b6b..a37a1b8 100644 --- a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java +++ b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java @@ -19,7 +19,12 @@ package hivemall.fm; import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.hadoop.WritableUtils; + +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nonnull; + import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; @@ -30,20 +35,18 @@ import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; -import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import java.util.ArrayList; -import java.util.List; - @Description( name = "ffm_predict", value = "_FUNC_(Float Wi, Float Wj, array<float> Vifj, array<float> Vjfi, float Xi, float Xj)" http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMPredictUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMPredictUDF.java b/core/src/main/java/hivemall/fm/FFMPredictUDF.java deleted file mode 100644 index 48745d9..0000000 --- a/core/src/main/java/hivemall/fm/FFMPredictUDF.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.fm; - -import hivemall.annotations.Experimental; -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.lang.NumberUtils; - -import java.io.IOException; -import java.util.Arrays; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; -import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; -import org.apache.hadoop.io.Text; - -/** - * @since v0.5-rc.1 - */ -@Description(name = "ffm_predict", - value = "_FUNC_(string modelId, string model, array<string> features)" - + " returns a prediction result in double from a Field-aware Factorization Machine") -@UDFType(deterministic = true, stateful = false) -@Experimental -public final class FFMPredictUDF extends GenericUDF { - - private StringObjectInspector _modelIdOI; - private StringObjectInspector _modelOI; - private ListObjectInspector _featureListOI; - - private DoubleWritable _result; - @Nullable - private String _cachedModeId; - @Nullable - private FFMPredictionModel _cachedModel; - @Nullable - private Feature[] _probes; - - public FFMPredictUDF() {} - - @Override - public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - if (argOIs.length != 3) { - throw new UDFArgumentException("_FUNC_ takes 3 arguments"); - } - this._modelIdOI = HiveUtils.asStringOI(argOIs[0]); - this._modelOI = HiveUtils.asStringOI(argOIs[1]); - this._featureListOI = HiveUtils.asListOI(argOIs[2]); - - this._result = new DoubleWritable(); - return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; - } - - @Override - public Object evaluate(DeferredObject[] args) throws HiveException { - String modelId = _modelIdOI.getPrimitiveJavaObject(args[0].get()); - if (modelId == null) { - throw new HiveException("modelId is not set"); - } - - final FFMPredictionModel model; - if (modelId.equals(_cachedModeId)) { - model = this._cachedModel; - } else { - Text serModel = _modelOI.getPrimitiveWritableObject(args[1].get()); - if (serModel == null) { - throw new HiveException("Model is null for model ID: " + modelId); - } - byte[] b = serModel.getBytes(); - final int length = serModel.getLength(); - try { - model = FFMPredictionModel.deserialize(b, length); - b = null; - } catch (ClassNotFoundException e) { - throw new HiveException(e); - } catch (IOException e) { - throw new HiveException(e); - } - this._cachedModeId = modelId; - this._cachedModel = model; - } - - int numFeatures = model.getNumFeatures(); - int numFields = model.getNumFields(); - - Object arg2 = args[2].get(); - // [workaround] - // java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray - // cannot be cast to [Ljava.lang.Object; - if (arg2 instanceof LazyBinaryArray) { - arg2 = ((LazyBinaryArray) arg2).getList(); - } - Feature[] x = Feature.parseFFMFeatures(arg2, _featureListOI, _probes, numFeatures, - numFields); - if (x == null || x.length == 0) { - return null; // return NULL if there are no features - } - this._probes = x; - - double predicted = predict(x, model); - _result.set(predicted); - return _result; - } - - private static double predict(@Nonnull final Feature[] x, - @Nonnull final FFMPredictionModel model) throws HiveException { - // w0 - double ret = model.getW0(); - // W - for (Feature e : x) { - double xi = e.getValue(); - float wi = model.getW(e); - double wx = wi * xi; - ret += wx; - } - // V - final int factors = model.getNumFactors(); - final float[] vij = new float[factors]; - final float[] vji = new float[factors]; - for (int i = 0; i < x.length; ++i) { - final Feature ei = x[i]; - final double xi = ei.getValue(); - final int iField = ei.getField(); - for (int j = i + 1; j < x.length; ++j) { - final Feature ej = x[j]; - final double xj = ej.getValue(); - final int jField = ej.getField(); - if (!model.getV(ei, jField, vij)) { - continue; - } - if (!model.getV(ej, iField, vji)) { - continue; - } - for (int f = 0; f < factors; f++) { - float vijf = vij[f]; - float vjif = vji[f]; - ret += vijf * vjif * xi * xj; - } - } - } - if (!NumberUtils.isFinite(ret)) { - throw new HiveException("Detected " + ret + " in ffm_predict"); - } - return ret; - } - - @Override - public void close() throws IOException { - super.close(); - // clean up to help GC - this._cachedModel = null; - this._probes = null; - } - - @Override - public String getDisplayString(String[] args) { - return "ffm_predict(" + Arrays.toString(args) + ")"; - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMPredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMPredictionModel.java b/core/src/main/java/hivemall/fm/FFMPredictionModel.java deleted file mode 100644 index befbec9..0000000 --- a/core/src/main/java/hivemall/fm/FFMPredictionModel.java +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.fm; - -import hivemall.utils.buffer.HeapBuffer; -import hivemall.utils.codec.VariableByteCodec; -import hivemall.utils.codec.ZigZagLEB128Codec; -import hivemall.utils.collections.maps.Int2LongOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashTable; -import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm; -import hivemall.utils.io.IOUtils; -import hivemall.utils.lang.ArrayUtils; -import hivemall.utils.lang.HalfFloat; -import hivemall.utils.lang.ObjectUtils; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.Arrays; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -public final class FFMPredictionModel implements Externalizable { - private static final Log LOG = LogFactory.getLog(FFMPredictionModel.class); - - private static final byte HALF_FLOAT_ENTRY = 1; - private static final byte W_ONLY_HALF_FLOAT_ENTRY = 2; - private static final byte FLOAT_ENTRY = 3; - private static final byte W_ONLY_FLOAT_ENTRY = 4; - - /** - * maps feature to feature weight pointer - */ - private Int2LongOpenHashTable _map; - private HeapBuffer _buf; - - private double _w0; - private int _factors; - private int _numFeatures; - private int _numFields; - - public FFMPredictionModel() {}// for Externalizable - - public FFMPredictionModel(@Nonnull Int2LongOpenHashTable map, @Nonnull HeapBuffer buf, - double w0, int factor, int numFeatures, int numFields) { - this._map = map; - this._buf = buf; - this._w0 = w0; - this._factors = factor; - this._numFeatures = numFeatures; - this._numFields = numFields; - } - - public int getNumFactors() { - return _factors; - } - - public double getW0() { - return _w0; - } - - public int getNumFeatures() { - return _numFeatures; - } - - public int getNumFields() { - return _numFields; - } - - public int getActualNumFeatures() { - return _map.size(); - } - - public long approxBytesConsumed() { - int size = _map.size(); - - // [map] size * (|state| + |key| + |entry|) - long bytes = size * (1L + 4L + 4L + (4L * _factors)); - int rest = _map.capacity() - size; - if (rest > 0) { - bytes += rest * 1L; - } - // w0, factors, numFeatures, numFields, used, size - bytes += (8 + 4 + 4 + 4 + 4 + 4); - return bytes; - } - - @Nullable - private Entry getEntry(final int key) { - final long ptr = _map.get(key); - if (ptr == -1L) { - return null; - } - return new Entry(_buf, _factors, ptr); - } - - public float getW(@Nonnull final Feature x) { - int j = x.getFeatureIndex(); - - Entry entry = getEntry(j); - if (entry == null) { - return 0.f; - } - return entry.getW(); - } - - /** - * @return true if V exists - */ - public boolean getV(@Nonnull final Feature x, @Nonnull final int yField, @Nonnull float[] dst) { - int j = Feature.toIntFeature(x, yField, _numFields); - - Entry entry = getEntry(j); - if (entry == null) { - return false; - } - - entry.getV(dst); - if (ArrayUtils.equals(dst, 0.f)) { - return false; // treat as null - } - return true; - } - - @Override - public void writeExternal(@Nonnull ObjectOutput out) throws IOException { - out.writeDouble(_w0); - final int factors = _factors; - out.writeInt(factors); - out.writeInt(_numFeatures); - out.writeInt(_numFields); - - int used = _map.size(); - out.writeInt(used); - - final int[] keys = _map.getKeys(); - final int size = keys.length; - out.writeInt(size); - - final byte[] states = _map.getStates(); - writeStates(states, out); - - final long[] values = _map.getValues(); - - final HeapBuffer buf = _buf; - final Entry e = new Entry(buf, factors); - final float[] Vf = new float[factors]; - for (int i = 0; i < size; i++) { - if (states[i] != IntOpenHashTable.FULL) { - continue; - } - ZigZagLEB128Codec.writeSignedInt(keys[i], out); - e.setOffset(values[i]); - writeEntry(e, factors, Vf, out); - } - - // help GC - this._map = null; - this._buf = null; - } - - private static void writeEntry(@Nonnull final Entry e, final int factors, - @Nonnull final float[] Vf, @Nonnull final DataOutput out) throws IOException { - final float W = e.getW(); - e.getV(Vf); - - if (ArrayUtils.almostEquals(Vf, 0.f)) { - if (HalfFloat.isRepresentable(W)) { - out.writeByte(W_ONLY_HALF_FLOAT_ENTRY); - out.writeShort(HalfFloat.floatToHalfFloat(W)); - } else { - out.writeByte(W_ONLY_FLOAT_ENTRY); - out.writeFloat(W); - } - } else if (isRepresentableAsHalfFloat(W, Vf)) { - out.writeByte(HALF_FLOAT_ENTRY); - out.writeShort(HalfFloat.floatToHalfFloat(W)); - for (int i = 0; i < factors; i++) { - out.writeShort(HalfFloat.floatToHalfFloat(Vf[i])); - } - } else { - out.writeByte(FLOAT_ENTRY); - out.writeFloat(W); - IOUtils.writeFloats(Vf, factors, out); - } - } - - private static boolean isRepresentableAsHalfFloat(final float W, @Nonnull final float[] Vf) { - if (!HalfFloat.isRepresentable(W)) { - return false; - } - for (float V : Vf) { - if (!HalfFloat.isRepresentable(V)) { - return false; - } - } - return true; - } - - @Nonnull - static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out) - throws IOException { - // write empty states's indexes differentially - final int size = status.length; - int cardinarity = 0; - for (int i = 0; i < size; i++) { - if (status[i] != IntOpenHashTable.FULL) { - cardinarity++; - } - } - out.writeInt(cardinarity); - if (cardinarity == 0) { - return; - } - int prev = 0; - for (int i = 0; i < size; i++) { - if (status[i] != IntOpenHashTable.FULL) { - int diff = i - prev; - assert (diff >= 0); - VariableByteCodec.encodeUnsignedInt(diff, out); - prev = i; - } - } - } - - @Override - public void readExternal(@Nonnull final ObjectInput in) throws IOException, - ClassNotFoundException { - this._w0 = in.readDouble(); - final int factors = in.readInt(); - this._factors = factors; - this._numFeatures = in.readInt(); - this._numFields = in.readInt(); - - final int used = in.readInt(); - final int size = in.readInt(); - final int[] keys = new int[size]; - final long[] values = new long[size]; - final byte[] states = new byte[size]; - readStates(in, states); - - final int entrySize = Entry.sizeOf(factors); - int numChunks = (entrySize * used) / HeapBuffer.DEFAULT_CHUNK_BYTES + 1; - final HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE, numChunks); - final Entry e = new Entry(buf, factors); - final float[] Vf = new float[factors]; - for (int i = 0; i < size; i++) { - if (states[i] != IntOpenHashTable.FULL) { - continue; - } - keys[i] = ZigZagLEB128Codec.readSignedInt(in); - long ptr = buf.allocate(entrySize); - e.setOffset(ptr); - readEntry(in, factors, Vf, e); - values[i] = ptr; - } - - this._map = new Int2LongOpenHashTable(keys, values, states, used); - this._buf = buf; - } - - @Nonnull - private static void readEntry(@Nonnull final DataInput in, final int factors, - @Nonnull final float[] Vf, @Nonnull Entry dst) throws IOException { - final byte type = in.readByte(); - switch (type) { - case HALF_FLOAT_ENTRY: { - float W = HalfFloat.halfFloatToFloat(in.readShort()); - dst.setW(W); - for (int i = 0; i < factors; i++) { - Vf[i] = HalfFloat.halfFloatToFloat(in.readShort()); - } - dst.setV(Vf); - break; - } - case W_ONLY_HALF_FLOAT_ENTRY: { - float W = HalfFloat.halfFloatToFloat(in.readShort()); - dst.setW(W); - break; - } - case FLOAT_ENTRY: { - float W = in.readFloat(); - dst.setW(W); - IOUtils.readFloats(in, Vf); - dst.setV(Vf); - break; - } - case W_ONLY_FLOAT_ENTRY: { - float W = in.readFloat(); - dst.setW(W); - break; - } - default: - throw new IOException("Unexpected Entry type: " + type); - } - } - - @Nonnull - static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status) - throws IOException { - // read non-empty states differentially - final int cardinarity = in.readInt(); - Arrays.fill(status, IntOpenHashTable.FULL); - int prev = 0; - for (int j = 0; j < cardinarity; j++) { - int i = VariableByteCodec.decodeUnsignedInt(in) + prev; - status[i] = IntOpenHashTable.FREE; - prev = i; - } - } - - public byte[] serialize() throws IOException { - LOG.info("FFMPredictionModel#serialize(): " + _buf.toString()); - return ObjectUtils.toCompressedBytes(this, CompressionAlgorithm.lzma2, true); - } - - public static FFMPredictionModel deserialize(@Nonnull final byte[] serializedObj, final int len) - throws ClassNotFoundException, IOException { - FFMPredictionModel model = new FFMPredictionModel(); - ObjectUtils.readCompressedObject(serializedObj, len, model, CompressionAlgorithm.lzma2, - true); - LOG.info("FFMPredictionModel#deserialize(): " + model._buf.toString()); - return model; - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java index 4f445fa..2264063 100644 --- a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java @@ -23,6 +23,7 @@ import hivemall.fm.Entry.FTRLEntry; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.buffer.HeapBuffer; import hivemall.utils.collections.maps.Int2LongOpenHashTable; +import hivemall.utils.collections.maps.Int2LongOpenHashTable.IMapIterator; import hivemall.utils.lang.NumberUtils; import hivemall.utils.math.MathUtils; @@ -39,7 +40,7 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi private final HeapBuffer _buf; // hyperparams - private final int _numFeatures; + // private final int _numFeatures; private final int _numFields; // FTEL @@ -55,7 +56,7 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi this._w0 = 0.f; this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE); this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE); - this._numFeatures = params.numFeatures; + // this._numFeatures = params.numFeatures; this._numFields = params.numFields; this._alpha = params.alphaFTRL; this._beta = params.betaFTRL; @@ -64,11 +65,6 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi this._entrySize = entrySize(_factor, _useFTRL, _useAdaGrad); } - @Nonnull - FFMPredictionModel toPredictionModel() { - return new FFMPredictionModel(_map, _buf, _w0, _factor, _numFeatures, _numFields); - } - @Override public int getSize() { return _map.size(); @@ -271,6 +267,49 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi } } + @Nonnull + EntryIterator entries() { + return new EntryIterator(this); + } + + static final class EntryIterator { + + @Nonnull + private final IMapIterator dictItor; + @Nonnull + private final Entry entryProbe; + + EntryIterator(@Nonnull FFMStringFeatureMapModel model) { + this.dictItor = model._map.entries(); + this.entryProbe = new Entry(model._buf, model._factor); + } + + @Nonnull + Entry getEntryProbe() { + return entryProbe; + } + + boolean hasNext() { + return dictItor.hasNext(); + } + + int next() { + return dictItor.next(); + } + + int getEntryIndex() { + return dictItor.getKey(); + } + + @Nonnull + void getEntry(@Nonnull final Entry probe) { + long offset = dictItor.getValue(); + probe.setOffset(offset); + } + + } + + private static int entrySize(int factors, boolean ftrl, boolean adagrad) { if (ftrl) { return FTRLEntry.sizeOf(factors); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java index 67dbf87..d1c9e73 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java @@ -18,25 +18,23 @@ */ package hivemall.fm; +import hivemall.fm.FFMStringFeatureMapModel.EntryIterator; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.collections.arrays.DoubleArray3D; import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.hadoop.HadoopUtils; -import hivemall.utils.hadoop.Text3; -import hivemall.utils.lang.NumberUtils; +import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.math.MathUtils; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; @@ -44,6 +42,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; /** @@ -56,7 +56,6 @@ import org.apache.hadoop.io.Text; name = "train_ffm", value = "_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model") public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachineUDTF { - private static final Log LOG = LogFactory.getLog(FieldAwareFactorizationMachineUDTF.class); // ---------------------------------------- // Learning hyper-parameters/options @@ -150,8 +149,14 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi fieldNames.add("model_id"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); - fieldNames.add("model"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + fieldNames.add("i"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + fieldNames.add("Wi"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + fieldNames.add("Vi"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @@ -267,39 +272,46 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi this._fieldList = null; this._sumVfX = null; - Text modelId = new Text(); - String taskId = HadoopUtils.getUniqueTaskIdString(); - modelId.set(taskId); + final int factors = _factors; + final IntWritable idx = new IntWritable(); + final FloatWritable Wi = new FloatWritable(0.f); + final FloatWritable[] Vi = HiveUtils.newFloatArray(factors, 0.f); + + final Object[] forwardObjs = new Object[4]; + String modelId = HadoopUtils.getUniqueTaskIdString(); + forwardObjs[0] = new Text(modelId); + forwardObjs[1] = idx; + forwardObjs[2] = Wi; + forwardObjs[3] = null; // Vi + + // W0 + idx.set(0); + Wi.set(_ffmModel.getW0()); + forward(forwardObjs); - FFMPredictionModel predModel = _ffmModel.toPredictionModel(); - this._ffmModel = null; // help GC + forwardObjs[3] = Arrays.asList(Vi); - if (LOG.isInfoEnabled()) { - LOG.info("Serializing a model '" + modelId + "'... Configured # features: " - + _numFeatures + ", Configured # fields: " + _numFields - + ", Actual # features: " + predModel.getActualNumFeatures() - + ", Estimated uncompressed bytes: " - + NumberUtils.prettySize(predModel.approxBytesConsumed())); - } + final EntryIterator itor = _ffmModel.entries(); + final Entry entry = itor.getEntryProbe(); + final float[] Vf = new float[factors]; + while (itor.next() != -1) { + // set i + int i = itor.getEntryIndex(); + idx.set(i); - byte[] serialized; - try { - serialized = predModel.serialize(); - predModel = null; - } catch (IOException e) { - throw new HiveException("Failed to serialize a model", e); - } + itor.getEntry(entry); - if (LOG.isInfoEnabled()) { - LOG.info("Forwarding a serialized/compressed model '" + modelId + "' of size: " - + NumberUtils.prettySize(serialized.length)); - } + // set Wi + Wi.set(entry.getW()); - Text modelObj = new Text3(serialized); - serialized = null; - Object[] forwardObjs = new Object[] {modelId, modelObj}; + // set Vif + entry.getV(Vf); + for (int f = 0; f < factors; f++) { + Vi[f].set(Vf[f]); + } - forward(forwardObjs); + forward(forwardObjs); + } } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java deleted file mode 100644 index 076387f..0000000 --- a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.fm; - -import hivemall.utils.buffer.HeapBuffer; -import hivemall.utils.collections.maps.Int2LongOpenHashTable; - -import java.io.IOException; - -import org.junit.Assert; -import org.junit.Test; - -public class FFMPredictionModelTest { - - @Test - public void testSerialize() throws IOException, ClassNotFoundException { - final int factors = 3; - final int entrySize = Entry.sizeOf(factors); - - HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE); - Int2LongOpenHashTable map = Int2LongOpenHashTable.newInstance(); - - Entry e1 = new Entry(buf, factors, buf.allocate(entrySize)); - e1.setW(1f); - e1.setV(new float[] {1f, -1f, -1f}); - - Entry e2 = new Entry(buf, factors, buf.allocate(entrySize)); - e2.setW(2f); - e2.setV(new float[] {1f, 2f, -1f}); - - Entry e3 = new Entry(buf, factors, buf.allocate(entrySize)); - e3.setW(3f); - e3.setV(new float[] {1f, 2f, 3f}); - - map.put(1, e1.getOffset()); - map.put(2, e2.getOffset()); - map.put(3, e3.getOffset()); - - FFMPredictionModel expected = new FFMPredictionModel(map, buf, 0.d, 3, - Feature.DEFAULT_NUM_FEATURES, Feature.DEFAULT_NUM_FIELDS); - byte[] b = expected.serialize(); - - FFMPredictionModel actual = FFMPredictionModel.deserialize(b, b.length); - Assert.assertEquals(3, actual.getNumFactors()); - Assert.assertEquals(Feature.DEFAULT_NUM_FEATURES, actual.getNumFeatures()); - Assert.assertEquals(Feature.DEFAULT_NUM_FIELDS, actual.getNumFields()); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-all-as-permanent.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index feb1a08..f3065c0 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -620,7 +620,7 @@ DROP FUNCTION IF EXISTS train_ffm; CREATE FUNCTION train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF' USING JAR '${hivemall_jar}'; DROP FUNCTION IF EXISTS ffm_predict; -CREATE FUNCTION ffm_predict as 'hivemall.fm.FFMPredictUDF' USING JAR '${hivemall_jar}'; +CREATE FUNCTION ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF' USING JAR '${hivemall_jar}'; --------------------------- -- Anomaly Detection ------ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-all.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index 310f9f4..305afc9 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -612,7 +612,7 @@ drop temporary function if exists train_ffm; create temporary function train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF'; drop temporary function if exists ffm_predict; -create temporary function ffm_predict as 'hivemall.fm.FFMPredictUDF'; +create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF'; --------------------------- -- Anomaly Detection ------ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-all.spark ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index 42b235b..4e18b8f 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -596,7 +596,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_ffm") sqlContext.sql("CREATE TEMPORARY FUNCTION train_ffm AS 'hivemall.fm.FieldAwareFactorizationMachineUDTF'") sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS ffm_predict") -sqlContext.sql("CREATE TEMPORARY FUNCTION ffm_predict AS 'hivemall.fm.FFMPredictUDF'") +sqlContext.sql("CREATE TEMPORARY FUNCTION ffm_predict AS 'hivemall.fm.FFMPredictGenericUDAF'") /** * Anomaly Detection http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-udfs.td.hql ---------------------------------------------------------------------- diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql index dd694e3..bc5e3db 100644 --- a/resources/ddl/define-udfs.td.hql +++ b/resources/ddl/define-udfs.td.hql @@ -174,6 +174,8 @@ create temporary function dimsum_mapper as 'hivemall.knn.similarity.DIMSUMMapper create temporary function train_classifier as 'hivemall.classifier.GeneralClassifierUDTF'; create temporary function train_regressor as 'hivemall.regression.GeneralRegressorUDTF'; create temporary function tree_export as 'hivemall.smile.tools.TreeExportUDF'; +create temporary function train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF'; +create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF'; -- NLP features create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF';
