This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new 647c6ae [HIVEMALL-268] Fix the default vInit, eta initialization bug
in FactorizationMachines
647c6ae is described below
commit 647c6ae31ddc0fa290f540716b3026b0e042f39b
Author: Makoto Yui <[email protected]>
AuthorDate: Thu Oct 3 17:34:10 2019 +0900
[HIVEMALL-268] Fix the default vInit, eta initialization bug in
FactorizationMachines
## What changes were proposed in this pull request?
Fix the default vInit, eta initialization bug in FactorizationMachines
## What type of PR is it?
Bug Fix
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-268
## How was this patch tested?
unit tests, manual tests on EMR
## Checklist
(Please remove this section if not needed; check `x` for YES, blank for NO)
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`,
for your commit?
- [ ] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <[email protected]>
Closes #200 from myui/HIVEMALL-268.
---
core/src/main/java/hivemall/fm/FMArrayModel.java | 2 ++
.../main/java/hivemall/fm/FMHyperParameters.java | 23 ++++++++++++++++++----
.../java/hivemall/fm/FMIntFeatureMapModel.java | 2 ++
.../java/hivemall/fm/FMStringFeatureMapModel.java | 1 +
.../hivemall/fm/FactorizationMachineModel.java | 9 +++++++--
.../fm/FieldAwareFactorizationMachineUDTFTest.java | 3 ++-
6 files changed, 33 insertions(+), 7 deletions(-)
diff --git a/core/src/main/java/hivemall/fm/FMArrayModel.java
b/core/src/main/java/hivemall/fm/FMArrayModel.java
index 97807aa..8260188 100644
--- a/core/src/main/java/hivemall/fm/FMArrayModel.java
+++ b/core/src/main/java/hivemall/fm/FMArrayModel.java
@@ -29,7 +29,9 @@ public final class FMArrayModel extends
FactorizationMachineModel {
private final int _p;
// LEARNING PARAMS
+ @Nonnull
private final float[] _w;
+ @Nonnull
private final float[][] _V;
public FMArrayModel(@Nonnull FMHyperParameters params) {
diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java
b/core/src/main/java/hivemall/fm/FMHyperParameters.java
index edee14f..c86c5e2 100644
--- a/core/src/main/java/hivemall/fm/FMHyperParameters.java
+++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java
@@ -20,6 +20,7 @@ package hivemall.fm;
import hivemall.fm.FactorizationMachineModel.VInitScheme;
import hivemall.optimizer.EtaEstimator;
+import hivemall.optimizer.EtaEstimator.InvscalingEtaEstimator;
import hivemall.utils.lang.Primitives;
import javax.annotation.Nonnull;
@@ -46,6 +47,7 @@ class FMHyperParameters {
// V initialization
double sigma = 0.1d;
long seed = -1L;
+ @Nonnull
VInitScheme vInit;
// regression
@@ -53,6 +55,7 @@ class FMHyperParameters {
double maxTarget = Double.MAX_VALUE;
// learning rate
+ @Nonnull
EtaEstimator eta;
// feature hashing
@@ -75,7 +78,10 @@ class FMHyperParameters {
int validationThreshold = 1000;
boolean parseFeatureAsInt = false;
- FMHyperParameters() {}
+ FMHyperParameters() {
+ this.vInit = instantiateVInit();
+ this.eta = new InvscalingEtaEstimator(DEFAULT_ETA0,
EtaEstimator.DEFAULT_POWER_T);
+ }
@Override
public String toString() {
@@ -134,13 +140,22 @@ class FMHyperParameters {
}
@Nonnull
+ private VInitScheme instantiateVInit() {
+ VInitScheme vInit = getDefaultVinitScheme(classification);
+ vInit.setMaxInitValue(0.5f);
+ vInit.setInitStdDev(0.2d);
+ vInit.initRandom(factors, System.nanoTime());
+ return vInit;
+ }
+
+ @Nonnull
private VInitScheme instantiateVInit(@Nonnull CommandLine cl, int factor,
long seed,
final boolean classification) {
String vInitOpt = cl.getOptionValue("init_v");
float maxInitValue =
Primitives.parseFloat(cl.getOptionValue("max_init_value"), 0.5f);
double initStdDev =
Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1d);
- VInitScheme vInit = VInitScheme.resolve(vInitOpt,
getDefaultVinitScheme());
+ VInitScheme vInit = VInitScheme.resolve(vInitOpt,
getDefaultVinitScheme(classification));
vInit.setMaxInitValue(maxInitValue);
initStdDev = Math.max(initStdDev, 1.0d / factor);
vInit.setInitStdDev(initStdDev);
@@ -149,7 +164,7 @@ class FMHyperParameters {
}
@Nonnull
- protected VInitScheme getDefaultVinitScheme() {
+ protected VInitScheme getDefaultVinitScheme(boolean classification) {
return classification ? VInitScheme.gaussian :
VInitScheme.adjustedRandom;
}
@@ -178,7 +193,7 @@ class FMHyperParameters {
}
@Nonnull
- protected VInitScheme getDefaultVinitScheme() {
+ protected VInitScheme getDefaultVinitScheme(boolean classification) {
return VInitScheme.random;
}
diff --git a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
index 72d64c0..c0ce2a5 100644
--- a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
@@ -34,7 +34,9 @@ public final class FMIntFeatureMapModel extends
FactorizationMachineModel {
// LEARNING PARAMS
private float _w0;
+ @Nonnull
private final Int2FloatMap _w;
+ @Nonnull
private final Int2ObjectMap<float[]> _V;
private int _minIndex, _maxIndex;
diff --git a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
index 84b780a..ae15598 100644
--- a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java
@@ -28,6 +28,7 @@ public final class FMStringFeatureMapModel extends
FactorizationMachineModel {
// LEARNING PARAMS
private float _w0;
+ @Nonnull
private final Object2ObjectMap<String, Entry> _map;
public FMStringFeatureMapModel(@Nonnull FMHyperParameters params) {
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
index c654f32..a6d7523 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
@@ -23,6 +23,7 @@ import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
+import java.util.Objects;
import java.util.Random;
import javax.annotation.Nonnegative;
@@ -36,8 +37,11 @@ public abstract class FactorizationMachineModel {
protected final boolean _classification;
protected final int _factor;
protected final double _sigma;
+ @Nonnull
protected final EtaEstimator _eta;
+ @Nonnull
protected final VInitScheme _initScheme;
+ @Nonnull
protected final Random _rnd;
// Hyperparameter for regression
@@ -47,14 +51,15 @@ public abstract class FactorizationMachineModel {
// Regulation Variables
protected float _lambdaW0;
protected float _lambdaW;
+ @Nonnull
protected final float[] _lambdaV;
public FactorizationMachineModel(@Nonnull FMHyperParameters params) {
this._classification = params.classification;
this._factor = params.factors;
this._sigma = params.sigma;
- this._eta = params.eta;
- this._initScheme = params.vInit;
+ this._eta = Objects.requireNonNull(params.eta);
+ this._initScheme = Objects.requireNonNull(params.vInit);
this._rnd = new Random(params.seed);
this._min_target = params.minTarget;
diff --git
a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
index 67040a1..4270ebe 100644
--- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
+++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
@@ -33,6 +33,7 @@ import java.util.zip.GZIPInputStream;
import javax.annotation.Nonnull;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
@@ -258,7 +259,7 @@ public class FieldAwareFactorizationMachineUDTFTest {
cumulativeLoss > udtf._validationState.getCumulativeLoss());
}
- @Test(expected = IllegalArgumentException.class)
+ @Test(expected = UDFArgumentException.class)
public void testUnsupportedAdaptiveRegularizationOption() throws Exception
{
TestUtils.testGenericUDTFSerialization(FieldAwareFactorizationMachineUDTF.class,
new ObjectInspector[] {