Repository: incubator-hivemall
Updated Branches:
  refs/heads/master e8abae257 -> e9c6f3e04


Added a FFM unit test


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e9c6f3e0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e9c6f3e0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e9c6f3e0

Branch: refs/heads/master
Commit: e9c6f3e04f6875dd5befc53438d88b8b51371c96
Parents: e8abae2
Author: Makoto Yui <[email protected]>
Authored: Wed Oct 18 21:15:28 2017 +0900
Committer: Makoto Yui <[email protected]>
Committed: Wed Oct 18 21:15:28 2017 +0900

----------------------------------------------------------------------
 .../hivemall/fm/FactorizationMachineUDTF.java   |  38 ++++---
 .../FieldAwareFactorizationMachineUDTFTest.java | 113 ++++++++++++++++---
 2 files changed, 123 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c6f3e0/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java 
b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
index 24210a8..5c8af32 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
@@ -18,21 +18,6 @@
  */
 package hivemall.fm;
 
-import hivemall.UDTFWithOptions;
-import hivemall.common.ConversionState;
-import hivemall.fm.FMStringFeatureMapModel.Entry;
-import hivemall.optimizer.EtaEstimator;
-import hivemall.optimizer.LossFunctions;
-import hivemall.optimizer.LossFunctions.LossFunction;
-import hivemall.optimizer.LossFunctions.LossType;
-import hivemall.utils.collections.IMapIterator;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.io.FileUtils;
-import hivemall.utils.io.NioStatefullSegment;
-import hivemall.utils.lang.NumberUtils;
-import hivemall.utils.lang.SizeOf;
-import hivemall.utils.math.MathUtils;
-
 import java.io.File;
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -63,6 +48,22 @@ import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.Counters.Counter;
 import org.apache.hadoop.mapred.Reporter;
 
