Repository: incubator-hivemall Updated Branches: refs/heads/master 533849856 -> 4f795cb9a
[HIVEMALL-222] Introduce Gradient Clipping to avoid exploding gradient to General Classifier/Regressor ## What changes were proposed in this pull request? Avoid [exploding gradients](http://www.cs.toronto.edu/~rgrosse/courses/csc321_2017/readings/L15%20Exploding%20and%20Vanishing%20Gradients.pdf) by gradient clipping (by value) ## What type of PR is it? Improvement ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-222 ## How was this patch tested? unit tests ## Checklist (Please remove this section if not needed; check `x` for YES, blank for NO) - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [ ] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <[email protected]> Closes #169 from myui/clipping. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/4f795cb9 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/4f795cb9 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/4f795cb9 Branch: refs/heads/master Commit: 4f795cb9a1ab8b9cb99664f47b2fc80552967aaa Parents: 5338498 Author: Makoto Yui <[email protected]> Authored: Wed Oct 24 17:20:56 2018 +0900 Committer: Makoto Yui <[email protected]> Committed: Wed Oct 24 17:20:56 2018 +0900 ---------------------------------------------------------------------- .../java/hivemall/GeneralLearnerBaseUDTF.java | 11 ++- .../main/java/hivemall/model/FeatureValue.java | 5 ++ .../main/java/hivemall/optimizer/Optimizer.java | 4 +- .../regression/GeneralRegressorUDTFTest.java | 82 ++++++++++++++++++- .../hivemall/regression/clipping_data.tsv.gz | Bin 0 -> 7948 bytes 5 files changed, 99 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 0198e77..4aad70a 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -76,6 +76,8 @@ import org.apache.hadoop.mapred.Reporter; public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class); + private static final float MAX_DLOSS = 1e+12f; + private static final float MIN_DLOSS = -1e+12f; private ListObjectInspector featureListOI; private PrimitiveObjectInspector targetOI; @@ -168,6 +170,8 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { opts.addOption("loss", "loss_function", true, getLossOptionDescription()); opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]"); + opts.addOption("iters", "iterations", true, + "The maximum number of iterations [default: 10]"); // conversion check opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: OFF]"); @@ -451,11 +455,16 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { float loss = lossFunction.loss(predicted, target); cvState.incrLoss(loss); // retain cumulative loss to check convergence - final float dloss = lossFunction.dloss(predicted, target); + float dloss = lossFunction.dloss(predicted, target); if (dloss == 0.f) { optimizer.proceedStep(); return; } + if (dloss < MIN_DLOSS) { + dloss = MIN_DLOSS; + } else if (dloss > MAX_DLOSS) { + dloss = MAX_DLOSS; + } if (is_mini_batch) { accumulateUpdate(features, dloss); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/main/java/hivemall/model/FeatureValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/FeatureValue.java b/core/src/main/java/hivemall/model/FeatureValue.java index d7aecd8..209f1ed 100644 --- a/core/src/main/java/hivemall/model/FeatureValue.java +++ b/core/src/main/java/hivemall/model/FeatureValue.java @@ -177,4 +177,9 @@ public final class FeatureValue { } } + @Override + public String toString() { + return feature + ":" + value; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/main/java/hivemall/optimizer/Optimizer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java index 0cbac42..587adf2 100644 --- a/core/src/main/java/hivemall/optimizer/Optimizer.java +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -71,7 +71,9 @@ public interface Optimizer { protected float update(@Nonnull final IWeightValue weight, final float gradient) { float oldWeight = weight.get(); float delta = computeDelta(weight, gradient); - float newWeight = oldWeight - _eta.eta(_numStep) * _reg.regularize(oldWeight, delta); + float eta = _eta.eta(_numStep); + float reg = _reg.regularize(oldWeight, delta); + float newWeight = oldWeight - eta * reg; weight.set(newWeight); return newWeight; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java b/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java index 2755340..a2a8696 100644 --- a/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java +++ b/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java @@ -22,13 +22,23 @@ import static hivemall.utils.hadoop.HiveUtils.lazyInteger; import static hivemall.utils.hadoop.HiveUtils.lazyLong; import static hivemall.utils.hadoop.HiveUtils.lazyString; +import hivemall.TestUtils; +import hivemall.model.FeatureValue; +import hivemall.utils.lang.StringUtils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.List; +import java.util.StringTokenizer; +import java.util.zip.GZIPInputStream; import javax.annotation.Nonnull; -import hivemall.TestUtils; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.Collector; @@ -346,6 +356,76 @@ public class GeneralRegressorUDTFTest { new Object[][] {{Arrays.asList("1:-2", "2:-1"), 10.f}}); } + @Test + public void testGradientClippingSGD() throws IOException, HiveException { + String filePath = "clipping_data.tsv.gz"; + String options = "-loss squaredloss -opt SGD -reg no -eta0 0.01 -iter 1"; + + GeneralRegressorUDTF udtf = new GeneralRegressorUDTF(); + + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, options); + + udtf.initialize(new ObjectInspector[] {stringListOI, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, params}); + + BufferedReader reader = readFile(filePath); + String line = reader.readLine(); + for (int i = 0; line != null; i++) { + //System.out.println("> " + i); + //System.out.println(line); + + StringTokenizer tokenizer = new StringTokenizer(line, " "); + double y = Double.parseDouble(tokenizer.nextToken()); + List<String> X = new ArrayList<String>(); + while (tokenizer.hasMoreTokens()) { + String f = tokenizer.nextToken(); + X.add(f); + } + FeatureValue[] features = udtf.parseFeatures(X); + if (DEBUG) { + printLine(features, y); + } + + float yhat = udtf.predict(features); + //System.out.println(yhat); + if (Float.isNaN(yhat)) { + Assert.fail("NaN cause in line: " + i); + } + + udtf.process(new Object[] {X, y}); + + line = reader.readLine(); + } + + udtf.finalizeTraining(); + } + + private static void printLine(FeatureValue[] features, final double y) { + Arrays.sort(features, new Comparator<FeatureValue>() { + @Override + public int compare(FeatureValue o1, FeatureValue o2) { + int f1 = Integer.parseInt(o1.getFeatureAsString()); + int f2 = Integer.parseInt(o2.getFeatureAsString()); + return Integer.compare(f1, f2); + } + }); + System.out.print(y); + System.out.print(' '); + System.out.println(StringUtils.join(features, ' ')); + } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = GeneralRegressorUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + private static void println(String msg) { if (DEBUG) { System.out.println(msg); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4f795cb9/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz ---------------------------------------------------------------------- diff --git a/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz b/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz new file mode 100644 index 0000000..c55a999 Binary files /dev/null and b/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz differ
