Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 533849856 -> 4f795cb9a


[HIVEMALL-222] Introduce Gradient Clipping to avoid exploding gradient to 
General Classifier/Regressor

## What changes were proposed in this pull request?

Avoid [exploding 
gradients](http://www.cs.toronto.edu/~rgrosse/courses/csc321_2017/readings/L15%20Exploding%20and%20Vanishing%20Gradients.pdf)
 by gradient clipping (by value)

## What type of PR is it?

Improvement

## What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-222

## How was this patch tested?

unit tests

## Checklist

(Please remove this section if not needed; check `x` for YES, blank for NO)

- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for 
your commit?
- [ ] Did you run system tests on Hive (or Spark)?

Author: Makoto Yui <[email protected]>

Closes #169 from myui/clipping.


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/4f795cb9
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/4f795cb9
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/4f795cb9

Branch: refs/heads/master
Commit: 4f795cb9a1ab8b9cb99664f47b2fc80552967aaa
Parents: 5338498
Author: Makoto Yui <[email protected]>
Authored: Wed Oct 24 17:20:56 2018 +0900
Committer: Makoto Yui <[email protected]>
Committed: Wed Oct 24 17:20:56 2018 +0900

----------------------------------------------------------------------
 .../java/hivemall/GeneralLearnerBaseUDTF.java   |  11 ++-
 .../main/java/hivemall/model/FeatureValue.java  |   5 ++
 .../main/java/hivemall/optimizer/Optimizer.java |   4 +-
 .../regression/GeneralRegressorUDTFTest.java    |  82 ++++++++++++++++++-
 .../hivemall/regression/clipping_data.tsv.gz    | Bin 0 -> 7948 bytes
 5 files changed, 99 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java 
b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
index 0198e77..4aad70a 100644
--- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java
@@ -76,6 +76,8 @@ import org.apache.hadoop.mapred.Reporter;
 
 public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
     private static final Log logger = 
LogFactory.getLog(GeneralLearnerBaseUDTF.class);
+    private static final float MAX_DLOSS = 1e+12f;
+    private static final float MIN_DLOSS = -1e+12f;
 
     private ListObjectInspector featureListOI;
     private PrimitiveObjectInspector targetOI;
@@ -168,6 +170,8 @@ public abstract class GeneralLearnerBaseUDTF extends 
LearnerBaseUDTF {
         opts.addOption("loss", "loss_function", true, 
getLossOptionDescription());
         opts.addOption("iter", "iterations", true,
             "The maximum number of iterations [default: 10]");
+        opts.addOption("iters", "iterations", true,
+            "The maximum number of iterations [default: 10]");
         // conversion check
         opts.addOption("disable_cv", "disable_cvtest", false,
             "Whether to disable convergence check [default: OFF]");
@@ -451,11 +455,16 @@ public abstract class GeneralLearnerBaseUDTF extends 
LearnerBaseUDTF {
         float loss = lossFunction.loss(predicted, target);
         cvState.incrLoss(loss); // retain cumulative loss to check convergence
 
-        final float dloss = lossFunction.dloss(predicted, target);
+        float dloss = lossFunction.dloss(predicted, target);
         if (dloss == 0.f) {
             optimizer.proceedStep();
             return;
         }
+        if (dloss < MIN_DLOSS) {
+            dloss = MIN_DLOSS;
+        } else if (dloss > MAX_DLOSS) {
+            dloss = MAX_DLOSS;
+        }
 
         if (is_mini_batch) {
             accumulateUpdate(features, dloss);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/main/java/hivemall/model/FeatureValue.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/FeatureValue.java 
b/core/src/main/java/hivemall/model/FeatureValue.java
index d7aecd8..209f1ed 100644
--- a/core/src/main/java/hivemall/model/FeatureValue.java
+++ b/core/src/main/java/hivemall/model/FeatureValue.java
@@ -177,4 +177,9 @@ public final class FeatureValue {
         }
     }
 
+    @Override
+    public String toString() {
+        return feature + ":" + value;
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/main/java/hivemall/optimizer/Optimizer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java 
b/core/src/main/java/hivemall/optimizer/Optimizer.java
index 0cbac42..587adf2 100644
--- a/core/src/main/java/hivemall/optimizer/Optimizer.java
+++ b/core/src/main/java/hivemall/optimizer/Optimizer.java
@@ -71,7 +71,9 @@ public interface Optimizer {
         protected float update(@Nonnull final IWeightValue weight, final float 
gradient) {
             float oldWeight = weight.get();
             float delta = computeDelta(weight, gradient);
-            float newWeight = oldWeight - _eta.eta(_numStep) * 
_reg.regularize(oldWeight, delta);
+            float eta = _eta.eta(_numStep);
+            float reg = _reg.regularize(oldWeight, delta);
+            float newWeight = oldWeight - eta * reg;
             weight.set(newWeight);
             return newWeight;
         }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java 
b/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java
index 2755340..a2a8696 100644
--- a/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java
+++ b/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java
@@ -22,13 +22,23 @@ import static hivemall.utils.hadoop.HiveUtils.lazyInteger;
 import static hivemall.utils.hadoop.HiveUtils.lazyLong;
 import static hivemall.utils.hadoop.HiveUtils.lazyString;
 
+import hivemall.TestUtils;
+import hivemall.model.FeatureValue;
+import hivemall.utils.lang.StringUtils;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Comparator;
 import java.util.List;
+import java.util.StringTokenizer;
+import java.util.zip.GZIPInputStream;
 
 import javax.annotation.Nonnull;
 
-import hivemall.TestUtils;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.udf.generic.Collector;
@@ -346,6 +356,76 @@ public class GeneralRegressorUDTFTest {
             new Object[][] {{Arrays.asList("1:-2", "2:-1"), 10.f}});
     }
 
+    @Test
+    public void testGradientClippingSGD() throws IOException, HiveException {
+        String filePath = "clipping_data.tsv.gz";
+        String options = "-loss squaredloss -opt SGD -reg no -eta0 0.01 -iter 
1";
+
+        GeneralRegressorUDTF udtf = new GeneralRegressorUDTF();
+
+        ListObjectInspector stringListOI = 
ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+        ObjectInspector params = 
ObjectInspectorUtils.getConstantObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
options);
+
+        udtf.initialize(new ObjectInspector[] {stringListOI,
+                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, 
params});
+
+        BufferedReader reader = readFile(filePath);
+        String line = reader.readLine();
+        for (int i = 0; line != null; i++) {
+            //System.out.println("> " + i);
+            //System.out.println(line);
+
+            StringTokenizer tokenizer = new StringTokenizer(line, " ");
+            double y = Double.parseDouble(tokenizer.nextToken());
+            List<String> X = new ArrayList<String>();
+            while (tokenizer.hasMoreTokens()) {
+                String f = tokenizer.nextToken();
+                X.add(f);
+            }
+            FeatureValue[] features = udtf.parseFeatures(X);
+            if (DEBUG) {
+                printLine(features, y);
+            }
+
+            float yhat = udtf.predict(features);
+            //System.out.println(yhat);
+            if (Float.isNaN(yhat)) {
+                Assert.fail("NaN cause in line: " + i);
+            }
+
+            udtf.process(new Object[] {X, y});
+
+            line = reader.readLine();
+        }
+
+        udtf.finalizeTraining();
+    }
+
+    private static void printLine(FeatureValue[] features, final double y) {
+        Arrays.sort(features, new Comparator<FeatureValue>() {
+            @Override
+            public int compare(FeatureValue o1, FeatureValue o2) {
+                int f1 = Integer.parseInt(o1.getFeatureAsString());
+                int f2 = Integer.parseInt(o2.getFeatureAsString());
+                return Integer.compare(f1, f2);
+            }
+        });
+        System.out.print(y);
+        System.out.print(' ');
+        System.out.println(StringUtils.join(features, ' '));
+    }
+
+    @Nonnull
+    private static BufferedReader readFile(@Nonnull String fileName) throws 
IOException {
+        InputStream is = 
GeneralRegressorUDTFTest.class.getResourceAsStream(fileName);
+        if (fileName.endsWith(".gz")) {
+            is = new GZIPInputStream(is);
+        }
+        return new BufferedReader(new InputStreamReader(is));
+    }
+
     private static void println(String msg) {
         if (DEBUG) {
             System.out.println(msg);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz
----------------------------------------------------------------------
diff --git a/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz 
b/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz
new file mode 100644
index 0000000..c55a999
Binary files /dev/null and 
b/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz differ

Reply via email to