Changed FFM prediction model as a scalable format

Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/550bb4e6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/550bb4e6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/550bb4e6

Branch: refs/heads/HIVEMALL-24-2
Commit: 550bb4e6f0f69cfd25506c34caf67e3c014d5750
Parents: 36a6ca2
Author: Makoto Yui <[email protected]>
Authored: Tue Jul 25 20:03:29 2017 +0900
Committer: Makoto Yui <[email protected]>
Committed: Tue Jul 25 20:03:29 2017 +0900

----------------------------------------------------------------------
 core/src/main/java/hivemall/fm/Entry.java       |   2 +-
 .../java/hivemall/fm/FFMPredictGenericUDAF.java |  21 +-
 .../main/java/hivemall/fm/FFMPredictUDF.java    | 187 ----------
 .../java/hivemall/fm/FFMPredictionModel.java    | 349 -------------------
 .../hivemall/fm/FFMStringFeatureMapModel.java   |  53 ++-
 .../fm/FieldAwareFactorizationMachineUDTF.java  |  82 +++--
 .../hivemall/fm/FFMPredictionModelTest.java     |  65 ----
 resources/ddl/define-all-as-permanent.hive      |   2 +-
 resources/ddl/define-all.hive                   |   2 +-
 resources/ddl/define-all.spark                  |   2 +-
 resources/ddl/define-udfs.td.hql                |   2 +
 11 files changed, 111 insertions(+), 656 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/Entry.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/Entry.java 
b/core/src/main/java/hivemall/fm/Entry.java
index 1882f85..209112c 100644
--- a/core/src/main/java/hivemall/fm/Entry.java
+++ b/core/src/main/java/hivemall/fm/Entry.java
@@ -58,7 +58,7 @@ class Entry {
         return _offset;
     }
 
