This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new 647c6ae  [HIVEMALL-268] Fix the default vInit, eta initialization bug 
in FactorizationMachines
647c6ae is described below

commit 647c6ae31ddc0fa290f540716b3026b0e042f39b
Author: Makoto Yui <[email protected]>
AuthorDate: Thu Oct 3 17:34:10 2019 +0900

    [HIVEMALL-268] Fix the default vInit, eta initialization bug in 
FactorizationMachines
    
    ## What changes were proposed in this pull request?
    
    Fix the default vInit, eta initialization bug in FactorizationMachines
    
    ## What type of PR is it?
    
    Bug Fix
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-268
    
    ## How was this patch tested?
    
    unit tests, manual tests on EMR
    
    ## Checklist
    
    (Please remove this section if not needed; check `x` for YES, blank for NO)
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, 
for your commit?
    - [ ] Did you run system tests on Hive (or Spark)?
    
    Author: Makoto Yui <[email protected]>
    
    Closes #200 from myui/HIVEMALL-268.
---
 core/src/main/java/hivemall/fm/FMArrayModel.java   |  2 ++
 .../main/java/hivemall/fm/FMHyperParameters.java   | 23 ++++++++++++++++++----
 .../java/hivemall/fm/FMIntFeatureMapModel.java     |  2 ++
 .../java/hivemall/fm/FMStringFeatureMapModel.java  |  1 +
 .../hivemall/fm/FactorizationMachineModel.java     |  9 +++++++--
 .../fm/FieldAwareFactorizationMachineUDTFTest.java |  3 ++-
 6 files changed, 33 insertions(+), 7 deletions(-)

