Repository: incubator-hivemall Updated Branches: refs/heads/master e8abae257 -> e9c6f3e04
Added a FFM unit test Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e9c6f3e0 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e9c6f3e0 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e9c6f3e0 Branch: refs/heads/master Commit: e9c6f3e04f6875dd5befc53438d88b8b51371c96 Parents: e8abae2 Author: Makoto Yui <[email protected]> Authored: Wed Oct 18 21:15:28 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Wed Oct 18 21:15:28 2017 +0900 ---------------------------------------------------------------------- .../hivemall/fm/FactorizationMachineUDTF.java | 38 ++++--- .../FieldAwareFactorizationMachineUDTFTest.java | 113 ++++++++++++++++--- 2 files changed, 123 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c6f3e0/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java index 24210a8..5c8af32 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java @@ -18,21 +18,6 @@ */ package hivemall.fm; -import hivemall.UDTFWithOptions; -import hivemall.common.ConversionState; -import hivemall.fm.FMStringFeatureMapModel.Entry; -import hivemall.optimizer.EtaEstimator; -import hivemall.optimizer.LossFunctions; -import hivemall.optimizer.LossFunctions.LossFunction; -import hivemall.optimizer.LossFunctions.LossType; -import hivemall.utils.collections.IMapIterator; -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.io.FileUtils; -import hivemall.utils.io.NioStatefullSegment; -import hivemall.utils.lang.NumberUtils; -import hivemall.utils.lang.SizeOf; -import hivemall.utils.math.MathUtils; - import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; @@ -63,6 +48,22 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.Counters.Counter; import org.apache.hadoop.mapred.Reporter; +import hivemall.UDTFWithOptions; +import hivemall.annotations.VisibleForTesting; +import hivemall.common.ConversionState; +import hivemall.fm.FMStringFeatureMapModel.Entry; +import hivemall.optimizer.EtaEstimator; +import hivemall.optimizer.LossFunctions; +import hivemall.optimizer.LossFunctions.LossFunction; +import hivemall.optimizer.LossFunctions.LossType; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NioStatefullSegment; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.SizeOf; +import hivemall.utils.math.MathUtils; + @Description( name = "train_fm", value = "_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model") @@ -436,6 +437,13 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { this._model = null; } + @VisibleForTesting + void finalizeTraining() throws HiveException { + if (_iterations > 1) { + runTrainingIteration(_iterations); + } + } + protected void forwardModel() throws HiveException { if (_parseFeatureAsInt) { forwardAsIntFeature(_model, _factors); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c6f3e0/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java index 3b219c6..5b54b1e 100644 --- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java @@ -22,6 +22,7 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.URL; import java.util.ArrayList; import java.util.List; import java.util.zip.GZIPInputStream; @@ -36,41 +37,116 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.junit.Assert; import org.junit.Test; +import hivemall.utils.lang.NumberUtils; + public class FieldAwareFactorizationMachineUDTFTest { private static final boolean DEBUG = false; private static final int ITERATIONS = 50; private static final int MAX_LINES = 200; + // ---------------------------------------------------- + // bigdata.tr.txt + @Test public void testSGD() throws HiveException, IOException { - runTest("Pure SGD test", "-opt sgd -classification -factors 10 -w0 -seed 43", 0.60f); + runIterations("Pure SGD test", "bigdata.tr.txt.gz", + "-opt sgd -classification -factors 10 -w0 -seed 43", 0.60f); } @Test public void testAdaGrad() throws HiveException, IOException { - runTest("AdaGrad test", "-opt adagrad -classification -factors 10 -w0 -seed 43", 0.30f); + runIterations("AdaGrad test", "bigdata.tr.txt.gz", + "-opt adagrad -classification -factors 10 -w0 -seed 43", 0.30f); } @Test public void testAdaGradNoCoeff() throws HiveException, IOException { - runTest("AdaGrad No Coeff test", + runIterations("AdaGrad No Coeff test", "bigdata.tr.txt.gz", "-opt adagrad -no_coeff -classification -factors 10 -w0 -seed 43", 0.30f); } @Test public void testFTRL() throws HiveException, IOException { - runTest("FTRL test", "-opt ftrl -classification -factors 10 -w0 -seed 43", 0.30f); + runIterations("FTRL test", "bigdata.tr.txt.gz", + "-opt ftrl -classification -factors 10 -w0 -seed 43", 0.30f); } @Test public void testFTRLNoCoeff() throws HiveException, IOException { - runTest("FTRL Coeff test", "-opt ftrl -no_coeff -classification -factors 10 -w0 -seed 43", - 0.30f); + runIterations("FTRL Coeff test", "bigdata.tr.txt.gz", + "-opt ftrl -no_coeff -classification -factors 10 -w0 -seed 43", 0.30f); } - private static void runTest(String testName, String testOptions, float lossThreshold) - throws IOException, HiveException { + // ---------------------------------------------------- + // https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz + + @Test + public void testSample() throws IOException, HiveException { + run("[Sample.ffm] default option", + "https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz", + "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43", 0.1f); + } + + private static void run(String testName, String testFile, String testOptions, + float lossThreshold) throws IOException, HiveException { + println(testName); + + FieldAwareFactorizationMachineUDTF udtf = new FieldAwareFactorizationMachineUDTF(); + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, testOptions)}; + + udtf.initialize(argOIs); + FieldAwareFactorizationMachineModel model = udtf.initModel(udtf._params); + Assert.assertTrue("Actual class: " + model.getClass().getName(), + model instanceof FFMStringFeatureMapModel); + + + BufferedReader data = readFile(testFile); + while (true) { + //gather features in current line + final String input = data.readLine(); + if (input == null) { + break; + } + String[] featureStrings = input.split(" "); + + double y = Double.parseDouble(featureStrings[0]); + if (y == 0) { + y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1} + } + + final List<String> features = new ArrayList<String>(featureStrings.length - 1); + for (int j = 1; j < featureStrings.length; ++j) { + String fj = featureStrings[j]; + String[] splitted = fj.split(":"); + Assert.assertEquals(3, splitted.length); + String indexStr = splitted[1]; + String f = fj; + if (NumberUtils.isDigits(indexStr)) { + int index = Integer.parseInt(indexStr) + 1; // avoid 0 index + f = splitted[0] + ':' + index + ':' + splitted[2]; + } + features.add(f); + } + + udtf.process(new Object[] {features, y}); + } + udtf.finalizeTraining(); + data.close(); + + println("model size=" + udtf._model.getSize()); + + double avgLoss = udtf._cvState.getCumulativeLoss() / udtf._t; + Assert.assertTrue("Last loss was greater than expected: " + avgLoss, + avgLoss < lossThreshold); + } + + private static void runIterations(String testName, String testFile, String testOptions, + float lossThreshold) throws IOException, HiveException { println(testName); FieldAwareFactorizationMachineUDTF udtf = new FieldAwareFactorizationMachineUDTF(); @@ -88,7 +164,7 @@ public class FieldAwareFactorizationMachineUDTFTest { double loss = 0.d; double cumul = 0.d; for (int trainingIteration = 1; trainingIteration <= ITERATIONS; ++trainingIteration) { - BufferedReader data = readFile("bigdata.tr.txt.gz"); + BufferedReader data = readFile(testFile); loss = udtf._cvState.getCumulativeLoss(); int lines = 0; for (int lineNumber = 0; lineNumber < MAX_LINES; ++lineNumber, ++lines) { @@ -106,10 +182,15 @@ public class FieldAwareFactorizationMachineUDTFTest { final List<String> features = new ArrayList<String>(featureStrings.length - 1); for (int j = 1; j < featureStrings.length; ++j) { - String[] splitted = featureStrings[j].split(":"); + String fj = featureStrings[j]; + String[] splitted = fj.split(":"); Assert.assertEquals(3, splitted.length); - int index = Integer.parseInt(splitted[1]) + 1; - String f = splitted[0] + ':' + index + ':' + splitted[2]; + String indexStr = splitted[1]; + String f = fj; + if (NumberUtils.isDigits(indexStr)) { + int index = Integer.parseInt(indexStr) + 1; // avoid 0 index + f = splitted[0] + ':' + index + ':' + splitted[2]; + } features.add(f); } udtf.process(new Object[] {features, y}); @@ -125,7 +206,13 @@ public class FieldAwareFactorizationMachineUDTFTest { @Nonnull private static BufferedReader readFile(@Nonnull String fileName) throws IOException { - InputStream is = FieldAwareFactorizationMachineUDTFTest.class.getResourceAsStream(fileName); + InputStream is; + if (fileName.startsWith("http")) { + URL url = new URL(fileName); + is = url.openStream(); + } else { + is = FieldAwareFactorizationMachineUDTFTest.class.getResourceAsStream(fileName); + } if (fileName.endsWith(".gz")) { is = new GZIPInputStream(is); }