+import hivemall.UDTFWithOptions;
+import hivemall.annotations.VisibleForTesting;
+import hivemall.common.ConversionState;
+import hivemall.fm.FMStringFeatureMapModel.Entry;
+import hivemall.optimizer.EtaEstimator;
+import hivemall.optimizer.LossFunctions;
+import hivemall.optimizer.LossFunctions.LossFunction;
+import hivemall.optimizer.LossFunctions.LossType;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.FileUtils;
+import hivemall.utils.io.NioStatefullSegment;
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.SizeOf;
+import hivemall.utils.math.MathUtils;
+
 @Description(
         name = "train_fm",
         value = "_FUNC_(array<string> x, double y [, const string options]) - 
Returns a prediction model")
@@ -436,6 +437,13 @@ public class FactorizationMachineUDTF extends 
UDTFWithOptions {
         this._model = null;
     }
 
+    @VisibleForTesting
+    void finalizeTraining() throws HiveException {
+        if (_iterations > 1) {
+            runTrainingIteration(_iterations);
+        }
+    }
+
     protected void forwardModel() throws HiveException {
         if (_parseFeatureAsInt) {
             forwardAsIntFeature(_model, _factors);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c6f3e0/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java 
b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
index 3b219c6..5b54b1e 100644
--- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
+++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
@@ -22,6 +22,7 @@ import java.io.BufferedReader;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
+import java.net.URL;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.zip.GZIPInputStream;
@@ -36,41 +37,116 @@ import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
 import org.junit.Assert;
 import org.junit.Test;
 
+import hivemall.utils.lang.NumberUtils;
+
 public class FieldAwareFactorizationMachineUDTFTest {
 
     private static final boolean DEBUG = false;
     private static final int ITERATIONS = 50;
     private static final int MAX_LINES = 200;
 
+    // ----------------------------------------------------
+    // bigdata.tr.txt
+
     @Test
     public void testSGD() throws HiveException, IOException {
-        runTest("Pure SGD test", "-opt sgd -classification -factors 10 -w0 
-seed 43", 0.60f);
+        runIterations("Pure SGD test", "bigdata.tr.txt.gz",
+            "-opt sgd -classification -factors 10 -w0 -seed 43", 0.60f);
     }
 
     @Test
     public void testAdaGrad() throws HiveException, IOException {
-        runTest("AdaGrad test", "-opt adagrad -classification -factors 10 -w0 
-seed 43", 0.30f);
+        runIterations("AdaGrad test", "bigdata.tr.txt.gz",
+            "-opt adagrad -classification -factors 10 -w0 -seed 43", 0.30f);
     }
 
     @Test
     public void testAdaGradNoCoeff() throws HiveException, IOException {
-        runTest("AdaGrad No Coeff test",
+        runIterations("AdaGrad No Coeff test", "bigdata.tr.txt.gz",
             "-opt adagrad -no_coeff -classification -factors 10 -w0 -seed 43", 
0.30f);
     }
 
     @Test
     public void testFTRL() throws HiveException, IOException {
-        runTest("FTRL test", "-opt ftrl -classification -factors 10 -w0 -seed 
43", 0.30f);
+        runIterations("FTRL test", "bigdata.tr.txt.gz",
+            "-opt ftrl -classification -factors 10 -w0 -seed 43", 0.30f);
     }
 
     @Test
     public void testFTRLNoCoeff() throws HiveException, IOException {
-        runTest("FTRL Coeff test", "-opt ftrl -no_coeff -classification 
-factors 10 -w0 -seed 43",
-            0.30f);
+        runIterations("FTRL Coeff test", "bigdata.tr.txt.gz",
+            "-opt ftrl -no_coeff -classification -factors 10 -w0 -seed 43", 
0.30f);
     }
 
-    private static void runTest(String testName, String testOptions, float 
lossThreshold)
-            throws IOException, HiveException {
+    // ----------------------------------------------------
+    // https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz
+
+    @Test
+    public void testSample() throws IOException, HiveException {
+        run("[Sample.ffm] default option",
+            "https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz";,
+            "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 
43", 0.1f);
+    }
+
+    private static void run(String testName, String testFile, String 
testOptions,
+            float lossThreshold) throws IOException, HiveException {
+        println(testName);
+
+        FieldAwareFactorizationMachineUDTF udtf = new 
FieldAwareFactorizationMachineUDTF();
+        ObjectInspector[] argOIs = new ObjectInspector[] {
+                
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector),
+                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
testOptions)};
+
+        udtf.initialize(argOIs);
+        FieldAwareFactorizationMachineModel model = 
udtf.initModel(udtf._params);
+        Assert.assertTrue("Actual class: " + model.getClass().getName(),
+            model instanceof FFMStringFeatureMapModel);
+
+
+        BufferedReader data = readFile(testFile);
+        while (true) {
+            //gather features in current line
+            final String input = data.readLine();
+            if (input == null) {
+                break;
+            }
+            String[] featureStrings = input.split(" ");
+
+            double y = Double.parseDouble(featureStrings[0]);
+            if (y == 0) {
+                y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1}
+            }
+
+            final List<String> features = new 
ArrayList<String>(featureStrings.length - 1);
+            for (int j = 1; j < featureStrings.length; ++j) {
+                String fj = featureStrings[j];
+                String[] splitted = fj.split(":");
+                Assert.assertEquals(3, splitted.length);
+                String indexStr = splitted[1];
+                String f = fj;
+                if (NumberUtils.isDigits(indexStr)) {
+                    int index = Integer.parseInt(indexStr) + 1; // avoid 0 
index
+                    f = splitted[0] + ':' + index + ':' + splitted[2];
+                }
+                features.add(f);
+            }
+
+            udtf.process(new Object[] {features, y});
+        }
+        udtf.finalizeTraining();
+        data.close();
+
+        println("model size=" + udtf._model.getSize());
+
+        double avgLoss = udtf._cvState.getCumulativeLoss() / udtf._t;
+        Assert.assertTrue("Last loss was greater than expected: " + avgLoss,
+            avgLoss < lossThreshold);
+    }
+
+    private static void runIterations(String testName, String testFile, String 
testOptions,
+            float lossThreshold) throws IOException, HiveException {
         println(testName);
 
         FieldAwareFactorizationMachineUDTF udtf = new 
FieldAwareFactorizationMachineUDTF();
@@ -88,7 +164,7 @@ public class FieldAwareFactorizationMachineUDTFTest {
         double loss = 0.d;
         double cumul = 0.d;
         for (int trainingIteration = 1; trainingIteration <= ITERATIONS; 
++trainingIteration) {
-            BufferedReader data = readFile("bigdata.tr.txt.gz");
+            BufferedReader data = readFile(testFile);
             loss = udtf._cvState.getCumulativeLoss();
             int lines = 0;
             for (int lineNumber = 0; lineNumber < MAX_LINES; ++lineNumber, 
++lines) {
@@ -106,10 +182,15 @@ public class FieldAwareFactorizationMachineUDTFTest {
 
                 final List<String> features = new 
ArrayList<String>(featureStrings.length - 1);
                 for (int j = 1; j < featureStrings.length; ++j) {
-                    String[] splitted = featureStrings[j].split(":");
+                    String fj = featureStrings[j];
+                    String[] splitted = fj.split(":");
                     Assert.assertEquals(3, splitted.length);
-                    int index = Integer.parseInt(splitted[1]) + 1;
-                    String f = splitted[0] + ':' + index + ':' + splitted[2];
+                    String indexStr = splitted[1];
+                    String f = fj;
+                    if (NumberUtils.isDigits(indexStr)) {
+                        int index = Integer.parseInt(indexStr) + 1; // avoid 0 
index
+                        f = splitted[0] + ':' + index + ':' + splitted[2];
+                    }
                     features.add(f);
                 }
                 udtf.process(new Object[] {features, y});
@@ -125,7 +206,13 @@ public class FieldAwareFactorizationMachineUDTFTest {
 
     @Nonnull
     private static BufferedReader readFile(@Nonnull String fileName) throws 
IOException {
-        InputStream is = 
FieldAwareFactorizationMachineUDTFTest.class.getResourceAsStream(fileName);
+        InputStream is;
+        if (fileName.startsWith("http")) {
+            URL url = new URL(fileName);
+            is = url.openStream();
+        } else {
+            is = 
FieldAwareFactorizationMachineUDTFTest.class.getResourceAsStream(fileName);
+        }
         if (fileName.endsWith(".gz")) {
             is = new GZIPInputStream(is);
         }

Reply via email to