-    void setOffset(long offset) {
+    void setOffset(final long offset) {
         this._offset = offset;
     }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java 
b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
index 91d1b6b..a37a1b8 100644
--- a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
+++ b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java
@@ -19,7 +19,12 @@
 package hivemall.fm;
 
 import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import javax.annotation.Nonnull;
+
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
@@ -30,20 +35,18 @@ import 
org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
 import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-import java.util.ArrayList;
-import java.util.List;
-
 @Description(
         name = "ffm_predict",
         value = "_FUNC_(Float Wi, Float Wj, array<float> Vifj, array<float> 
Vjfi, float Xi, float Xj)"

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMPredictUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictUDF.java 
b/core/src/main/java/hivemall/fm/FFMPredictUDF.java
deleted file mode 100644
index 48745d9..0000000
--- a/core/src/main/java/hivemall/fm/FFMPredictUDF.java
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.fm;
-
-import hivemall.annotations.Experimental;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.lang.NumberUtils;
-
-import java.io.IOException;
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.UDFType;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
-import org.apache.hadoop.io.Text;
-
-/**
- * @since v0.5-rc.1
- */
-@Description(name = "ffm_predict",
-        value = "_FUNC_(string modelId, string model, array<string> features)"
-                + " returns a prediction result in double from a Field-aware 
Factorization Machine")
-@UDFType(deterministic = true, stateful = false)
-@Experimental
-public final class FFMPredictUDF extends GenericUDF {
-
-    private StringObjectInspector _modelIdOI;
-    private StringObjectInspector _modelOI;
-    private ListObjectInspector _featureListOI;
-
-    private DoubleWritable _result;
-    @Nullable
-    private String _cachedModeId;
-    @Nullable
-    private FFMPredictionModel _cachedModel;
-    @Nullable
-    private Feature[] _probes;
-
-    public FFMPredictUDF() {}
-
-    @Override
-    public ObjectInspector initialize(ObjectInspector[] argOIs) throws 
UDFArgumentException {
-        if (argOIs.length != 3) {
-            throw new UDFArgumentException("_FUNC_ takes 3 arguments");
-        }
-        this._modelIdOI = HiveUtils.asStringOI(argOIs[0]);
-        this._modelOI = HiveUtils.asStringOI(argOIs[1]);
-        this._featureListOI = HiveUtils.asListOI(argOIs[2]);
-
-        this._result = new DoubleWritable();
-        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
-    }
-
-    @Override
-    public Object evaluate(DeferredObject[] args) throws HiveException {
-        String modelId = _modelIdOI.getPrimitiveJavaObject(args[0].get());
-        if (modelId == null) {
-            throw new HiveException("modelId is not set");
-        }
-
-        final FFMPredictionModel model;
-        if (modelId.equals(_cachedModeId)) {
-            model = this._cachedModel;
-        } else {
-            Text serModel = _modelOI.getPrimitiveWritableObject(args[1].get());
-            if (serModel == null) {
-                throw new HiveException("Model is null for model ID: " + 
modelId);
-            }
-            byte[] b = serModel.getBytes();
-            final int length = serModel.getLength();
-            try {
-                model = FFMPredictionModel.deserialize(b, length);
-                b = null;
-            } catch (ClassNotFoundException e) {
-                throw new HiveException(e);
-            } catch (IOException e) {
-                throw new HiveException(e);
-            }
-            this._cachedModeId = modelId;
-            this._cachedModel = model;
-        }
-
-        int numFeatures = model.getNumFeatures();
-        int numFields = model.getNumFields();
-
-        Object arg2 = args[2].get();
-        // [workaround]
-        // java.lang.ClassCastException: 
org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray
-        // cannot be cast to [Ljava.lang.Object;
-        if (arg2 instanceof LazyBinaryArray) {
-            arg2 = ((LazyBinaryArray) arg2).getList();
-        }
-        Feature[] x = Feature.parseFFMFeatures(arg2, _featureListOI, _probes, 
numFeatures,
-            numFields);
-        if (x == null || x.length == 0) {
-            return null; // return NULL if there are no features
-        }
-        this._probes = x;
-
-        double predicted = predict(x, model);
-        _result.set(predicted);
-        return _result;
-    }
-
-    private static double predict(@Nonnull final Feature[] x,
-            @Nonnull final FFMPredictionModel model) throws HiveException {
-        // w0
-        double ret = model.getW0();
-        // W
-        for (Feature e : x) {
-            double xi = e.getValue();
-            float wi = model.getW(e);
-            double wx = wi * xi;
-            ret += wx;
-        }
-        // V        
-        final int factors = model.getNumFactors();
-        final float[] vij = new float[factors];
-        final float[] vji = new float[factors];
-        for (int i = 0; i < x.length; ++i) {
-            final Feature ei = x[i];
-            final double xi = ei.getValue();
-            final int iField = ei.getField();
-            for (int j = i + 1; j < x.length; ++j) {
-                final Feature ej = x[j];
-                final double xj = ej.getValue();
-                final int jField = ej.getField();
-                if (!model.getV(ei, jField, vij)) {
-                    continue;
-                }
-                if (!model.getV(ej, iField, vji)) {
-                    continue;
-                }
-                for (int f = 0; f < factors; f++) {
-                    float vijf = vij[f];
-                    float vjif = vji[f];
-                    ret += vijf * vjif * xi * xj;
-                }
-            }
-        }
-        if (!NumberUtils.isFinite(ret)) {
-            throw new HiveException("Detected " + ret + " in ffm_predict");
-        }
-        return ret;
-    }
-
-    @Override
-    public void close() throws IOException {
-        super.close();
-        // clean up to help GC
-        this._cachedModel = null;
-        this._probes = null;
-    }
-
-    @Override
-    public String getDisplayString(String[] args) {
-        return "ffm_predict(" + Arrays.toString(args) + ")";
-    }
-
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMPredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMPredictionModel.java 
b/core/src/main/java/hivemall/fm/FFMPredictionModel.java
deleted file mode 100644
index befbec9..0000000
--- a/core/src/main/java/hivemall/fm/FFMPredictionModel.java
+++ /dev/null
@@ -1,349 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.fm;
-
-import hivemall.utils.buffer.HeapBuffer;
-import hivemall.utils.codec.VariableByteCodec;
-import hivemall.utils.codec.ZigZagLEB128Codec;
-import hivemall.utils.collections.maps.Int2LongOpenHashTable;
-import hivemall.utils.collections.maps.IntOpenHashTable;
-import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm;
-import hivemall.utils.io.IOUtils;
-import hivemall.utils.lang.ArrayUtils;
-import hivemall.utils.lang.HalfFloat;
-import hivemall.utils.lang.ObjectUtils;
-
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-public final class FFMPredictionModel implements Externalizable {
-    private static final Log LOG = LogFactory.getLog(FFMPredictionModel.class);
-
-    private static final byte HALF_FLOAT_ENTRY = 1;
-    private static final byte W_ONLY_HALF_FLOAT_ENTRY = 2;
-    private static final byte FLOAT_ENTRY = 3;
-    private static final byte W_ONLY_FLOAT_ENTRY = 4;
-
-    /**
-     * maps feature to feature weight pointer
-     */
-    private Int2LongOpenHashTable _map;
-    private HeapBuffer _buf;
-
-    private double _w0;
-    private int _factors;
-    private int _numFeatures;
-    private int _numFields;
-
-    public FFMPredictionModel() {}// for Externalizable
-
-    public FFMPredictionModel(@Nonnull Int2LongOpenHashTable map, @Nonnull 
HeapBuffer buf,
-            double w0, int factor, int numFeatures, int numFields) {
-        this._map = map;
-        this._buf = buf;
-        this._w0 = w0;
-        this._factors = factor;
-        this._numFeatures = numFeatures;
-        this._numFields = numFields;
-    }
-
-    public int getNumFactors() {
-        return _factors;
-    }
-
-    public double getW0() {
-        return _w0;
-    }
-
-    public int getNumFeatures() {
-        return _numFeatures;
-    }
-
-    public int getNumFields() {
-        return _numFields;
-    }
-
-    public int getActualNumFeatures() {
-        return _map.size();
-    }
-
-    public long approxBytesConsumed() {
-        int size = _map.size();
-
-        // [map] size * (|state| + |key| + |entry|) 
-        long bytes = size * (1L + 4L + 4L + (4L * _factors));
-        int rest = _map.capacity() - size;
-        if (rest > 0) {
-            bytes += rest * 1L;
-        }
-        // w0, factors, numFeatures, numFields, used, size
-        bytes += (8 + 4 + 4 + 4 + 4 + 4);
-        return bytes;
-    }
-
-    @Nullable
-    private Entry getEntry(final int key) {
-        final long ptr = _map.get(key);
-        if (ptr == -1L) {
-            return null;
-        }
-        return new Entry(_buf, _factors, ptr);
-    }
-
-    public float getW(@Nonnull final Feature x) {
-        int j = x.getFeatureIndex();
-
-        Entry entry = getEntry(j);
-        if (entry == null) {
-            return 0.f;
-        }
-        return entry.getW();
-    }
-
-    /**
-     * @return true if V exists
-     */
-    public boolean getV(@Nonnull final Feature x, @Nonnull final int yField, 
@Nonnull float[] dst) {
-        int j = Feature.toIntFeature(x, yField, _numFields);
-
-        Entry entry = getEntry(j);
-        if (entry == null) {
-            return false;
-        }
-
-        entry.getV(dst);
-        if (ArrayUtils.equals(dst, 0.f)) {
-            return false; // treat as null
-        }
-        return true;
-    }
-
-    @Override
-    public void writeExternal(@Nonnull ObjectOutput out) throws IOException {
-        out.writeDouble(_w0);
-        final int factors = _factors;
-        out.writeInt(factors);
-        out.writeInt(_numFeatures);
-        out.writeInt(_numFields);
-
-        int used = _map.size();
-        out.writeInt(used);
-
-        final int[] keys = _map.getKeys();
-        final int size = keys.length;
-        out.writeInt(size);
-
-        final byte[] states = _map.getStates();
-        writeStates(states, out);
-
-        final long[] values = _map.getValues();
-
-        final HeapBuffer buf = _buf;
-        final Entry e = new Entry(buf, factors);
-        final float[] Vf = new float[factors];
-        for (int i = 0; i < size; i++) {
-            if (states[i] != IntOpenHashTable.FULL) {
-                continue;
-            }
-            ZigZagLEB128Codec.writeSignedInt(keys[i], out);
-            e.setOffset(values[i]);
-            writeEntry(e, factors, Vf, out);
-        }
-
-        // help GC
-        this._map = null;
-        this._buf = null;
-    }
-
-    private static void writeEntry(@Nonnull final Entry e, final int factors,
-            @Nonnull final float[] Vf, @Nonnull final DataOutput out) throws 
IOException {
-        final float W = e.getW();
-        e.getV(Vf);
-
-        if (ArrayUtils.almostEquals(Vf, 0.f)) {
-            if (HalfFloat.isRepresentable(W)) {
-                out.writeByte(W_ONLY_HALF_FLOAT_ENTRY);
-                out.writeShort(HalfFloat.floatToHalfFloat(W));
-            } else {
-                out.writeByte(W_ONLY_FLOAT_ENTRY);
-                out.writeFloat(W);
-            }
-        } else if (isRepresentableAsHalfFloat(W, Vf)) {
-            out.writeByte(HALF_FLOAT_ENTRY);
-            out.writeShort(HalfFloat.floatToHalfFloat(W));
-            for (int i = 0; i < factors; i++) {
-                out.writeShort(HalfFloat.floatToHalfFloat(Vf[i]));
-            }
-        } else {
-            out.writeByte(FLOAT_ENTRY);
-            out.writeFloat(W);
-            IOUtils.writeFloats(Vf, factors, out);
-        }
-    }
-
-    private static boolean isRepresentableAsHalfFloat(final float W, @Nonnull 
final float[] Vf) {
-        if (!HalfFloat.isRepresentable(W)) {
-            return false;
-        }
-        for (float V : Vf) {
-            if (!HalfFloat.isRepresentable(V)) {
-                return false;
-            }
-        }
-        return true;
-    }
-
-    @Nonnull
-    static void writeStates(@Nonnull final byte[] status, @Nonnull final 
DataOutput out)
-            throws IOException {
-        // write empty states's indexes differentially
-        final int size = status.length;
-        int cardinarity = 0;
-        for (int i = 0; i < size; i++) {
-            if (status[i] != IntOpenHashTable.FULL) {
-                cardinarity++;
-            }
-        }
-        out.writeInt(cardinarity);
-        if (cardinarity == 0) {
-            return;
-        }
-        int prev = 0;
-        for (int i = 0; i < size; i++) {
-            if (status[i] != IntOpenHashTable.FULL) {
-                int diff = i - prev;
-                assert (diff >= 0);
-                VariableByteCodec.encodeUnsignedInt(diff, out);
-                prev = i;
-            }
-        }
-    }
-
-    @Override
-    public void readExternal(@Nonnull final ObjectInput in) throws IOException,
-            ClassNotFoundException {
-        this._w0 = in.readDouble();
-        final int factors = in.readInt();
-        this._factors = factors;
-        this._numFeatures = in.readInt();
-        this._numFields = in.readInt();
-
-        final int used = in.readInt();
-        final int size = in.readInt();
-        final int[] keys = new int[size];
-        final long[] values = new long[size];
-        final byte[] states = new byte[size];
-        readStates(in, states);
-
-        final int entrySize = Entry.sizeOf(factors);
-        int numChunks = (entrySize * used) / HeapBuffer.DEFAULT_CHUNK_BYTES + 
1;
-        final HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE, 
numChunks);
-        final Entry e = new Entry(buf, factors);
-        final float[] Vf = new float[factors];
-        for (int i = 0; i < size; i++) {
-            if (states[i] != IntOpenHashTable.FULL) {
-                continue;
-            }
-            keys[i] = ZigZagLEB128Codec.readSignedInt(in);
-            long ptr = buf.allocate(entrySize);
-            e.setOffset(ptr);
-            readEntry(in, factors, Vf, e);
-            values[i] = ptr;
-        }
-
-        this._map = new Int2LongOpenHashTable(keys, values, states, used);
-        this._buf = buf;
-    }
-
-    @Nonnull
-    private static void readEntry(@Nonnull final DataInput in, final int 
factors,
-            @Nonnull final float[] Vf, @Nonnull Entry dst) throws IOException {
-        final byte type = in.readByte();
-        switch (type) {
-            case HALF_FLOAT_ENTRY: {
-                float W = HalfFloat.halfFloatToFloat(in.readShort());
-                dst.setW(W);
-                for (int i = 0; i < factors; i++) {
-                    Vf[i] = HalfFloat.halfFloatToFloat(in.readShort());
-                }
-                dst.setV(Vf);
-                break;
-            }
-            case W_ONLY_HALF_FLOAT_ENTRY: {
-                float W = HalfFloat.halfFloatToFloat(in.readShort());
-                dst.setW(W);
-                break;
-            }
-            case FLOAT_ENTRY: {
-                float W = in.readFloat();
-                dst.setW(W);
-                IOUtils.readFloats(in, Vf);
-                dst.setV(Vf);
-                break;
-            }
-            case W_ONLY_FLOAT_ENTRY: {
-                float W = in.readFloat();
-                dst.setW(W);
-                break;
-            }
-            default:
-                throw new IOException("Unexpected Entry type: " + type);
-        }
-    }
-
-    @Nonnull
-    static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] 
status)
-            throws IOException {
-        // read non-empty states differentially
-        final int cardinarity = in.readInt();
-        Arrays.fill(status, IntOpenHashTable.FULL);
-        int prev = 0;
-        for (int j = 0; j < cardinarity; j++) {
-            int i = VariableByteCodec.decodeUnsignedInt(in) + prev;
-            status[i] = IntOpenHashTable.FREE;
-            prev = i;
-        }
-    }
-
-    public byte[] serialize() throws IOException {
-        LOG.info("FFMPredictionModel#serialize(): " + _buf.toString());
-        return ObjectUtils.toCompressedBytes(this, CompressionAlgorithm.lzma2, 
true);
-    }
-
-    public static FFMPredictionModel deserialize(@Nonnull final byte[] 
serializedObj, final int len)
-            throws ClassNotFoundException, IOException {
-        FFMPredictionModel model = new FFMPredictionModel();
-        ObjectUtils.readCompressedObject(serializedObj, len, model, 
CompressionAlgorithm.lzma2,
-            true);
-        LOG.info("FFMPredictionModel#deserialize(): " + model._buf.toString());
-        return model;
-    }
-
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java 
b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
index 4f445fa..2264063 100644
--- a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java
@@ -23,6 +23,7 @@ import hivemall.fm.Entry.FTRLEntry;
 import hivemall.fm.FMHyperParameters.FFMHyperParameters;
 import hivemall.utils.buffer.HeapBuffer;
 import hivemall.utils.collections.maps.Int2LongOpenHashTable;
+import hivemall.utils.collections.maps.Int2LongOpenHashTable.IMapIterator;
 import hivemall.utils.lang.NumberUtils;
 import hivemall.utils.math.MathUtils;
 
@@ -39,7 +40,7 @@ public final class FFMStringFeatureMapModel extends 
FieldAwareFactorizationMachi
     private final HeapBuffer _buf;
 
     // hyperparams
-    private final int _numFeatures;
+    // private final int _numFeatures;
     private final int _numFields;
 
     // FTEL
@@ -55,7 +56,7 @@ public final class FFMStringFeatureMapModel extends 
FieldAwareFactorizationMachi
         this._w0 = 0.f;
         this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE);
         this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