diff --git a/core/src/main/java/hivemall/fm/FMArrayModel.java 
b/core/src/main/java/hivemall/fm/FMArrayModel.java
index 97807aa..8260188 100644
--- a/core/src/main/java/hivemall/fm/FMArrayModel.java
+++ b/core/src/main/java/hivemall/fm/FMArrayModel.java
@@ -29,7 +29,9 @@ public final class FMArrayModel extends 
FactorizationMachineModel {
     private final int _p;
 
     // LEARNING PARAMS
+    @Nonnull
     private final float[] _w;
+    @Nonnull
     private final float[][] _V;
 
     public FMArrayModel(@Nonnull FMHyperParameters params) {
diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java 
b/core/src/main/java/hivemall/fm/FMHyperParameters.java
index edee14f..c86c5e2 100644
--- a/core/src/main/java/hivemall/fm/FMHyperParameters.java
+++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java
@@ -20,6 +20,7 @@ package hivemall.fm;
 
 import hivemall.fm.FactorizationMachineModel.VInitScheme;
 import hivemall.optimizer.EtaEstimator;
+import hivemall.optimizer.EtaEstimator.InvscalingEtaEstimator;
 import hivemall.utils.lang.Primitives;
 
 import javax.annotation.Nonnull;
@@ -46,6 +47,7 @@ class FMHyperParameters {
     // V initialization
     double sigma = 0.1d;
     long seed = -1L;
+    @Nonnull
     VInitScheme vInit;
 
     // regression
@@ -53,6 +55,7 @@ class FMHyperParameters {
     double maxTarget = Double.MAX_VALUE;
 
     // learning rate
+    @Nonnull
     EtaEstimator eta;
 
     // feature hashing
@@ -75,7 +78,10 @@ class FMHyperParameters {
     int validationThreshold = 1000;
     boolean parseFeatureAsInt = false;
 
-    FMHyperParameters() {}
+    FMHyperParameters() {
+        this.vInit = instantiateVInit();
+        this.eta = new InvscalingEtaEstimator(DEFAULT_ETA0, 
EtaEstimator.DEFAULT_POWER_T);
+    }
 
     @Override
     public String toString() {
@@ -134,13 +140,22 @@ class FMHyperParameters {
     }
 
     @Nonnull
+    private VInitScheme instantiateVInit() {
+        VInitScheme vInit = getDefaultVinitScheme(classification);
+        vInit.setMaxInitValue(0.5f);
+        vInit.setInitStdDev(0.2d);
+        vInit.initRandom(factors, System.nanoTime());
+        return vInit;
+    }
+
+    @Nonnull
     private VInitScheme instantiateVInit(@Nonnull CommandLine cl, int factor, 
long seed,
             final boolean classification) {
         String vInitOpt = cl.getOptionValue("init_v");
         float maxInitValue = 
Primitives.parseFloat(cl.getOptionValue("max_init_value"), 0.5f);
         double initStdDev = 
Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1d);
 
-        VInitScheme vInit = VInitScheme.resolve(vInitOpt, 
getDefaultVinitScheme());
+        VInitScheme vInit = VInitScheme.resolve(vInitOpt, 
getDefaultVinitScheme(classification));
         vInit.setMaxInitValue(maxInitValue);
         initStdDev = Math.max(initStdDev, 1.0d / factor);
         vInit.setInitStdDev(initStdDev);
@@ -149,7 +164,7 @@ class FMHyperParameters {
     }
 
     @Nonnull
-    protected VInitScheme getDefaultVinitScheme() {
+    protected VInitScheme getDefaultVinitScheme(boolean classification) {
         return classification ? VInitScheme.gaussian : 
VInitScheme.adjustedRandom;
     }
 
@@ -178,7 +193,7 @@ class FMHyperParameters {
         }
 
         @Nonnull
-        protected VInitScheme getDefaultVinitScheme() {
+        protected VInitScheme getDefaultVinitScheme(boolean classification) {
             return VInitScheme.random;
         }
 
diff --git a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java 
b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
index 72d64c0..c0ce2a5 100644
--- a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
@@ -34,7 +34,9 @@ public final class FMIntFeatureMapModel extends 
FactorizationMachineModel {
 
     // LEARNING PARAMS
     private float _w0;
+    @Nonnull
     private final Int2FloatMap _w;
+    @Nonnull
     private final Int2ObjectMap<float[]> _V;
 
     private int _minIndex, _maxIndex;
diff --git a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java 
b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
index 84b780a..ae15598 100644
--- a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
@@ -28,6 +28,7 @@ public final class FMStringFeatureMapModel extends 
FactorizationMachineModel {
 
     // LEARNING PARAMS
     private float _w0;
+    @Nonnull
     private final Object2ObjectMap<String, Entry> _map;
 
     public FMStringFeatureMapModel(@Nonnull FMHyperParameters params) {
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java 
b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
index c654f32..a6d7523 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
@@ -23,6 +23,7 @@ import hivemall.utils.lang.NumberUtils;
 import hivemall.utils.math.MathUtils;
 
 import java.util.Arrays;
+import java.util.Objects;
 import java.util.Random;
 
 import javax.annotation.Nonnegative;
@@ -36,8 +37,11 @@ public abstract class FactorizationMachineModel {
     protected final boolean _classification;
     protected final int _factor;
     protected final double _sigma;
+    @Nonnull
     protected final EtaEstimator _eta;
+    @Nonnull
     protected final VInitScheme _initScheme;
+    @Nonnull
     protected final Random _rnd;
 
     // Hyperparameter for regression
@@ -47,14 +51,15 @@ public abstract class FactorizationMachineModel {
     // Regulation Variables
     protected float _lambdaW0;
     protected float _lambdaW;
+    @Nonnull
     protected final float[] _lambdaV;
 
     public FactorizationMachineModel(@Nonnull FMHyperParameters params) {
         this._classification = params.classification;
         this._factor = params.factors;
         this._sigma = params.sigma;
-        this._eta = params.eta;
-        this._initScheme = params.vInit;
+        this._eta = Objects.requireNonNull(params.eta);
+        this._initScheme = Objects.requireNonNull(params.vInit);
         this._rnd = new Random(params.seed);
 
         this._min_target = params.minTarget;
diff --git 
a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java 
b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
index 67040a1..4270ebe 100644
--- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
+++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
@@ -33,6 +33,7 @@ import java.util.zip.GZIPInputStream;
 
 import javax.annotation.Nonnull;
 
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
@@ -258,7 +259,7 @@ public class FieldAwareFactorizationMachineUDTFTest {
             cumulativeLoss > udtf._validationState.getCumulativeLoss());
     }
 
-    @Test(expected = IllegalArgumentException.class)
+    @Test(expected = UDFArgumentException.class)
     public void testUnsupportedAdaptiveRegularizationOption() throws Exception 
{
         
TestUtils.testGenericUDTFSerialization(FieldAwareFactorizationMachineUDTF.class,
             new ObjectInspector[] {

Reply via email to