[HIVEMALL-201] Evaluate, fix and document FFM ## What changes were proposed in this pull request?
Applied some refactoring to #149 This PR closes #149 ## What type of PR is it? Hot Fix, Refactoring ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-201 ## How was this patch tested? unit tests, manual tests ## How to use this feature? Will be published at: http://hivemall.incubator.apache.org/userguide/binaryclass/criteo_ffm.html ## Checklist (Please remove this section if not needed; check `x` for YES, blank for NO) - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Takuya Kitazawa <[email protected]> Author: Makoto Yui <[email protected]> Closes #155 from myui/HIVEMALL-201-2. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/61711fbc Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/61711fbc Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/61711fbc Branch: refs/heads/master Commit: 61711fbc2be109a200f1773958b5a8c519f5a066 Parents: b88e9f5 Author: Takuya Kitazawa <[email protected]> Authored: Thu Aug 23 20:05:04 2018 +0900 Committer: Makoto Yui <[email protected]> Committed: Thu Aug 23 20:05:04 2018 +0900 ---------------------------------------------------------------------- .../java/hivemall/fm/FMHyperParameters.java | 61 +++- .../hivemall/fm/FactorizationMachineModel.java | 63 +++- .../hivemall/fm/FactorizationMachineUDTF.java | 151 +++++--- .../fm/FieldAwareFactorizationMachineModel.java | 10 +- .../fm/FieldAwareFactorizationMachineUDTF.java | 32 +- .../ftvec/pairing/FeaturePairsUDTF.java | 10 +- .../ftvec/scaling/L1NormalizationUDF.java | 5 + .../ftvec/scaling/L2NormalizationUDF.java | 5 + .../hivemall/mf/BPRMatrixFactorizationUDTF.java | 20 +- .../mf/OnlineMatrixFactorizationUDTF.java | 20 +- .../hivemall/tools/mapred/RowNumberUDF.java | 3 +- .../java/hivemall/utils/lang/Primitives.java | 3 + .../main/java/hivemall/utils/lang/SizeOf.java | 1 + .../fm/FactorizationMachineUDTFTest.java | 135 +++++++ .../FieldAwareFactorizationMachineUDTFTest.java | 185 ++++++---- .../ftvec/scaling/L1NormalizationUDFTest.java | 6 + .../ftvec/scaling/L2NormalizationUDFTest.java | 6 + docs/gitbook/SUMMARY.md | 3 + docs/gitbook/binaryclass/criteo.md | 20 ++ docs/gitbook/binaryclass/criteo_dataset.md | 97 +++++ docs/gitbook/binaryclass/criteo_ffm.md | 356 +++++++++++++++++++ 21 files changed, 1011 insertions(+), 181 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FMHyperParameters.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java b/core/src/main/java/hivemall/fm/FMHyperParameters.java index 0992325..edee14f 100644 --- a/core/src/main/java/hivemall/fm/FMHyperParameters.java +++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java @@ -28,7 +28,8 @@ import org.apache.commons.cli.CommandLine; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; class FMHyperParameters { - private static final float DEFAULT_ETA0 = 0.05f; + protected static final float DEFAULT_ETA0 = 0.1f; + protected static final float DEFAULT_LAMBDA = 0.0001f; // ------------------------------------- // Model parameters @@ -37,10 +38,10 @@ class FMHyperParameters { int factors = 5; // regularization - float lambda = 0.01f; - float lambdaW0 = 0.01f; - float lambdaW = 0.01f; - float lambdaV = 0.01f; + float lambda = DEFAULT_LAMBDA; + float lambdaW0; + float lambdaW; + float lambdaV; // V initialization double sigma = 0.1d; @@ -62,10 +63,12 @@ class FMHyperParameters { boolean l2norm; // enable by default for FFM. disabled by default for FM. - int iters = 1; + int iters = 10; boolean conversionCheck = true; double convergenceRate = 0.005d; + boolean earlyStopping = false; + // adaptive regularization boolean adaptiveRegularization = false; float validationRatio = 0.05f; @@ -89,10 +92,14 @@ class FMHyperParameters { void processOptions(@Nonnull CommandLine cl) throws UDFArgumentException { this.classification = cl.hasOption("classification"); - this.factors = Primitives.parseInt(cl.getOptionValue("factors"), factors); + if (cl.hasOption("factor")) { + this.factors = Primitives.parseInt(cl.getOptionValue("factor"), factors); + } else { + this.factors = Primitives.parseInt(cl.getOptionValue("factors"), factors); + } this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), lambda); this.lambdaW0 = Primitives.parseFloat(cl.getOptionValue("lambda_w0"), lambda); - this.lambdaW = Primitives.parseFloat(cl.getOptionValue("lambda_w"), lambda); + this.lambdaW = Primitives.parseFloat(cl.getOptionValue("lambda_wi"), lambda); this.lambdaV = Primitives.parseFloat(cl.getOptionValue("lambda_v"), lambda); this.sigma = Primitives.parseDouble(cl.getOptionValue("sigma"), sigma); this.seed = Primitives.parseLong(cl.getOptionValue("seed"), seed); @@ -105,10 +112,15 @@ class FMHyperParameters { this.eta = EtaEstimator.get(cl, DEFAULT_ETA0); this.numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), numFeatures); this.l2norm = cl.hasOption("enable_norm"); - this.iters = Primitives.parseInt(cl.getOptionValue("iterations"), iters); + if (cl.hasOption("iter")) { + this.iters = Primitives.parseInt(cl.getOptionValue("iter"), iters); + } else { + this.iters = Primitives.parseInt(cl.getOptionValue("iterations"), iters); + } this.conversionCheck = !cl.hasOption("disable_cvtest"); this.convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate); + this.earlyStopping = cl.hasOption("early_stopping"); this.adaptiveRegularization = cl.hasOption("adaptive_regularization"); this.validationRatio = Primitives.parseFloat(cl.getOptionValue("validation_ratio"), validationRatio); @@ -122,14 +134,13 @@ class FMHyperParameters { } @Nonnull - private static VInitScheme instantiateVInit(@Nonnull CommandLine cl, int factor, long seed, + private VInitScheme instantiateVInit(@Nonnull CommandLine cl, int factor, long seed, final boolean classification) { String vInitOpt = cl.getOptionValue("init_v"); float maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 0.5f); double initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1d); - VInitScheme defaultInit = classification ? VInitScheme.gaussian : VInitScheme.random; - VInitScheme vInit = VInitScheme.resolve(vInitOpt, defaultInit); + VInitScheme vInit = VInitScheme.resolve(vInitOpt, getDefaultVinitScheme()); vInit.setMaxInitValue(maxInitValue); initStdDev = Math.max(initStdDev, 1.0d / factor); vInit.setInitStdDev(initStdDev); @@ -137,11 +148,16 @@ class FMHyperParameters { return vInit; } + @Nonnull + protected VInitScheme getDefaultVinitScheme() { + return classification ? VInitScheme.gaussian : VInitScheme.adjustedRandom; + } + public static final class FFMHyperParameters extends FMHyperParameters { // FFM hyper parameters boolean globalBias = false; - boolean linearCoeff = true; + boolean linearCoeff = false; // feature hashing int numFields = Feature.DEFAULT_NUM_FIELDS; @@ -152,15 +168,20 @@ class FMHyperParameters { // FTRL boolean useFTRL = false; - float alphaFTRL = 0.2f; // Learning Rate - float betaFTRL = 1.f; // Smoothing parameter for AdaGrad - float lambda1 = 0.001f; // L1 Regularization + float alphaFTRL = 0.5f; // Learning Rate + float betaFTRL = 1.0f; // Smoothing parameter for AdaGrad + float lambda1 = 0.0002f; // L1 Regularization float lambda2 = 0.0001f; // L2 Regularization FFMHyperParameters() { super(); } + @Nonnull + protected VInitScheme getDefaultVinitScheme() { + return VInitScheme.random; + } + @Override void processOptions(@Nonnull CommandLine cl) throws UDFArgumentException { super.processOptions(cl); @@ -170,7 +191,13 @@ class FMHyperParameters { } this.globalBias = cl.hasOption("global_bias"); - this.linearCoeff = !cl.hasOption("no_coeff"); + this.linearCoeff = cl.hasOption("linear_term"); + + if (cl.hasOption("enable_norm") && cl.hasOption("disable_norm")) { + throw new UDFArgumentException( + "-enable_norm and -disable_norm MUST NOT be used simultaneously"); + } + this.l2norm = !cl.hasOption("disable_norm"); // feature hashing if (numFeatures == -1) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FactorizationMachineModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java index bb97bef..c654f32 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java @@ -271,7 +271,7 @@ public abstract class FactorizationMachineModel { * sum_f_dash := \sum_{j} x_j * v'_lj, this is independent of the groups * sum_f(g) := \sum_{j \in group(g)} x_j * v_jf * sum_f_dash_f(g) := \sum_{j \in group(g)} x^2_j * v_jf * v'_jf - * := \sum_{j \in group(g)} x_j * v'_jf * x_j * v_jf + * := \sum_{j \in group(g)} x_j * v'_jf * x_j * v_jf * v_jf' := v_jf - alpha ( grad_v_jf + 2 * lambda_v_f * v_jf) * </pre> */ @@ -336,7 +336,7 @@ public abstract class FactorizationMachineModel { public void check(@Nonnull Feature[] x) throws HiveException {} public enum VInitScheme { - random /* default */, gaussian; + adjustedRandom /* default */, libffmRandom, random, gaussian; @Nonnegative float maxInitValue; @@ -346,7 +346,7 @@ public abstract class FactorizationMachineModel { @Nonnull public static VInitScheme resolve(@Nullable String opt) { - return resolve(opt, random); + return resolve(opt, adjustedRandom); } @Nonnull @@ -354,10 +354,16 @@ public abstract class FactorizationMachineModel { @Nonnull VInitScheme defaultScheme) { if (opt == null) { return defaultScheme; - } else if ("gaussian".equalsIgnoreCase(opt)) { - return gaussian; + } else if ("adjusted_random".equalsIgnoreCase(opt) + || "adjustedRandom".equalsIgnoreCase(opt)) { + return adjustedRandom; + } else if ("libffm_random".equalsIgnoreCase(opt) || "libffmRandom".equalsIgnoreCase(opt) + || "libffm".equalsIgnoreCase(opt)) { + return VInitScheme.libffmRandom; } else if ("random".equalsIgnoreCase(opt)) { return random; + } else if ("gaussian".equalsIgnoreCase(opt)) { + return gaussian; } return defaultScheme; } @@ -371,7 +377,7 @@ public abstract class FactorizationMachineModel { } public void initRandom(int factor, long seed) { - int size = (this == random) ? 1 : factor; + final int size = (this != gaussian) ? 1 : factor; this.rand = new Random[size]; for (int i = 0; i < size; i++) { rand[i] = new Random(seed + i); @@ -383,8 +389,14 @@ public abstract class FactorizationMachineModel { protected final float[] initV() { final float[] ret = new float[_factor]; switch (_initScheme) { + case adjustedRandom: + adjustedRandomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue); + break; + case libffmRandom: + libffmRandomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue); + break; case random: - uniformFill(ret, _initScheme.rand[0], _initScheme.maxInitValue); + randomFill(ret, _initScheme.rand[0], _initScheme.maxInitValue); break; case gaussian: gaussianFill(ret, _initScheme.rand, _initScheme.initStdDev); @@ -396,19 +408,42 @@ public abstract class FactorizationMachineModel { return ret; } - protected static final void uniformFill(final float[] a, final Random rand, - final float maxInitValue) { - final int len = a.length; - final float basev = maxInitValue / len; - for (int i = 0; i < len; i++) { + protected static final void adjustedRandomFill(@Nonnull final float[] a, + @Nonnull final Random rand, final float maxInitValue) { + final int k = a.length; + final float basev = maxInitValue / k; + for (int i = 0; i < k; i++) { float v = rand.nextFloat() * basev; a[i] = v; } } - protected static final void gaussianFill(final float[] a, final Random[] rand, + // libffm's V initialization scheme: 1/sqrt(k) + // https://github.com/guestwalk/libffm/blob/master/ffm.cpp#L287 + protected static final void libffmRandomFill(@Nonnull final float[] a, + @Nonnull final Random rand, final float maxInitValue) { + final int k = a.length; + final float basev = maxInitValue / (float) Math.sqrt(k); + for (int i = 0; i < k; i++) { + float v = rand.nextFloat() * basev; + a[i] = v; + } + } + + protected static final void randomFill(@Nonnull final float[] a, @Nonnull final Random rand, + final float maxInitValue) { + final int k = a.length; + for (int i = 0; i < k; i++) { + float v = rand.nextFloat() * maxInitValue; + a[i] = v; + } + } + + // libfm uses gaussian for initialization + // https://github.com/srendle/libfm/blob/30b9c799c41d043f31565cbf827bf41d0dc3e2ab/src/fm_core/fm_model.h#L96 + protected static final void gaussianFill(@Nonnull final float[] a, @Nonnull final Random[] rand, final double stddev) { - for (int i = 0, len = a.length; i < len; i++) { + for (int i = 0, k = a.length; i < k; i++) { float v = (float) MathUtils.gaussian(0.d, stddev, rand[i]); a[i] = v; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java index eadd451..a253729 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java @@ -18,6 +18,11 @@ */ package hivemall.fm; +import static hivemall.fm.FMHyperParameters.DEFAULT_ETA0; +import static hivemall.fm.FMHyperParameters.DEFAULT_LAMBDA; +import static hivemall.utils.lang.Primitives.FALSE_BYTE; +import static hivemall.utils.lang.Primitives.TRUE_BYTE; + import hivemall.UDTFWithOptions; import hivemall.annotations.VisibleForTesting; import hivemall.common.ConversionState; @@ -65,6 +70,8 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.Counters.Counter; import org.apache.hadoop.mapred.Reporter; +import com.google.common.base.Preconditions; + @Description(name = "train_fm", value = "_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model") public class FactorizationMachineUDTF extends UDTFWithOptions { @@ -89,7 +96,11 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { protected int _factors; protected boolean _parseFeatureAsInt; + protected boolean _earlyStopping; + protected ConversionState _validationState; + // adaptive regularization + protected boolean _adaptiveRegularization; @Nullable protected Random _va_rand; protected float _validationRatio; @@ -107,6 +118,10 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { * The number of training examples processed */ protected long _t; + /** + * The number of validation examples + */ + protected long _numValidations; // file IO private ByteBuffer _inputBuf; @@ -117,24 +132,28 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { Options opts = new Options(); opts.addOption("c", "classification", false, "Act as classification"); opts.addOption("seed", true, "Seed value [default: -1 (random)]"); - opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]"); + opts.addOption("iters", "iterations", true, "The number of iterations [default: 10]"); + opts.addOption("iter", true, "The number of iterations [default: 10]." + + " Note this is alias of `iters` for backward compatibility"); opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]"); opts.addOption("f", "factors", true, "The number of the latent variables [default: 5]"); + opts.addOption("k", "factor", true, + "The number of the latent variables [default: 5]" + " Alias of `-factors` option"); opts.addOption("sigma", true, "The standard deviation for initializing V [default: 0.1]"); opts.addOption("lambda0", "lambda", true, - "The initial lambda value for regularization [default: 0.01]"); + "The initial lambda value for regularization [default: " + DEFAULT_LAMBDA + "]"); opts.addOption("lambdaW0", "lambda_w0", true, - "The initial lambda value for W0 regularization [default: 0.01]"); + "The initial lambda value for W0 regularization [default: " + DEFAULT_LAMBDA + "]"); opts.addOption("lambdaWi", "lambda_wi", true, - "The initial lambda value for Wi regularization [default: 0.01]"); + "The initial lambda value for Wi regularization [default: " + DEFAULT_LAMBDA + "]"); opts.addOption("lambdaV", "lambda_v", true, - "The initial lambda value for V regularization [default: 0.01]"); + "The initial lambda value for V regularization [default: " + DEFAULT_LAMBDA + "]"); // regression opts.addOption("min", "min_target", true, "The minimum value of target variable"); opts.addOption("max", "max_target", true, "The maximum value of target variable"); // learning rates opts.addOption("eta", true, "The initial learning rate"); - opts.addOption("eta0", true, "The initial learning rate [default 0.05]"); + opts.addOption("eta0", true, "The initial learning rate [default " + DEFAULT_ETA0 + "]"); opts.addOption("t", "total_steps", true, "The total number of training examples"); opts.addOption("power_t", true, "The exponent for inverse scaling learning rate [default 0.1]"); @@ -143,19 +162,22 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { "Whether to disable convergence check [default: OFF]"); opts.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]"); - // adaptive regularization + // adaptive regularization and early stopping with randomly hold-out validation samples + opts.addOption("early_stopping", false, + "Stop at the iteration that achieves the best validation on partial samples [default: OFF]"); + opts.addOption("va_ratio", "validation_ratio", true, + "Ratio of training data used for validation [default: 0.05f]"); + opts.addOption("va_threshold", "validation_threshold", true, + "Threshold to start validation. " + + "At least N training examples are used before validation [default: 1000]"); if (isAdaptiveRegularizationSupported()) { opts.addOption("adareg", "adaptive_regularization", false, "Whether to enable adaptive regularization [default: OFF]"); - opts.addOption("va_ratio", "validation_ratio", true, - "Ratio of training data used for validation [default: 0.05f]"); - opts.addOption("va_threshold", "validation_threshold", true, - "Threshold to start validation. " - + "At least N training examples are used before validation [default: 1000]"); } // initialization of V - opts.addOption("init_v", true, "Initialization strategy of matrix V [random, gaussian]" - + "(default: 'random' for regression / 'gaussian' for classification)"); + opts.addOption("init_v", true, + "Initialization strategy of matrix V [adjusted_random, libffm, random, gaussian]" + + "(FM default: 'adjusted_random' for regression, 'gaussian' for classification, FFM default: random)"); opts.addOption("maxval", "max_init_value", true, "The maximum initial value in the matrix V [default: 0.5]"); opts.addOption("min_init_stddev", true, @@ -188,9 +210,12 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { this._iterations = params.iters; this._factors = params.factors; this._parseFeatureAsInt = params.parseFeatureAsInt; - if (params.adaptiveRegularization) { + this._earlyStopping = params.earlyStopping; + this._adaptiveRegularization = params.adaptiveRegularization; + if (_earlyStopping || _adaptiveRegularization) { this._va_rand = new Random(params.seed + 31L); } + this._validationState = new ConversionState(); this._validationRatio = params.validationRatio; this._validationThreshold = params.validationThreshold; this._lossFunction = params.classification ? LossFunctions.getLossFunction(LossType.LogLoss) @@ -216,6 +241,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { this._model = null; this._t = 0L; + this._numValidations = 0L; if (LOG.isInfoEnabled()) { LOG.info(_params); @@ -276,16 +302,23 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { return; } this._probes = x; + _model.check(x); // mostly for FMIntFeatureMapModel double y = PrimitiveObjectInspectorUtils.getDouble(args[1], _yOI); if (_classification) { y = (y > 0.d) ? 1.d : -1.d; } - ++_t; - recordTrain(x, y); - boolean adaptiveRegularization = (_va_rand != null) && _t >= _validationThreshold; - train(x, y, adaptiveRegularization); + boolean validation = isValidationExample(); + recordTrain(x, y, validation); + train(x, y, validation); + } + + private boolean isValidationExample() { + if (_va_rand != null && _t >= _validationThreshold) { + return _va_rand.nextFloat() < _validationRatio; + } + return false; } @Nullable @@ -297,7 +330,8 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { return features; } - protected void recordTrain(@Nonnull final Feature[] x, final double y) throws HiveException { + private void recordTrain(@Nonnull final Feature[] x, final double y, final boolean validation) + throws HiveException { if (_iterations <= 1) { return; } @@ -325,7 +359,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { } int xBytes = Feature.requiredBytes(x); - int recordBytes = SizeOf.INT + SizeOf.DOUBLE + xBytes; + int recordBytes = SizeOf.INT + SizeOf.DOUBLE + xBytes + SizeOf.BYTE; int requiredBytes = SizeOf.INT + recordBytes; int remain = inputBuf.remaining(); if (remain < requiredBytes) { @@ -338,6 +372,12 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { f.writeTo(inputBuf); } inputBuf.putDouble(y); + if (validation) { + ++_numValidations; + inputBuf.put(TRUE_BYTE); + } else { + inputBuf.put(FALSE_BYTE); + } } private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefulSegment dst) @@ -351,20 +391,13 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { srcBuf.clear(); } - public void train(@Nonnull final Feature[] x, final double y, - final boolean adaptiveRegularization) throws HiveException { - _model.check(x); - + private void train(@Nonnull final Feature[] x, final double y, final boolean validation) + throws HiveException { try { - if (adaptiveRegularization) { - assert (_va_rand != null); - final float rnd = _va_rand.nextFloat(); - if (rnd < _validationRatio) { - trainLambda(x, y); // adaptive regularization - } else { - trainTheta(x, y); - } + if (validation) { + processValidationSample(x, y); } else { + ++_t; trainTheta(x, y); } } catch (Exception ex) { @@ -372,6 +405,18 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { } } + protected void processValidationSample(@Nonnull final Feature[] x, final double y) + throws HiveException { + if (_earlyStopping) { + double p = _model.predict(x); + double loss = _lossFunction.loss(p, y); + _validationState.incrLoss(loss); + } + if (_adaptiveRegularization) { + trainLambda(x, y); // adaptive regularization + } + } + /** * Update model parameters */ @@ -410,7 +455,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { * grad_lambdafg = (grad l(p,y)) * (-2 * alpha * (\sum_{l} x_l * v'_lf) * \sum_{l \in group(g)} x_l * v_lf) - \sum_{l \in group(g)} x^2_l * v_lf * v'_lf) * </pre> */ - protected void trainLambda(final Feature[] x, final double y) throws HiveException { + private void trainLambda(final Feature[] x, final double y) throws HiveException { final float eta = _etaEstimator.eta(_t); final double p = _model.predict(x); final double lossGrad = _model.dloss(p, y); @@ -534,12 +579,11 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { } protected void runTrainingIteration(int iterations) throws HiveException { - final ByteBuffer inputBuf = this._inputBuf; - final NioStatefulSegment fileIO = this._fileIO; - assert (inputBuf != null); - assert (fileIO != null); + final ByteBuffer inputBuf = Preconditions.checkNotNull(this._inputBuf); + final NioStatefulSegment fileIO = Preconditions.checkNotNull(this._fileIO); + final long numTrainingExamples = _t; - final boolean adaregr = _va_rand != null; + boolean lossIncreasedLastIter = false; final Reporter reporter = getReporter(); final Counter iterCounter = (reporter == null) ? null @@ -553,6 +597,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { inputBuf.flip(); for (int iter = 2; iter <= iterations; iter++) { + _validationState.next(); _cvState.next(); reportProgress(reporter); setCounterValue(iterCounter, iter); @@ -566,19 +611,25 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { x[j] = instantiateFeature(inputBuf); } double y = inputBuf.getDouble(); + boolean validation = (inputBuf.get() == TRUE_BYTE); + // invoke train - ++_t; - train(x, y, adaregr); + train(x, y, validation); } - if (_cvState.isConverged(numTrainingExamples)) { + // stop if validation loss is consecutively increased over recent 2 iterations + final boolean lossIncreased = _validationState.isLossIncreased(); + if ((lossIncreasedLastIter && lossIncreased) + || _cvState.isConverged(numTrainingExamples)) { break; } + lossIncreasedLastIter = lossIncreased; inputBuf.rewind(); } LOG.info("Performed " + _cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(_t) - + " training updates in total) "); + + " training updates in total), used " + _numValidations + + " validation examples"); } else {// read training examples in the temporary file and invoke train for each example // write training examples in buffer to a temporary file @@ -601,6 +652,7 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { // run iterations for (int iter = 2; iter <= iterations; iter++) { + _validationState.next(); _cvState.next(); setCounterValue(iterCounter, iter); @@ -643,23 +695,28 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { x[j] = instantiateFeature(inputBuf); } double y = inputBuf.getDouble(); + boolean validation = (inputBuf.get() == TRUE_BYTE); // invoke training - ++_t; - train(x, y, adaregr); + train(x, y, validation); remain -= recordBytes; } inputBuf.compact(); } - if (_cvState.isConverged(numTrainingExamples)) { + // stop if validation loss is consecutively increased over recent 2 iterations + final boolean lossIncreased = _validationState.isLossIncreased(); + if ((lossIncreasedLastIter && lossIncreased) + || _cvState.isConverged(numTrainingExamples)) { break; } + lossIncreasedLastIter = lossIncreased; } LOG.info("Performed " + _cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on a secondary storage (thus " - + NumberUtils.formatNumber(_t) + " training updates in total)"); + + NumberUtils.formatNumber(_t) + " training updates in total), used " + + _numValidations + " validation examples"); } } finally { // delete the temporary file and release resources http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java index c6c0fd0..6cd8fe8 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java @@ -36,13 +36,12 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM @Nonnull protected final FFMHyperParameters _params; - protected final float _eta0; protected final float _eps; protected final boolean _useAdaGrad; protected final boolean _useFTRL; - // FTEL + // FTRL private final float _alpha; private final float _beta; private final float _lambda1; @@ -51,11 +50,6 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM public FieldAwareFactorizationMachineModel(@Nonnull FFMHyperParameters params) { super(params); this._params = params; - if (params.useAdaGrad) { - this._eta0 = 1.0f; - } else { - this._eta0 = params.eta.eta0(); - } this._eps = params.eps; this._useAdaGrad = params.useAdaGrad; this._useFTRL = params.useFTRL; @@ -261,7 +255,7 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM if (_useAdaGrad) { double gg = theta.getSumOfSquaredGradients(f); theta.addGradient(f, grad); - return (float) (_eta0 / Math.sqrt(_eps + gg)); + return (float) (_eta.eta(t) / Math.sqrt(_eps + gg)); } else { return _eta.eta(t); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java index 610fa3d..7987086 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java @@ -52,7 +52,7 @@ import org.apache.hadoop.io.Text; /** * Field-aware Factorization Machines. - * + * * @link https://www.csie.ntu.edu.tw/~cjlin/libffm/ * @since v0.5-rc.1 */ @@ -70,7 +70,7 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi private int _numFields; // ---------------------------------------- - private transient FFMStringFeatureMapModel _ffmModel; + protected transient FFMStringFeatureMapModel _ffmModel; private transient IntArrayList _fieldList; @Nullable @@ -85,12 +85,13 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi Options opts = super.getOptions(); opts.addOption("w0", "global_bias", false, "Whether to include global bias term w0 [default: OFF]"); - opts.addOption("disable_wi", "no_coeff", false, - "Not to include linear term [default: OFF]"); + opts.addOption("enable_wi", "linear_term", false, "Include linear term [default: OFF]"); + opts.addOption("no_norm", "disable_norm", false, "Disable instance-wise L2 normalization"); // feature hashing opts.addOption("feature_hashing", true, "The number of bits for feature hashing in range [18,31] [default: -1]. No feature hashing for -1."); - opts.addOption("num_fields", true, "The number of fields [default: 256]"); + opts.addOption("num_fields", true, + "The number of fields [default: " + Feature.DEFAULT_NUM_FIELDS + "]"); // optimizer opts.addOption("opt", "optimizer", true, "Gradient Descent optimizer [default: ftrl, adagrad, sgd]"); @@ -98,11 +99,11 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default: 1.0]"); // FTRL opts.addOption("alpha", "alphaFTRL", true, - "Alpha value (learning rate) of Follow-The-Regularized-Reader [default: 0.2]"); + "Alpha value (learning rate) of Follow-The-Regularized-Reader [default: 0.5]"); opts.addOption("beta", "betaFTRL", true, "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default: 1.0]"); opts.addOption("l1", "lambda1", true, - "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default: 0.001]"); + "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default: 0.0002]"); opts.addOption("l2", "lambda2", true, "L2 regularization value of Follow-The-Regularized-Reader [default: 0.0001]"); return opts; @@ -180,13 +181,12 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi } @Override - public void train(@Nonnull final Feature[] x, final double y, - final boolean adaptiveRegularization) throws HiveException { - _ffmModel.check(x); - try { - trainTheta(x, y); - } catch (Exception ex) { - throw new HiveException("Exception caused in the " + _t + "-th call of train()", ex); + protected void processValidationSample(@Nonnull final Feature[] x, final double y) + throws HiveException { + if (_earlyStopping) { + double p = _model.predict(x); + double loss = _lossFunction.loss(p, y); + _validationState.incrLoss(loss); } } @@ -292,7 +292,7 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi forward(forwardObjs); final Entry entryW = new Entry(_ffmModel._buf, 1); - final Entry entryV = new Entry(_ffmModel._buf, _ffmModel._factor); + final Entry entryV = new Entry(_ffmModel._buf, factors); final float[] Vf = new float[factors]; for (Int2LongMap.Entry e : Fastutil.fastIterable(_ffmModel._map)) { @@ -303,7 +303,7 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi final long offset = e.getLongValue(); if (Entry.isEntryW(i)) {// set Wi entryW.setOffset(offset); - float w = entryV.getW(); + float w = entryW.getW(); if (w == 0.f) { continue; // skip w_i=0 } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java index 3f959e5..c46e470 100644 --- a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java +++ b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java @@ -55,6 +55,7 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { private RowProcessor _proc; private int _numFields; private int _numFeatures; + private boolean _l2norm; public FeaturePairsUDTF() {} @@ -69,7 +70,9 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]"); opts.addOption("feature_hashing", true, "The number of bits for feature hashing in range [18,31]. [default: -1] No feature hashing for -1."); - opts.addOption("num_fields", true, "The number of fields [default:1024]"); + opts.addOption("num_fields", true, + "The number of fields [default: " + Feature.DEFAULT_NUM_FIELDS + "]"); + opts.addOption("no_norm", "disable_norm", false, "Disable instance-wise L2 normalization"); return opts; } @@ -104,6 +107,7 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { throw new UDFArgumentException( "-num_fields MUST be greater than 1: " + _numFields); } + this._l2norm = !cl.hasOption("disable_norm"); } else { throw new UDFArgumentException("Unsupported option: " + cl.getArgList().get(0)); } @@ -285,6 +289,10 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { this._features = Feature.parseFFMFeatures(arg, fvOI, _features, _numFeatures, _numFields); + if (_l2norm) { + Feature.l2normalize(_features); + } + // W0 f0.set(0); forward[1] = null; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java b/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java index 45ef97d..e5de329 100644 --- a/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java +++ b/core/src/main/java/hivemall/ftvec/scaling/L1NormalizationUDF.java @@ -56,6 +56,11 @@ public final class L1NormalizationUDF extends UDF { float v = Float.parseFloat(ft[1]); weights[i] = v; absoluteSum += Math.abs(v); + } else if (ftlen == 3) { + features[i] = ft[0] + ':' + ft[1]; + float v = Float.parseFloat(ft[2]); + weights[i] = v; + absoluteSum += Math.abs(v); } else { throw new HiveException("Invalid feature value representation: " + s); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java b/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java index 9cf315c..fa70f10 100644 --- a/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java +++ b/core/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDF.java @@ -59,6 +59,11 @@ public final class L2NormalizationUDF extends UDF { float v = Float.parseFloat(ft[1]); weights[i] = v; squaredSum += (v * v); + } else if (ftlen == 3) { + features[i] = ft[0] + ':' + ft[1]; + float v = Float.parseFloat(ft[2]); + weights[i] = v; + squaredSum += (v * v); } else { throw new HiveException("Invalid feature value representation: " + s); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java index 76d52ab..23d9b63 100644 --- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java @@ -137,8 +137,12 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements @Override protected Options getOptions() { Options opts = new Options(); - opts.addOption("k", "factor", true, "The number of latent factor [default: 10]"); - opts.addOption("iter", "iterations", true, "The number of iterations [default: 30]"); + opts.addOption("k", "factor", true, + "The number of latent factor [default: 10] Alias for `-factors`"); + opts.addOption("f", "factors", true, "The number of latent factor [default: 10]"); + opts.addOption("iters", "iterations", true, "The number of iterations [default: 30]"); + opts.addOption("iter", true, + "The number of iterations [default: 30] Alias for `-iterations"); opts.addOption("loss", "loss_function", true, "Loss function [default: lnLogistic, logistic, sigmoid]"); // initialization @@ -191,8 +195,16 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements String rawArgs = HiveUtils.getConstString(argOIs[3]); cl = parseOptions(rawArgs); - this.factor = Primitives.parseInt(cl.getOptionValue("factor"), factor); - this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), iterations); + if (cl.hasOption("factor")) { + this.factor = Primitives.parseInt(cl.getOptionValue("factor"), factor); + } else { + this.factor = Primitives.parseInt(cl.getOptionValue("factors"), factor); + } + if (cl.hasOption("iter")) { + this.iterations = Primitives.parseInt(cl.getOptionValue("iter"), iterations); + } else { + this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), iterations); + } if (iterations < 1) { throw new UDFArgumentException( "'-iterations' must be greater than or equals to 1: " + iterations); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java index 537706e..9d7e1d1 100644 --- a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java @@ -106,7 +106,9 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions @Override protected Options getOptions() { Options opts = new Options(); - opts.addOption("k", "factor", true, "The number of latent factor [default: 10]"); + opts.addOption("k", "factor", true, "The number of latent factor [default: 10] " + + " Note this is alias for `factors` option."); + opts.addOption("f", "factors", true, "The number of latent factor [default: 10]"); opts.addOption("r", "lambda", true, "The regularization factor [default: 0.03]"); opts.addOption("mu", "mean_rating", true, "The mean rating [default: 0.0]"); opts.addOption("update_mean", "update_mu", false, @@ -117,7 +119,9 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions "The maximum initial value in the rank matrix [default: 1.0]"); opts.addOption("min_init_stddev", true, "The minimum standard deviation of initial rank matrix [default: 0.1]"); - opts.addOption("iter", "iterations", true, "The number of iterations [default: 1]"); + opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]"); + opts.addOption("iter", true, + "The number of iterations [default: 1] Alias for `-iterations`"); opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: enabled]"); opts.addOption("cv_rate", "convergence_rate", true, @@ -138,14 +142,22 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions if (argOIs.length >= 4) { String rawArgs = HiveUtils.getConstString(argOIs[3]); cl = parseOptions(rawArgs); - this.factor = Primitives.parseInt(cl.getOptionValue("factor"), 10); + if (cl.hasOption("factors")) { + this.factor = Primitives.parseInt(cl.getOptionValue("factors"), 10); + } else { + this.factor = Primitives.parseInt(cl.getOptionValue("factor"), 10); + } this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), 0.03f); this.meanRating = Primitives.parseFloat(cl.getOptionValue("mu"), 0.f); this.updateMeanRating = cl.hasOption("update_mean"); rankInitOpt = cl.getOptionValue("rankinit"); maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 1.f); initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1d); - this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1); + if (cl.hasOption("iter")) { + this.iterations = Primitives.parseInt(cl.getOptionValue("iter"), 1); + } else { + this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1); + } if (iterations < 1) { throw new UDFArgumentException( "'-iterations' must be greater than or equal to 1: " + iterations); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java b/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java index 95c97dc..ca85cee 100644 --- a/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java +++ b/core/src/main/java/hivemall/tools/mapred/RowNumberUDF.java @@ -28,7 +28,8 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.io.LongWritable; -@Description(name = "rownum", value = "_FUNC_() - Returns a generated row number `sprintf(`%d%04d`,sequence,taskId)` in long", +@Description(name = "rownum", + value = "_FUNC_() - Returns a generated row number `sprintf(`%d%04d`,sequence,taskId)` in long", extended = "SELECT rownum() as rownum, xxx from ...") @UDFType(deterministic = false, stateful = true) public final class RowNumberUDF extends UDF { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/utils/lang/Primitives.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java index 7d43da1..ab3be9a 100644 --- a/core/src/main/java/hivemall/utils/lang/Primitives.java +++ b/core/src/main/java/hivemall/utils/lang/Primitives.java @@ -24,6 +24,9 @@ public final class Primitives { public static final int INT_BYTES = Integer.SIZE / Byte.SIZE; public static final int DOUBLE_BYTES = Double.SIZE / Byte.SIZE; + public static final Byte TRUE_BYTE = 1; + public static final Byte FALSE_BYTE = 0; + private Primitives() {} public static short parseShort(final String s, final short defaultValue) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/main/java/hivemall/utils/lang/SizeOf.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/SizeOf.java b/core/src/main/java/hivemall/utils/lang/SizeOf.java index 9e0ef4c..08cf664 100644 --- a/core/src/main/java/hivemall/utils/lang/SizeOf.java +++ b/core/src/main/java/hivemall/utils/lang/SizeOf.java @@ -29,4 +29,5 @@ public final class SizeOf { public static final int CHAR = Character.SIZE / Byte.SIZE; private SizeOf() {} + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java index 64da212..b6b83c5 100644 --- a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java @@ -92,6 +92,141 @@ public class FactorizationMachineUDTFTest { } @Test + public void testAdaptiveRegularization() throws HiveException, IOException { + println("Adaptive regularization test"); + + final String options = "-factors 5 -min 1 -max 5 -init_v gaussian -eta0 0.01 -seed 31 "; + + FactorizationMachineUDTF udtf = new FactorizationMachineUDTF(); + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, options)}; + + udtf.initialize(argOIs); + + BufferedReader data = readFile("5107786.txt.gz"); + List<List<String>> featureVectors = new ArrayList<>(); + List<Double> ys = new ArrayList<>(); + String line = data.readLine(); + while (line != null) { + StringTokenizer tokenizer = new StringTokenizer(line, " "); + double y = Double.parseDouble(tokenizer.nextToken()); + List<String> features = new ArrayList<String>(); + while (tokenizer.hasMoreTokens()) { + String f = tokenizer.nextToken(); + features.add(f); + } + udtf.process(new Object[] {features, y}); + featureVectors.add(features); + ys.add(y); + line = data.readLine(); + } + udtf.finalizeTraining(); + data.close(); + + double loss = udtf._cvState.getAverageLoss(featureVectors.size()); + println("Average loss without adaptive regularization: " + loss); + + // train with adaptive regularization + udtf = new FactorizationMachineUDTF(); + argOIs[2] = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + options + "-adaptive_regularization -validation_threshold 1"); + udtf.initialize(argOIs); + udtf.initModel(udtf._params); + for (int i = 0, n = featureVectors.size(); i < n; i++) { + udtf.process(new Object[] {featureVectors.get(i), ys.get(i)}); + } + udtf.finalizeTraining(); + + double loss_adareg = udtf._cvState.getAverageLoss(featureVectors.size()); + println("Average loss with adaptive regularization: " + loss_adareg); + Assert.assertTrue("Adaptive regularization should achieve lower loss", loss > loss_adareg); + } + + @Test + public void testEarlyStopping() throws HiveException, IOException { + println("Early stopping test"); + + int iters = 20; + + // train with 20 iterations + FactorizationMachineUDTF udtf = new FactorizationMachineUDTF(); + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-factors 5 -min 1 -max 5 -init_v gaussian -eta0 0.002 -seed 31 -iters " + iters + + " -early_stopping -validation_threshold 1 -disable_cv")}; + + udtf.initialize(argOIs); + + BufferedReader data = readFile("5107786.txt.gz"); + List<List<String>> featureVectors = new ArrayList<>(); + List<Double> ys = new ArrayList<>(); + String line = data.readLine(); + while (line != null) { + StringTokenizer tokenizer = new StringTokenizer(line, " "); + double y = Double.parseDouble(tokenizer.nextToken()); + List<String> features = new ArrayList<String>(); + while (tokenizer.hasMoreTokens()) { + String f = tokenizer.nextToken(); + features.add(f); + } + udtf.process(new Object[] {features, y}); + featureVectors.add(features); + ys.add(y); + line = data.readLine(); + } + udtf.finalizeTraining(); + data.close(); + + double loss = udtf._validationState.getAverageLoss(featureVectors.size()); + Assert.assertTrue( + "Training seems to be failed because average loss is greater than 0.1: " + loss, + loss <= 0.1); + + Assert.assertNotNull("Early stopping validation has not been conducted", + udtf._validationState); + println("Performed " + udtf._validationState.getCurrentIteration() + " iterations out of " + + iters); + Assert.assertNotEquals("Early stopping did not happen", iters, + udtf._validationState.getCurrentIteration()); + + // store the best state achieved by early stopping + iters = udtf._validationState.getCurrentIteration() - 2; // best loss was at (N-2)-th iter + double cumulativeLoss = udtf._validationState.getCumulativeLoss(); + println("Cumulative loss: " + cumulativeLoss); + + // train with the number of early-stopped iterations + udtf = new FactorizationMachineUDTF(); + argOIs[2] = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-factors 5 -min 1 -max 5 -init_v gaussian -eta0 0.002 -seed 31 -iters " + iters + + " -early_stopping -validation_threshold 1 -disable_cv"); + udtf.initialize(argOIs); + udtf.initModel(udtf._params); + for (int i = 0, n = featureVectors.size(); i < n; i++) { + udtf.process(new Object[] {featureVectors.get(i), ys.get(i)}); + } + udtf.finalizeTraining(); + + println("Performed " + udtf._validationState.getCurrentIteration() + " iterations out of " + + iters); + Assert.assertEquals("Training finished earlier than expected", iters, + udtf._validationState.getCurrentIteration()); + + println("Cumulative loss: " + udtf._validationState.getCumulativeLoss()); + Assert.assertTrue("Cumulative loss should be better than " + cumulativeLoss, + cumulativeLoss > udtf._validationState.getCumulativeLoss()); + } + + @Test public void testEnableL2Norm() throws HiveException, IOException { FactorizationMachineUDTF udtf = new FactorizationMachineUDTF(); ObjectInspector[] argOIs = new ObjectInspector[] { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java index 5b7aa8f..67040a1 100644 --- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java @@ -44,59 +44,61 @@ import org.junit.Test; public class FieldAwareFactorizationMachineUDTFTest { private static final boolean DEBUG = false; - private static final int ITERATIONS = 50; - private static final int MAX_LINES = 200; // ---------------------------------------------------- // bigdata.tr.txt @Test public void testSGD() throws HiveException, IOException { - runIterations("Pure SGD test", "bigdata.tr.txt.gz", - "-opt sgd -classification -factors 10 -w0 -seed 43", 0.60f); + run("Pure SGD test", "bigdata.tr.txt.gz", + "-opt sgd -linear_term -classification -factors 10 -w0 -eta 0.4 -iters 20 -seed 43", + 0.30f); } @Test public void testAdaGrad() throws HiveException, IOException { - runIterations("AdaGrad test", "bigdata.tr.txt.gz", - "-opt adagrad -classification -factors 10 -w0 -seed 43", 0.30f); + run("AdaGrad test", "bigdata.tr.txt.gz", + "-opt adagrad -linear_term -classification -factors 10 -w0 -eta 0.4 -iters 30 -seed 43", + 0.30f); } @Test public void testAdaGradNoCoeff() throws HiveException, IOException { - runIterations("AdaGrad No Coeff test", "bigdata.tr.txt.gz", - "-opt adagrad -no_coeff -classification -factors 10 -w0 -seed 43", 0.30f); + run("AdaGrad No Coeff test", "bigdata.tr.txt.gz", + "-opt adagrad -classification -factors 10 -w0 -eta 0.4 -iters 30 -seed 43", 0.30f); } @Test public void testFTRL() throws HiveException, IOException { - runIterations("FTRL test", "bigdata.tr.txt.gz", - "-opt ftrl -classification -factors 10 -w0 -seed 43", 0.30f); + run("FTRL test", "bigdata.tr.txt.gz", + "-opt ftrl -linear_term -classification -factors 10 -w0 -alphaFTRL 10.0 -seed 43", + 0.30f); } @Test public void testFTRLNoCoeff() throws HiveException, IOException { - runIterations("FTRL Coeff test", "bigdata.tr.txt.gz", - "-opt ftrl -no_coeff -classification -factors 10 -w0 -seed 43", 0.30f); + run("FTRL Coeff test", "bigdata.tr.txt.gz", + "-opt ftrl -classification -factors 10 -w0 -alphaFTRL 10.0 -seed 43", 0.30f); } // ---------------------------------------------------- // https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz @Test - public void testSample() throws IOException, HiveException { + public void testSampleDisableNorm() throws IOException, HiveException { System.setProperty("https.protocols", "TLSv1,TLSv1.1,TLSv1.2"); run("[Sample.ffm] default option", "https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz", - "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43", 0.01f); + "-disable_norm -linear_term -classification -factors 2 -feature_hashing 20 -seed 43", + 0.01f); } - // TODO @Test - public void testSampleEnableNorm() throws IOException, HiveException { + @Test + public void testSample() throws IOException, HiveException { System.setProperty("https.protocols", "TLSv1,TLSv1.1,TLSv1.2"); run("[Sample.ffm] default option", "https://github.com/myui/ml_dataset/raw/master/ffm/sample.ffm.gz", - "-classification -factors 2 -iters 10 -feature_hashing 20 -seed 43 -enable_norm", + "-linear_term -classification -factors 2 -alphaFTRL 10.0 -feature_hashing 20 -seed 43", 0.01f); } @@ -161,66 +163,112 @@ public class FieldAwareFactorizationMachineUDTFTest { avgLoss < lossThreshold); } - private static void runIterations(String testName, String testFile, String testOptions, - float lossThreshold) throws IOException, HiveException { - println(testName); + @Test + public void testEarlyStopping() throws HiveException, IOException { + println("Early stopping"); + + int iters = 20; FieldAwareFactorizationMachineUDTF udtf = new FieldAwareFactorizationMachineUDTF(); - ObjectInspector[] argOIs = - new ObjectInspector[] { - ObjectInspectorFactory.getStandardListObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector), - PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, - ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - testOptions)}; + ObjectInspector[] argOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-opt sgd -linear_term -classification -factors 10 -w0 -eta 0.4 -iters " + iters + + " -early_stopping -validation_threshold 1 -disable_cv -seed 43")}; udtf.initialize(argOIs); - FieldAwareFactorizationMachineModel model = udtf.initModel(udtf._params); - Assert.assertTrue("Actual class: " + model.getClass().getName(), - model instanceof FFMStringFeatureMapModel); - double loss = 0.d; - double cumul = 0.d; - for (int trainingIteration = 1; trainingIteration <= ITERATIONS; ++trainingIteration) { - BufferedReader data = readFile(testFile); - loss = udtf._cvState.getCumulativeLoss(); - int lines = 0; - for (int lineNumber = 0; lineNumber < MAX_LINES; ++lineNumber, ++lines) { - //gather features in current line - final String input = data.readLine(); - if (input == null) { - break; - } - String[] featureStrings = input.split(" "); + BufferedReader data = readFile("bigdata.tr.txt.gz"); + List<List<String>> featureVectors = new ArrayList<>(); + List<Double> ys = new ArrayList<>(); + while (true) { + //gather features in current line + final String input = data.readLine(); + if (input == null) { + break; + } + String[] featureStrings = input.split(" "); - double y = Double.parseDouble(featureStrings[0]); - if (y == 0) { - y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1} - } + double y = Double.parseDouble(featureStrings[0]); + if (y == 0) { + y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1} + } + ys.add(y); - final List<String> features = new ArrayList<String>(featureStrings.length - 1); - for (int j = 1; j < featureStrings.length; ++j) { - String fj = featureStrings[j]; - String[] splitted = fj.split(":"); - Assert.assertEquals(3, splitted.length); - String indexStr = splitted[1]; - String f = fj; - if (NumberUtils.isDigits(indexStr)) { - int index = Integer.parseInt(indexStr) + 1; // avoid 0 index - f = splitted[0] + ':' + index + ':' + splitted[2]; - } - features.add(f); + final List<String> features = new ArrayList<String>(featureStrings.length - 1); + for (int j = 1; j < featureStrings.length; ++j) { + String fj = featureStrings[j]; + String[] splitted = fj.split(":"); + Assert.assertEquals(3, splitted.length); + String indexStr = splitted[1]; + String f = fj; + if (NumberUtils.isDigits(indexStr)) { + int index = Integer.parseInt(indexStr) + 1; // avoid 0 index + f = splitted[0] + ':' + index + ':' + splitted[2]; } - udtf.process(new Object[] {features, y}); + features.add(f); } - cumul = udtf._cvState.getCumulativeLoss(); - loss = (cumul - loss) / lines; - println(trainingIteration + " " + loss + " " + cumul / (trainingIteration * lines)); - data.close(); + featureVectors.add(features); + + udtf.process(new Object[] {features, y}); } - println("model size=" + udtf._model.getSize()); - Assert.assertTrue("Last loss was greater than expected: " + loss, loss < lossThreshold); + udtf.finalizeTraining(); + data.close(); + + double loss = udtf._validationState.getAverageLoss(featureVectors.size()); + Assert.assertTrue( + "Training seems to be failed because average loss is greater than 0.6: " + loss, + loss <= 0.6); + + Assert.assertNotNull("Early stopping validation has not been conducted", + udtf._validationState); + println("Performed " + udtf._validationState.getCurrentIteration() + " iterations out of " + + iters); + Assert.assertNotEquals("Early stopping did not happen", iters, + udtf._validationState.getCurrentIteration()); + + // store the best state achieved by early stopping + iters = udtf._validationState.getCurrentIteration() - 2; // best loss was at (N-2)-th iter + double cumulativeLoss = udtf._validationState.getCumulativeLoss(); + println("Cumulative loss: " + cumulativeLoss); + + // train with the number of early-stopped iterations + udtf = new FieldAwareFactorizationMachineUDTF(); + argOIs[2] = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-opt sgd -linear_term -classification -factors 10 -w0 -eta 0.4 -iters " + iters + + " -early_stopping -validation_threshold 1 -disable_cv -seed 43"); + udtf.initialize(argOIs); + udtf.initModel(udtf._params); + for (int i = 0, n = featureVectors.size(); i < n; i++) { + udtf.process(new Object[] {featureVectors.get(i), ys.get(i)}); + } + udtf.finalizeTraining(); + + println("Performed " + udtf._validationState.getCurrentIteration() + " iterations out of " + + iters); + Assert.assertEquals("Training finished earlier than expected", iters, + udtf._validationState.getCurrentIteration()); + + println("Cumulative loss: " + udtf._validationState.getCumulativeLoss()); + Assert.assertTrue("Cumulative loss should be better than " + cumulativeLoss, + cumulativeLoss > udtf._validationState.getCumulativeLoss()); + } + + @Test(expected = IllegalArgumentException.class) + public void testUnsupportedAdaptiveRegularizationOption() throws Exception { + TestUtils.testGenericUDTFSerialization(FieldAwareFactorizationMachineUDTF.class, + new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-seed 43 -adaptive_regularization")}, + new Object[][] {{Arrays.asList("0:1:-2", "1:2:-1"), 1.0}}); } @Test @@ -231,8 +279,7 @@ public class FieldAwareFactorizationMachineUDTFTest { PrimitiveObjectInspectorFactory.javaStringObjectInspector), PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - "-opt sgd -classification -factors 10 -w0 -seed 43")}, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-seed 43")}, new Object[][] {{Arrays.asList("0:1:-2", "1:2:-1"), 1.0}}); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java b/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java index 7d997f7..bfb37fc 100644 --- a/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java +++ b/core/src/test/java/hivemall/ftvec/scaling/L1NormalizationUDFTest.java @@ -59,6 +59,12 @@ public class L1NormalizationUDFTest { WritableUtils.val(new String[] {"aaa:" + normalized[0], "bbb:" + normalized[1]}), udf.evaluate(WritableUtils.val(new String[] {"aaa:1.0", "bbb:-0.5"}))); + normalized = MathUtils.l1normalize(new float[] {1.0f, 2.0f, 3.0f}); + assertEquals( + WritableUtils.val(new String[] {"1:123:" + normalized[0], "2:456:" + normalized[1], + "3:789:" + normalized[2]}), + udf.evaluate(WritableUtils.val(new String[] {"1:123:1", "2:456:2", "3:789:3"}))); + List<Text> expected = udf.evaluate(WritableUtils.val(new String[] {"bbb:-0.5", "aaa:1.0"})); Collections.sort(expected); List<Text> actual = udf.evaluate(WritableUtils.val(new String[] {"aaa:1.0", "bbb:-0.5"})); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java b/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java index 30e2aba..393a9d2 100644 --- a/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java +++ b/core/src/test/java/hivemall/ftvec/scaling/L2NormalizationUDFTest.java @@ -59,6 +59,12 @@ public class L2NormalizationUDFTest { WritableUtils.val(new String[] {"aaa:" + 1.0f / l2norm, "bbb:" + -0.5f / l2norm}), udf.evaluate(WritableUtils.val(new String[] {"aaa:1.0", "bbb:-0.5"}))); + l2norm = MathUtils.l2norm(new float[] {1.0f, 2.0f, 3.0f}); + assertEquals( + WritableUtils.val(new String[] {"1:123:" + 1.0f / l2norm, "2:456:" + 2.0f / l2norm, + "3:789:" + 3.0f / l2norm}), + udf.evaluate(WritableUtils.val(new String[] {"1:123:1", "2:456:2", "3:789:3"}))); + List<Text> expected = udf.evaluate(WritableUtils.val(new String[] {"bbb:-0.5", "aaa:1.0"})); Collections.sort(expected); List<Text> actual = udf.evaluate(WritableUtils.val(new String[] {"aaa:1.0", "bbb:-0.5"})); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/docs/gitbook/SUMMARY.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index 56e416f..155a221 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -111,6 +111,9 @@ * [Kaggle Titanic Tutorial](binaryclass/titanic_rf.md) +* [Criteo Tutorial](binaryclass/criteo.md) + * [Data preparation](binaryclass/criteo_dataset.md) + * [Field-Aware Factorization Machines](binaryclass/criteo_ffm.md) ## Part VII - Multiclass Classification http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/docs/gitbook/binaryclass/criteo.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/binaryclass/criteo.md b/docs/gitbook/binaryclass/criteo.md new file mode 100644 index 0000000..3ad5f81 --- /dev/null +++ b/docs/gitbook/binaryclass/criteo.md @@ -0,0 +1,20 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +This tutorial tackles [Kaggle Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge). \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/61711fbc/docs/gitbook/binaryclass/criteo_dataset.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/binaryclass/criteo_dataset.md b/docs/gitbook/binaryclass/criteo_dataset.md new file mode 100644 index 0000000..c4c12ea --- /dev/null +++ b/docs/gitbook/binaryclass/criteo_dataset.md @@ -0,0 +1,97 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +<!-- toc --> + +# Download data + +Get dataset of [Kaggle Display Advertising Challenge](https://www.kaggle.com/c/criteo-display-ad-challenge) from one of the following sources: + +1. [Original competition data](http://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset/) (by Criteo Labs) [~20GB] +2. [Subset of the original competition data](http://labs.criteo.com/2014/02/dataset/) (by Criteo Labs) [~30MB] +3. [Tiny sample data](https://github.com/guestwalk/kaggle-2014-criteo) (by the winners of the competition) [~20bytes] + +It should be noted that you must accept and agree with **CRITEO LABS DATA TERM OF USE** before downloading the data. + +# Convert data into CSV format + +Here, you can use a script prepared by one of the Hivemall PPMC members: **[takuti/criteo-ffm](https://github.com/takuti/criteo-ffm)**. + +Clone the repository: + +```sh +git clone [email protected]:takuti/criteo-ffm.git +cd criteo-ffm +``` + +A script [`data.sh`](https://github.com/takuti/criteo-ffm/blob/master/data.sh) downloads the original data and converts them into CSV format: + +```sh +./data.sh # downloads the original data and generates `train.csv` and `test.csv` +ln -s train.csv tr.csv +ln -s test.csv te.csv +``` + +Or, since the original data is very huge, starting from the tiny sample data bundled into the repository would be better: + +```sh +ln -s train.tiny.csv tr.csv +ln -s test.tiny.csv te.csv +``` + +# Create tables + +Load the CSV files to Hive tables as: + +```sh +hadoop fs -put tr.csv /criteo/train +hadoop fs -put te.csv /criteo/test +``` + +```sql +CREATE DATABASE IF NOT EXISTS criteo; +use criteo; +``` + +```sql +DROP TABLE IF EXISTS train; +CREATE EXTERNAL TABLE train ( + id bigint, + label int, + -- quantitative features + i1 int,i2 int,i3 int,i4 int,i5 int,i6 int,i7 int,i8 int,i9 int,i10 int,i11 int,i12 int,i13 int, + -- categorical features + c1 string,c2 string,c3 string,c4 string,c5 string,c6 string,c7 string,c8 string,c9 string,c10 string,c11 string,c12 string,c13 string,c14 string,c15 string,c16 string,c17 string,c18 string,c19 string,c20 string,c21 string,c22 string,c23 string,c24 string,c25 string,c26 string +) ROW FORMAT +DELIMITED FIELDS TERMINATED BY ',' +STORED AS TEXTFILE LOCATION '/criteo/train'; +``` + +```sql +DROP TABLE IF EXISTS test; +CREATE EXTERNAL TABLE test ( + label int, + -- quantitative features + i1 int,i2 int,i3 int,i4 int,i5 int,i6 int,i7 int,i8 int,i9 int,i10 int,i11 int,i12 int,i13 int, + -- categorical features + c1 string,c2 string,c3 string,c4 string,c5 string,c6 string,c7 string,c8 string,c9 string,c10 string,c11 string,c12 string,c13 string,c14 string,c15 string,c16 string,c17 string,c18 string,c19 string,c20 string,c21 string,c22 string,c23 string,c24 string,c25 string,c26 string +) ROW FORMAT +DELIMITED FIELDS TERMINATED BY ',' +STORED AS TEXTFILE LOCATION '/criteo/test'; +``` \ No newline at end of file