-        this._numFeatures = params.numFeatures;
+        // this._numFeatures = params.numFeatures;
         this._numFields = params.numFields;
         this._alpha = params.alphaFTRL;
         this._beta = params.betaFTRL;
@@ -64,11 +65,6 @@ public final class FFMStringFeatureMapModel extends 
FieldAwareFactorizationMachi
         this._entrySize = entrySize(_factor, _useFTRL, _useAdaGrad);
     }
 
-    @Nonnull
-    FFMPredictionModel toPredictionModel() {
-        return new FFMPredictionModel(_map, _buf, _w0, _factor, _numFeatures, 
_numFields);
-    }
-
     @Override
     public int getSize() {
         return _map.size();
@@ -271,6 +267,49 @@ public final class FFMStringFeatureMapModel extends 
FieldAwareFactorizationMachi
         }
     }
 
+    @Nonnull
+    EntryIterator entries() {
+        return new EntryIterator(this);
+    }
+
+    static final class EntryIterator {
+
+        @Nonnull
+        private final IMapIterator dictItor;
+        @Nonnull
+        private final Entry entryProbe;
+
+        EntryIterator(@Nonnull FFMStringFeatureMapModel model) {
+            this.dictItor = model._map.entries();
+            this.entryProbe = new Entry(model._buf, model._factor);
+        }
+
+        @Nonnull
+        Entry getEntryProbe() {
+            return entryProbe;
+        }
+
+        boolean hasNext() {
+            return dictItor.hasNext();
+        }
+
+        int next() {
+            return dictItor.next();
+        }
+
+        int getEntryIndex() {
+            return dictItor.getKey();
+        }
+
+        @Nonnull
+        void getEntry(@Nonnull final Entry probe) {
+            long offset = dictItor.getValue();
+            probe.setOffset(offset);
+        }
+
+    }
+
+
     private static int entrySize(int factors, boolean ftrl, boolean adagrad) {
         if (ftrl) {
             return FTRLEntry.sizeOf(factors);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java 
b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
index 67dbf87..d1c9e73 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
@@ -18,25 +18,23 @@
  */
 package hivemall.fm;
 
+import hivemall.fm.FFMStringFeatureMapModel.EntryIterator;
 import hivemall.fm.FMHyperParameters.FFMHyperParameters;
 import hivemall.utils.collections.arrays.DoubleArray3D;
 import hivemall.utils.collections.lists.IntArrayList;
 import hivemall.utils.hadoop.HadoopUtils;
-import hivemall.utils.hadoop.Text3;
-import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.hadoop.HiveUtils;
 import hivemall.utils.math.MathUtils;
 
-import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Arrays;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
@@ -44,6 +42,8 @@ import 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
 
 /**
@@ -56,7 +56,6 @@ import org.apache.hadoop.io.Text;
         name = "train_ffm",
         value = "_FUNC_(array<string> x, double y [, const string options]) - 
Returns a prediction model")
 public final class FieldAwareFactorizationMachineUDTF extends 
FactorizationMachineUDTF {
-    private static final Log LOG = 
LogFactory.getLog(FieldAwareFactorizationMachineUDTF.class);
 
     // ----------------------------------------
     // Learning hyper-parameters/options
@@ -150,8 +149,14 @@ public final class FieldAwareFactorizationMachineUDTF 
extends FactorizationMachi
         fieldNames.add("model_id");
         
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
 
-        fieldNames.add("model");
-        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+        fieldNames.add("i");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+        fieldNames.add("Wi");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+        fieldNames.add("Vi");
+        
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
 
         return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
     }
@@ -267,39 +272,46 @@ public final class FieldAwareFactorizationMachineUDTF 
extends FactorizationMachi
         this._fieldList = null;
         this._sumVfX = null;
 
-        Text modelId = new Text();
-        String taskId = HadoopUtils.getUniqueTaskIdString();
-        modelId.set(taskId);
+        final int factors = _factors;
+        final IntWritable idx = new IntWritable();
+        final FloatWritable Wi = new FloatWritable(0.f);
+        final FloatWritable[] Vi = HiveUtils.newFloatArray(factors, 0.f);
+
+        final Object[] forwardObjs = new Object[4];
+        String modelId = HadoopUtils.getUniqueTaskIdString();
+        forwardObjs[0] = new Text(modelId);
+        forwardObjs[1] = idx;
+        forwardObjs[2] = Wi;
+        forwardObjs[3] = null; // Vi
+
+        // W0
+        idx.set(0);
+        Wi.set(_ffmModel.getW0());
+        forward(forwardObjs);
 
-        FFMPredictionModel predModel = _ffmModel.toPredictionModel();
-        this._ffmModel = null; // help GC
+        forwardObjs[3] = Arrays.asList(Vi);
 
-        if (LOG.isInfoEnabled()) {
-            LOG.info("Serializing a model '" + modelId + "'... Configured # 
features: "
-                    + _numFeatures + ", Configured # fields: " + _numFields
-                    + ", Actual # features: " + 
predModel.getActualNumFeatures()
-                    + ", Estimated uncompressed bytes: "
-                    + NumberUtils.prettySize(predModel.approxBytesConsumed()));
-        }
+        final EntryIterator itor = _ffmModel.entries();
+        final Entry entry = itor.getEntryProbe();
+        final float[] Vf = new float[factors];
+        while (itor.next() != -1) {
+            // set i
+            int i = itor.getEntryIndex();
+            idx.set(i);
 
-        byte[] serialized;
-        try {
-            serialized = predModel.serialize();
-            predModel = null;
-        } catch (IOException e) {
-            throw new HiveException("Failed to serialize a model", e);
-        }
+            itor.getEntry(entry);
 
-        if (LOG.isInfoEnabled()) {
-            LOG.info("Forwarding a serialized/compressed model '" + modelId + 
"' of size: "
-                    + NumberUtils.prettySize(serialized.length));
-        }
+            // set Wi
+            Wi.set(entry.getW());
 
-        Text modelObj = new Text3(serialized);
-        serialized = null;
-        Object[] forwardObjs = new Object[] {modelId, modelObj};
+            // set Vif
+            entry.getV(Vf);
+            for (int f = 0; f < factors; f++) {
+                Vi[f].set(Vf[f]);
+            }
 
-        forward(forwardObjs);
+            forward(forwardObjs);
+        }
     }
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java 
b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
deleted file mode 100644
index 076387f..0000000
--- a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.fm;
-
-import hivemall.utils.buffer.HeapBuffer;
-import hivemall.utils.collections.maps.Int2LongOpenHashTable;
-
-import java.io.IOException;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class FFMPredictionModelTest {
-
-    @Test
-    public void testSerialize() throws IOException, ClassNotFoundException {
-        final int factors = 3;
-        final int entrySize = Entry.sizeOf(factors);
-
-        HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
-        Int2LongOpenHashTable map = Int2LongOpenHashTable.newInstance();
-
-        Entry e1 = new Entry(buf, factors, buf.allocate(entrySize));
-        e1.setW(1f);
-        e1.setV(new float[] {1f, -1f, -1f});
-
-        Entry e2 = new Entry(buf, factors, buf.allocate(entrySize));
-        e2.setW(2f);
-        e2.setV(new float[] {1f, 2f, -1f});
-
-        Entry e3 = new Entry(buf, factors, buf.allocate(entrySize));
-        e3.setW(3f);
-        e3.setV(new float[] {1f, 2f, 3f});
-
-        map.put(1, e1.getOffset());
-        map.put(2, e2.getOffset());
-        map.put(3, e3.getOffset());
-
-        FFMPredictionModel expected = new FFMPredictionModel(map, buf, 0.d, 3,
-            Feature.DEFAULT_NUM_FEATURES, Feature.DEFAULT_NUM_FIELDS);
-        byte[] b = expected.serialize();
-
-        FFMPredictionModel actual = FFMPredictionModel.deserialize(b, 
b.length);
-        Assert.assertEquals(3, actual.getNumFactors());
-        Assert.assertEquals(Feature.DEFAULT_NUM_FEATURES, 
actual.getNumFeatures());
-        Assert.assertEquals(Feature.DEFAULT_NUM_FIELDS, actual.getNumFields());
-    }
-
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive 
b/resources/ddl/define-all-as-permanent.hive
index feb1a08..f3065c0 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -620,7 +620,7 @@ DROP FUNCTION IF EXISTS train_ffm;
 CREATE FUNCTION train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF' 
USING JAR '${hivemall_jar}';
 
 DROP FUNCTION IF EXISTS ffm_predict;
-CREATE FUNCTION ffm_predict as 'hivemall.fm.FFMPredictUDF' USING JAR 
'${hivemall_jar}';
+CREATE FUNCTION ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF' USING JAR 
'${hivemall_jar}';
 
 ---------------------------
 -- Anomaly Detection ------

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 310f9f4..305afc9 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -612,7 +612,7 @@ drop temporary function if exists train_ffm;
 create temporary function train_ffm as 
'hivemall.fm.FieldAwareFactorizationMachineUDTF';
 
 drop temporary function if exists ffm_predict;
-create temporary function ffm_predict as 'hivemall.fm.FFMPredictUDF';
+create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF';
 
 ---------------------------
 -- Anomaly Detection ------

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 42b235b..4e18b8f 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -596,7 +596,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS 
train_ffm")
 sqlContext.sql("CREATE TEMPORARY FUNCTION train_ffm AS 
'hivemall.fm.FieldAwareFactorizationMachineUDTF'")
 
 sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS ffm_predict")
-sqlContext.sql("CREATE TEMPORARY FUNCTION ffm_predict AS 
'hivemall.fm.FFMPredictUDF'")
+sqlContext.sql("CREATE TEMPORARY FUNCTION ffm_predict AS 
'hivemall.fm.FFMPredictGenericUDAF'")
 
 /**
  * Anomaly Detection

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/550bb4e6/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index dd694e3..bc5e3db 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -174,6 +174,8 @@ create temporary function dimsum_mapper as 
'hivemall.knn.similarity.DIMSUMMapper
 create temporary function train_classifier as 
'hivemall.classifier.GeneralClassifierUDTF';
 create temporary function train_regressor as 
'hivemall.regression.GeneralRegressorUDTF';
 create temporary function tree_export as 'hivemall.smile.tools.TreeExportUDF';
+create temporary function train_ffm as 
'hivemall.fm.FieldAwareFactorizationMachineUDTF';
+create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF';
 
 -- NLP features
 create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF';


Reply via email to