Repository: incubator-hivemall
Updated Branches:
  refs/heads/master bc06c93d8 -> 31932fd7c


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java 
b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
index 8bd8134..7ba336e 100644
--- a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
@@ -388,6 +388,549 @@ public class GeneralClassifierUDTFTest {
             new Object[][] {{Arrays.asList("1:-2", "2:-1"), 0}});
     }
 
+    @Test
+    public void testSGD() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt sgd -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 1300: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 1300);
+    }
+
+    @Test
+    public void testMomentum() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt momentum -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 1200: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 1200);
+    }
+
+    @Test
+    public void testNesterov() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt nesterov -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 1100: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 1100);
+    }
+
+    @Test
+    public void testAdagradL1() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt adagrad -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 1400: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 1400);
+    }
+
+    @Test
+    public void testRMSprop() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt rmsprop -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 1300: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 1300);
+    }
+
+    @Test
+    public void testRMSpropGraves() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt RMSpropGraves -reg l1 -lambda 0.0001 -iter 
10 -mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 1200: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 1200);
+    }
+
+
+    @Test
+    public void testAdaDeltaL1() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt adadelta -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 1500: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 1500);
+    }
+
+    @Test
+    public void testAdam() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt Adam -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 800: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 800);
+    }
+
+    @Test
+    public void testNadam() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt Nadam -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 800: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 800);
+    }
+
+    @Test
+    public void testEve() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt Eve -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 800: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 800);
+    }
+
+    @Test
+    public void testAdamAmsgrad() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt Adam -amsgrad -reg l1 -lambda 0.0001 -iter 
10 -mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 1200: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 1200);
+    }
+
+    @Test
+    public void testEveAmsgrad() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt Eve -amsgrad -reg l1 -lambda 0.0001 -iter 
10 -mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 1200: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 1200);
+    }
+
+
+    @Test
+    public void testAdamHD() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt AdamHD -reg l1 -lambda 0.0001 -iter 10 
-mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 800: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 800);
+    }
+
+    @Test
+    public void testAdamDecay() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-loss logloss -opt Adam -decay 0.001 -reg l1 -lambda 0.0001 
-iter 10 -mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 900: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 900);
+    }
+
+
+    @Test
+    public void testAdamInvScaleEta() throws IOException, HiveException {
+        String filePath = "adam_test_10000.tsv.gz";
+        String options =
+                "-eta inv -eta0 0.1 -loss logloss -opt Adam -reg l1 -lambda 
0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaIntObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        for (String line = reader.readLine(); line != null; line = 
reader.readLine()) {
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+
+            String featureLine = tokenizer.nextToken();
+            List<String> X = Arrays.asList(featureLine.split(","));
+
+            String labelLine = tokenizer.nextToken();
+            Integer y = Integer.valueOf(labelLine);
+
+            udtf.process(new Object[] {X, y});
+        }
+
+        udtf.finalizeTraining();
+
+        Assert.assertTrue(
+            "CumulativeLoss is expected to be less than 900: " + 
udtf.getCumulativeLoss(),
+            udtf.getCumulativeLoss() < 900);
+    }
+
     private static void println(String msg) {
         if (DEBUG) {
             System.out.println(msg);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/test/java/hivemall/optimizer/OptimizerTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/optimizer/OptimizerTest.java 
b/core/src/test/java/hivemall/optimizer/OptimizerTest.java
index eeb880d..26ce807 100644
--- a/core/src/test/java/hivemall/optimizer/OptimizerTest.java
+++ b/core/src/test/java/hivemall/optimizer/OptimizerTest.java
@@ -18,13 +18,15 @@
  */
 package hivemall.optimizer;
 
-import org.junit.Assert;
-import org.junit.Test;
+import hivemall.optimizer.Optimizer.OptimizerBase;
 
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Random;
 
+import org.junit.Assert;
+import org.junit.Test;
+
 public final class OptimizerTest {
 
     @Test
@@ -122,7 +124,7 @@ public final class OptimizerTest {
         }
     }
 
-    private void testUpdateWeights(Optimizer optimizer, int numUpdates, int 
initSize) {
+    private void testUpdateWeights(OptimizerBase optimizer, int numUpdates, 
int initSize) {
         final float[] weights = new float[initSize * 2];
         final Random rnd = new Random();
         try {
@@ -144,8 +146,10 @@ public final class OptimizerTest {
         final String[] regTypes = new String[] {"NO", "L1", "L2", "RDA", 
"ElasticNet"};
         for (final String regType : regTypes) {
             options.put("regularization", regType);
-            testUpdateWeights(DenseOptimizerFactory.create(1024, testOptions), 
65536, 1024);
-            testUpdateWeights(SparseOptimizerFactory.create(1024, 
testOptions), 65536, 1024);
+            testUpdateWeights((OptimizerBase) 
DenseOptimizerFactory.create(1024, testOptions),
+                65536, 1024);
+            testUpdateWeights((OptimizerBase) 
SparseOptimizerFactory.create(1024, testOptions),
+                65536, 1024);
         }
     }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/31932fd7/core/src/test/resources/hivemall/classifier/adam_test_10000.tsv.gz
----------------------------------------------------------------------
diff --git a/core/src/test/resources/hivemall/classifier/adam_test_10000.tsv.gz 
b/core/src/test/resources/hivemall/classifier/adam_test_10000.tsv.gz
new file mode 100644
index 0000000..666f425
Binary files /dev/null and 
b/core/src/test/resources/hivemall/classifier/adam_test_10000.tsv.gz differ

Reply via email to