Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 047f5fed4 -> e186a5876


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java 
b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
index 1c7a90e..dba4a00 100644
--- a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
+++ b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java
@@ -18,6 +18,9 @@
  */
 package hivemall.classifier;
 
+import static hivemall.utils.hadoop.HiveUtils.lazyInteger;
+import static hivemall.utils.hadoop.HiveUtils.lazyLong;
+import static hivemall.utils.hadoop.HiveUtils.lazyString;
 import hivemall.utils.math.MathUtils;
 
 import java.io.BufferedReader;
@@ -35,11 +38,20 @@ import javax.annotation.Nonnull;
 
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.Collector;
+import org.apache.hadoop.hive.serde2.lazy.LazyInteger;
+import org.apache.hadoop.hive.serde2.lazy.LazyLong;
+import org.apache.hadoop.hive.serde2.lazy.LazyString;
+import 
org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyStringObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -82,6 +94,129 @@ public class GeneralClassifierUDTFTest {
         udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
     }
 
+    @Test
+    public void testNoOptions() throws Exception {
+        List<String> x = Arrays.asList("1:-2", "2:-1");
+        int y = 0;
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+        ObjectInspector intOI = 
PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+        ObjectInspector stringOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI, intOI});
+
+        udtf.process(new Object[] {x, y});
+
+        udtf.finalizeTraining();
+
+        float score = udtf.predict(udtf.parseFeatures(x));
+        int predicted = score > 0.f ? 1 : 0;
+        Assert.assertTrue(y == predicted);
+    }
+
+    private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector 
featureOI,
+            @Nonnull Class<T> featureClass, @Nonnull Class<?> 
modelFeatureClass) throws Exception {
+        int y = 0;
+
+        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
+        ObjectInspector valueOI = 
PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+        ListObjectInspector featureListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(featureOI);
+
+        udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});
+
+        final List<Object> modelFeatures = new ArrayList<Object>();
+        udtf.setCollector(new Collector() {
+            @Override
+            public void collect(Object input) throws HiveException {
+                Object[] forwardMapObj = (Object[]) input;
+                modelFeatures.add(forwardMapObj[0]);
+            }
+        });
+
+        udtf.process(new Object[] {x, y});
+
+        udtf.close();
+
+        Assert.assertFalse(modelFeatures.isEmpty());
+        for (Object modelFeature : modelFeatures) {
+            Assert.assertEquals("All model features must have same type", 
modelFeatureClass,
+                modelFeature.getClass());
+        }
+    }
+
+    @Test
+    public void testStringFeature() throws Exception {
+        List<String> x = Arrays.asList("1:-2", "2:-1");
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+        testFeature(x, featureOI, String.class, String.class);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testIllegalStringFeature() throws Exception {
+        List<String> x = Arrays.asList("1:-2jjjjj", "2:-1");
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+        testFeature(x, featureOI, String.class, String.class);
+    }
+
+    @Test
+    public void testLazyStringFeature() throws Exception {
+        LazyStringObjectInspector oi = 
LazyPrimitiveObjectInspectorFactory.getLazyStringObjectInspector(
+            false, (byte) 0);
+        List<LazyString> x = Arrays.asList(lazyString("テスト:-2", oi), 
lazyString("漢字:-333.0", oi),
+            lazyString("test:-1"));
+        testFeature(x, oi, LazyString.class, String.class);
+    }
+
+    @Test
+    public void testTextFeature() throws Exception {
+        List<Text> x = Arrays.asList(new Text("1:-2"), new Text("2:-1"));
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.writableStringObjectInspector;
+        testFeature(x, featureOI, Text.class, String.class);
+    }
+
+    @Test
+    public void testIntegerFeature() throws Exception {
+        List<Integer> x = Arrays.asList(111, 222);
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+        testFeature(x, featureOI, Integer.class, Integer.class);
+    }
+
+    @Test
+    public void testLazyIntegerFeature() throws Exception {
+        List<LazyInteger> x = Arrays.asList(lazyInteger(111), 
lazyInteger(222));
+        ObjectInspector featureOI = 
LazyPrimitiveObjectInspectorFactory.LAZY_INT_OBJECT_INSPECTOR;
+        testFeature(x, featureOI, LazyInteger.class, Integer.class);
+    }
+
+    @Test
+    public void testWritableIntFeature() throws Exception {
+        List<IntWritable> x = Arrays.asList(new IntWritable(111), new 
IntWritable(222));
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.writableIntObjectInspector;
+        testFeature(x, featureOI, IntWritable.class, Integer.class);
+    }
+
+    @Test
+    public void testLongFeature() throws Exception {
+        List<Long> x = Arrays.asList(111L, 222L);
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.javaLongObjectInspector;
+        testFeature(x, featureOI, Long.class, Long.class);
+    }
+
+    @Test
+    public void testLazyLongFeature() throws Exception {
+        List<LazyLong> x = Arrays.asList(lazyLong(111), lazyLong(222));
+        ObjectInspector featureOI = 
LazyPrimitiveObjectInspectorFactory.LAZY_LONG_OBJECT_INSPECTOR;
+        testFeature(x, featureOI, LazyLong.class, Long.class);
+    }
+
+    @Test
+    public void testWritableLongFeature() throws Exception {
+        List<LongWritable> x = Arrays.asList(new LongWritable(111L), new 
LongWritable(222L));
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector;
+        testFeature(x, featureOI, LongWritable.class, Long.class);
+    }
+
     private void run(@Nonnull String options) throws Exception {
         println(options);
 
@@ -95,8 +230,6 @@ public class GeneralClassifierUDTFTest {
 
         int[] labels = new int[] {0, 0, 0, 1, 1, 1};
 
-        int maxIter = 512;
-
         GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
         ObjectInspector intOI = 
PrimitiveObjectInspectorFactory.javaIntObjectInspector;
         ObjectInspector stringOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
@@ -106,19 +239,17 @@ public class GeneralClassifierUDTFTest {
 
         udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
 
-        double cumLossPrev = Double.MAX_VALUE;
-        double cumLoss = 0.d;
-        int it = 0;
-        while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) {
-            cumLossPrev = cumLoss;
-            udtf.resetCumulativeLoss();
-            for (int i = 0, size = samplesList.size(); i < size; i++) {
-                udtf.process(new Object[] {samplesList.get(i), labels[i]});
-            }
-            cumLoss = udtf.getCumulativeLoss();
-            println("Iter: " + ++it + ", Cumulative loss: " + cumLoss);
+        for (int i = 0, size = samplesList.size(); i < size; i++) {
+            udtf.process(new Object[] {samplesList.get(i), labels[i]});
         }
-        Assert.assertTrue(cumLoss / samplesList.size() < 0.5d);
+
+        udtf.finalizeTraining();
+
+        double cumLoss = udtf.getCumulativeLoss();
+        println("Cumulative loss: " + cumLoss);
+        double normalizedLoss = cumLoss / samplesList.size();
+        Assert.assertTrue("cumLoss: " + cumLoss + ", normalizedLoss: " + 
normalizedLoss
+                + "\noptions: " + options, normalizedLoss < 0.5d);
 
         int numTests = 0;
         int numCorrect = 0;
@@ -157,7 +288,8 @@ public class GeneralClassifierUDTFTest {
                 }
 
                 for (String loss : lossFunctions) {
-                    String options = "-opt " + opt + " -reg " + reg + " -loss 
" + loss;
+                    String options = "-opt " + opt + " -reg " + reg + " -loss 
" + loss
+                            + " -cv_rate 0.001 -iter 512";
 
                     // sparse
                     run(options);
@@ -178,15 +310,13 @@ public class GeneralClassifierUDTFTest {
     @SuppressWarnings("unchecked")
     @Test
     public void testNews20() throws IOException, ParseException, HiveException 
{
-        int nIter = 10;
-
         GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
         ObjectInspector intOI = 
PrimitiveObjectInspectorFactory.javaIntObjectInspector;
         ObjectInspector stringOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
         ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
         ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
             PrimitiveObjectInspectorFactory.javaStringObjectInspector,
-            "-opt SGD -loss logloss -reg L2 -lambda 0.1");
+            "-opt SGD -loss logloss -reg L2 -lambda 0.1 -cv_rate 0.005");
 
         udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
 
@@ -213,13 +343,7 @@ public class GeneralClassifierUDTFTest {
         news20.close();
 
         // perform SGD iterations
-        for (int it = 1; it < nIter; it++) {
-            for (int i = 0, size = wordsList.size(); i < size; i++) {
-                words = wordsList.get(i);
-                int label = labels.get(i);
-                udtf.process(new Object[] {words, label});
-            }
-        }
+        udtf.finalizeTraining();
 
         int numTests = 0;
         int numCorrect = 0;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/test/java/hivemall/model/FeatureValueTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/model/FeatureValueTest.java 
b/core/src/test/java/hivemall/model/FeatureValueTest.java
index 2b6c832..598e13a 100644
--- a/core/src/test/java/hivemall/model/FeatureValueTest.java
+++ b/core/src/test/java/hivemall/model/FeatureValueTest.java
@@ -50,12 +50,12 @@ public class FeatureValueTest {
     }
 
     @Test(expected = IllegalArgumentException.class)
-    public void testParseExpectingIllegalArgumentException() {
+    public void testParseExpectingIllegalArgumentException1() {
         FeatureValue.parse("ad_url:");
     }
 
-    @Test(expected = NumberFormatException.class)
-    public void testParseExpectingNumberFormatException() {
+    @Test(expected = IllegalArgumentException.class)
+    public void testParseExpectingIllegalArgumentException2() {
         FeatureValue.parse("ad_url:xxxxx");
     }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/test/java/hivemall/regression/AdaGradUDTFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/regression/AdaGradUDTFTest.java 
b/core/src/test/java/hivemall/regression/AdaGradUDTFTest.java
index e7a0a89..fa7e28a 100644
--- a/core/src/test/java/hivemall/regression/AdaGradUDTFTest.java
+++ b/core/src/test/java/hivemall/regression/AdaGradUDTFTest.java
@@ -30,6 +30,7 @@ import org.junit.Test;
 
 public class AdaGradUDTFTest {
 
+    @SuppressWarnings("deprecation")
     @Test
     public void testInitialize() throws UDFArgumentException {
         AdaGradUDTF udtf = new AdaGradUDTF();

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java 
b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java
index cfe9651..f352b89 100644
--- a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java
+++ b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java
@@ -18,6 +18,10 @@
  */
 package hivemall.regression;
 
+import static hivemall.utils.hadoop.HiveUtils.lazyInteger;
+import static hivemall.utils.hadoop.HiveUtils.lazyLong;
+import static hivemall.utils.hadoop.HiveUtils.lazyString;
+
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
@@ -25,11 +29,21 @@ import java.util.List;
 import javax.annotation.Nonnull;
 
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.Collector;
+import org.apache.hadoop.hive.serde2.lazy.LazyInteger;
+import org.apache.hadoop.hive.serde2.lazy.LazyLong;
+import org.apache.hadoop.hive.serde2.lazy.LazyString;
+import 
org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyStringObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -84,6 +98,128 @@ public class GeneralRegressionUDTFTest {
         udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params});
     }
 
+    @Test
+    public void testNoOptions() throws Exception {
+        List<String> x = Arrays.asList("1:-2", "2:-1");
+        float y = 0.f;
+
+        GeneralRegressionUDTF udtf = new GeneralRegressionUDTF();
+        ObjectInspector intOI = 
PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
+        ObjectInspector stringOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI, intOI});
+
+        udtf.process(new Object[] {x, y});
+
+        udtf.finalizeTraining();
+
+        float predicted = udtf.predict(udtf.parseFeatures(x));
+        Assert.assertEquals(y, predicted, 1E-5);
+    }
+
+    private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector 
featureOI,
+            @Nonnull Class<T> featureClass, @Nonnull Class<?> 
modelFeatureClass) throws Exception {
+        float y = 0.f;
+
+        GeneralRegressionUDTF udtf = new GeneralRegressionUDTF();
+        ObjectInspector valueOI = 
PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
+        ListObjectInspector featureListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(featureOI);
+
+        udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});
+
+        final List<Object> modelFeatures = new ArrayList<Object>();
+        udtf.setCollector(new Collector() {
+            @Override
+            public void collect(Object input) throws HiveException {
+                Object[] forwardMapObj = (Object[]) input;
+                modelFeatures.add(forwardMapObj[0]);
+            }
+        });
+
+        udtf.process(new Object[] {x, y});
+
+        udtf.close();
+
+        Assert.assertFalse(modelFeatures.isEmpty());
+        for (Object modelFeature : modelFeatures) {
+            Assert.assertEquals("All model features must have same type", 
modelFeatureClass,
+                modelFeature.getClass());
+        }
+    }
+
+    @Test
+    public void testLazyStringFeature() throws Exception {
+        LazyStringObjectInspector oi = 
LazyPrimitiveObjectInspectorFactory.getLazyStringObjectInspector(
+            false, (byte) 0);
+        List<LazyString> x = Arrays.asList(lazyString("テスト:-2", oi), 
lazyString("漢字:-333.0", oi),
+            lazyString("test:-1"));
+        testFeature(x, oi, LazyString.class, String.class);
+    }
+
+    @Test
+    public void testStringFeature() throws Exception {
+        List<String> x = Arrays.asList("1:-2", "2:-1");
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+        testFeature(x, featureOI, String.class, String.class);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testIlleagalStringFeature() throws Exception {
+        List<String> x = Arrays.asList("1:-2jjjj", "2:-1");
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+        testFeature(x, featureOI, String.class, String.class);
+    }
+
+    @Test
+    public void testTextFeature() throws Exception {
+        List<Text> x = Arrays.asList(new Text("1:-2"), new Text("2:-1"));
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.writableStringObjectInspector;
+        testFeature(x, featureOI, Text.class, String.class);
+    }
+
+    @Test
+    public void testIntegerFeature() throws Exception {
+        List<Integer> x = Arrays.asList(111, 222);
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.javaIntObjectInspector;
+        testFeature(x, featureOI, Integer.class, Integer.class);
+    }
+
+    @Test
+    public void testLazyIntegerFeature() throws Exception {
+        List<LazyInteger> x = Arrays.asList(lazyInteger(111), 
lazyInteger(222));
+        ObjectInspector featureOI = 
LazyPrimitiveObjectInspectorFactory.LAZY_INT_OBJECT_INSPECTOR;
+        testFeature(x, featureOI, LazyInteger.class, Integer.class);
+    }
+
+    @Test
+    public void testWritableIntFeature() throws Exception {
+        List<IntWritable> x = Arrays.asList(new IntWritable(111), new 
IntWritable(222));
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.writableIntObjectInspector;
+        testFeature(x, featureOI, IntWritable.class, Integer.class);
+    }
+
+    @Test
+    public void testLongFeature() throws Exception {
+        List<Long> x = Arrays.asList(111L, 222L);
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.javaLongObjectInspector;
+        testFeature(x, featureOI, Long.class, Long.class);
+    }
+
+    @Test
+    public void testLazyLongFeature() throws Exception {
+        List<LazyLong> x = Arrays.asList(lazyLong(111), lazyLong(222));
+        ObjectInspector featureOI = 
LazyPrimitiveObjectInspectorFactory.LAZY_LONG_OBJECT_INSPECTOR;
+        testFeature(x, featureOI, LazyLong.class, Long.class);
+    }
+
+    @Test
+    public void testWritableLongFeature() throws Exception {
+        List<LongWritable> x = Arrays.asList(new LongWritable(111L), new 
LongWritable(222L));
+        ObjectInspector featureOI = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector;
+        testFeature(x, featureOI, LongWritable.class, Long.class);
+    }
+
     private void run(@Nonnull String options) throws Exception {
         println(options);
 
@@ -108,9 +244,6 @@ public class GeneralRegressionUDTFTest {
             x2 += x2Step;
         }
 
-        int numTrain = (int) (numSamples * 0.8);
-        int maxIter = 512;
-
         GeneralRegressionUDTF udtf = new GeneralRegressionUDTF();
         ObjectInspector floatOI = 
PrimitiveObjectInspectorFactory.javaFloatObjectInspector;
         ObjectInspector stringOI = 
PrimitiveObjectInspectorFactory.javaStringObjectInspector;
@@ -120,23 +253,29 @@ public class GeneralRegressionUDTFTest {
 
         udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params});
 
-        double cumLossPrev = Double.MAX_VALUE;
-        double cumLoss = 0.d;
-        int it = 0;
-        while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) {
-            cumLossPrev = cumLoss;
-            udtf.resetCumulativeLoss();
-            for (int i = 0; i < numTrain; i++) {
-                udtf.process(new Object[] {samplesList.get(i), (Float) 
ys.get(i)});
-            }
-            cumLoss = udtf.getCumulativeLoss();
-            println("Iter: " + ++it + ", Cumulative loss: " + cumLoss);
+        float accum = 0.f;
+        for (int i = 0; i < numSamples; i++) {
+            float y = ys.get(i).floatValue();
+            float predicted = 
udtf.predict(udtf.parseFeatures(samplesList.get(i)));
+            accum += Math.abs(y - predicted);
         }
-        Assert.assertTrue(cumLoss / numTrain < 0.1d);
+        float maeInit = accum / numSamples;
+        println("Mean absolute error before training: " + maeInit);
 
-        float accum = 0.f;
+        for (int i = 0; i < numSamples; i++) {
+            udtf.process(new Object[] {samplesList.get(i), (Float) ys.get(i)});
+        }
 
-        for (int i = numTrain; i < numSamples; i++) {
+        udtf.finalizeTraining();
+
+        double cumLoss = udtf.getCumulativeLoss();
+        println("Cumulative loss: " + cumLoss);
+        double normalizedLoss = cumLoss / numSamples;
+        Assert.assertTrue("cumLoss: " + cumLoss + ", normalizedLoss: " + 
normalizedLoss
+                + "\noptions: " + options, normalizedLoss < 0.1d);
+
+        accum = 0.f;
+        for (int i = 0; i < numSamples; i++) {
             float y = ys.get(i).floatValue();
 
             float predicted = 
udtf.predict(udtf.parseFeatures(samplesList.get(i)));
@@ -144,10 +283,10 @@ public class GeneralRegressionUDTFTest {
 
             accum += Math.abs(y - predicted);
         }
-
-        float err = accum / (numSamples - numTrain);
-        println("Mean absolute error: " + err);
-        Assert.assertTrue(err < 0.2f);
+        float mae = accum / numSamples;
+        println("Mean absolute error after training: " + mae);
+        Assert.assertTrue("accum: " + accum + ", mae (init):" + maeInit + ", 
mae:" + mae
+                + "\noptions: " + options, mae < maeInit);
     }
 
     @Test
@@ -165,7 +304,7 @@ public class GeneralRegressionUDTFTest {
 
                 for (String loss : lossFunctions) {
                     String options = "-opt " + opt + " -reg " + reg + " -loss 
" + loss
-                            + " -lambda 1e-6 -eta0 1e-1";
+                            + " -iter 512";
 
                     // sparse
                     run(options);

Reply via email to