Close #98, Close #87: [HIVEMALL-108] Support option in generic predictors
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e186a587 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e186a587 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e186a587 Branch: refs/heads/master Commit: e186a58767a4ae16ef673242f57a54ab0da2e81b Parents: 047f5fe Author: Takuya Kitazawa <[email protected]> Authored: Fri Jul 14 23:20:19 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Fri Jul 14 23:20:19 2017 +0900 ---------------------------------------------------------------------- .../java/hivemall/GeneralLearnerBaseUDTF.java | 521 ++++++++++++++++--- .../src/main/java/hivemall/LearnerBaseUDTF.java | 148 +----- .../src/main/java/hivemall/UDTFWithOptions.java | 2 +- .../classifier/BinaryOnlineClassifierUDTF.java | 18 +- .../classifier/GeneralClassifierUDTF.java | 3 +- .../MulticlassOnlineClassifierUDTF.java | 20 +- .../java/hivemall/common/ConversionState.java | 51 +- .../hivemall/fm/FactorizationMachineUDTF.java | 18 +- .../hivemall/mf/BPRMatrixFactorizationUDTF.java | 20 +- .../mf/OnlineMatrixFactorizationUDTF.java | 16 +- .../main/java/hivemall/model/FeatureValue.java | 20 +- .../main/java/hivemall/optimizer/Optimizer.java | 2 +- .../regression/GeneralRegressionUDTF.java | 4 +- .../hivemall/regression/RegressionBaseUDTF.java | 18 +- .../smile/regression/RegressionTree.java | 1 - .../java/hivemall/utils/hadoop/HiveUtils.java | 54 ++ .../classifier/GeneralClassifierUDTFTest.java | 174 ++++++- .../java/hivemall/model/FeatureValueTest.java | 6 +- .../hivemall/regression/AdaGradUDTFTest.java | 1 + .../regression/GeneralRegressionUDTFTest.java | 183 ++++++- 20 files changed, 904 insertions(+), 376 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 34c7ec9..f1bc045 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -19,6 +19,7 @@ package hivemall; import hivemall.annotations.VisibleForTesting; +import hivemall.common.ConversionState; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.PredictionModel; @@ -31,13 +32,23 @@ import hivemall.optimizer.Optimizer; import hivemall.optimizer.OptimizerOptions; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NIOUtils; +import hivemall.utils.io.NioStatefullSegment; import hivemall.utils.lang.FloatAccumulator; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.SizeOf; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -51,34 +62,57 @@ import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class); private ListObjectInspector featureListOI; - private PrimitiveObjectInspector featureInputOI; private PrimitiveObjectInspector targetOI; - private boolean parseFeature; + private FeatureType featureType; + + // ----------------------------------------- + // hyperparameters @Nonnull private final Map<String, String> optimizerOptions; private Optimizer optimizer; private LossFunction lossFunction; + // ----------------------------------------- + private PredictionModel model; private long count; - // The accumulated delta of each weight values. + // ----------------------------------------- + // for mini-batch + + /** The accumulated delta of each weight values. */ @Nullable private transient Map<Object, FloatAccumulator> accumulated; private int sampled; - private double cumLoss; + // ----------------------------------------- + // for iterations + + @Nullable + protected transient NioStatefullSegment fileIO; + @Nullable + protected transient ByteBuffer inputBuf; + private int iterations; + protected ConversionState cvState; + + // ----------------------------------------- public GeneralLearnerBaseUDTF() { this(true); @@ -108,17 +142,13 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { throw new UDFArgumentException( "_FUNC_ takes 2 arguments: List<Int|BigInt|Text> features, float target [, constant string options]"); } - this.featureInputOI = processFeaturesOI(argOIs[0]); + this.featureListOI = HiveUtils.asListOI(argOIs[0]); + this.featureType = getFeatureType(featureListOI); this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]); processOptions(argOIs); - PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector - : featureInputOI; this.model = createModel(); - if (preloadedModelFile != null) { - loadPredictionModel(model, preloadedModelFile, featureOutputOI); - } try { this.optimizer = createOptimizer(optimizerOptions); @@ -128,15 +158,20 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { this.count = 0L; this.sampled = 0; - this.cumLoss = 0.d; - return getReturnOI(featureOutputOI); + return getReturnOI(getFeatureOutputOI(featureType)); } @Override protected Options getOptions() { Options opts = super.getOptions(); opts.addOption("loss", "loss_function", true, getLossOptionDescription()); + opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]"); + // conversion check + opts.addOption("disable_cv", "disable_cvtest", false, + "Whether to disable convergence check [default: OFF]"); + opts.addOption("cv_rate", "convergence_rate", true, + "Threshold to determine convergence [default: 0.005]"); OptimizerOptions.setup(opts); return opts; } @@ -146,29 +181,83 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { CommandLine cl = super.processOptions(argOIs); LossFunction lossFunction = LossFunctions.getLossFunction(getDefaultLossType()); - if (cl.hasOption("loss_function")) { - try { - lossFunction = LossFunctions.getLossFunction(cl.getOptionValue("loss_function")); - } catch (Throwable e) { - throw new UDFArgumentException(e.getMessage()); + int iterations = 10; + boolean conversionCheck = true; + double convergenceRate = 0.005d; + + if (cl != null) { + if (cl.hasOption("loss_function")) { + try { + lossFunction = LossFunctions.getLossFunction(cl.getOptionValue("loss_function")); + } catch (Throwable e) { + throw new UDFArgumentException(e.getMessage()); + } } + checkLossFunction(lossFunction); + + iterations = Primitives.parseInt(cl.getOptionValue("iterations"), iterations); + if (iterations < 1) { + throw new UDFArgumentException( + "'-iterations' must be greater than or equals to 1: " + iterations); + } + + conversionCheck = !cl.hasOption("disable_cvtest"); + convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate); } - checkLossFunction(lossFunction); + this.lossFunction = lossFunction; + this.iterations = iterations; + this.cvState = new ConversionState(conversionCheck, convergenceRate); OptimizerOptions.propcessOptions(cl, optimizerOptions); return cl; } + public enum FeatureType { + STRING, INT, LONG + } + @Nonnull - protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg) + private static FeatureType getFeatureType(@Nonnull ListObjectInspector featureListOI) throws UDFArgumentException { - this.featureListOI = (ListObjectInspector) arg; - ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); - HiveUtils.validateFeatureOI(featureRawOI); - this.parseFeature = HiveUtils.isStringOI(featureRawOI); - return HiveUtils.asPrimitiveObjectInspector(featureRawOI); + final ObjectInspector featureOI = featureListOI.getListElementObjectInspector(); + if (featureOI instanceof StringObjectInspector) { + return FeatureType.STRING; + } else if (featureOI instanceof IntObjectInspector) { + return FeatureType.INT; + } else if (featureOI instanceof LongObjectInspector) { + return FeatureType.LONG; + } else { + throw new UDFArgumentException("Feature object inspector must be one of " + + "[StringObjectInspector, IntObjectInspector, LongObjectInspector]: " + + featureOI.toString()); + } + } + + @Nonnull + protected final ObjectInspector getFeatureOutputOI(@Nonnull final FeatureType featureType) + throws UDFArgumentException { + final PrimitiveObjectInspector outputOI; + if (dense_model) { + // TODO validation + outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel (long/string is also parsed as int) + } else { + switch (featureType) { + case STRING: + outputOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + break; + case INT: + outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + break; + case LONG: + outputOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector; + break; + default: + throw new IllegalStateException("Unexpected feature type: " + featureType); + } + } + return outputOI; } @Nonnull @@ -177,8 +266,7 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldNames.add("feature"); - ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureOutputOI); - fieldOIs.add(featureOI); + fieldOIs.add(featureOutputOI); fieldNames.add("weight"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); if (useCovariance()) { @@ -204,8 +292,100 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { checkTargetValue(target); count++; - train(featureVector, target); + + recordTrainSampleToTempFile(featureVector, target); + } + + protected void recordTrainSampleToTempFile(@Nonnull final FeatureValue[] featureVector, + final float target) throws HiveException { + if (iterations == 1) { + return; + } + + ByteBuffer buf = inputBuf; + NioStatefullSegment dst = fileIO; + + if (buf == null) { + final File file; + try { + file = File.createTempFile("hivemall_general_learner", ".sgmt"); + file.deleteOnExit(); + if (!file.canWrite()) { + throw new UDFArgumentException("Cannot write a temporary file: " + + file.getAbsolutePath()); + } + logger.info("Record training samples to a file: " + file.getAbsolutePath()); + } catch (IOException ioe) { + throw new UDFArgumentException(ioe); + } catch (Throwable e) { + throw new UDFArgumentException(e); + } + this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB + this.fileIO = dst = new NioStatefullSegment(file, false); + } + + int featureVectorBytes = 0; + for (FeatureValue f : featureVector) { + if (f == null) { + continue; + } + int featureLength = f.getFeatureAsString().length(); + + // feature as String (even if it is Text or Integer) + featureVectorBytes += SizeOf.CHAR * featureLength; + + // NIOUtils.putString() first puts the length of string before string itself + featureVectorBytes += SizeOf.INT; + + // value + featureVectorBytes += SizeOf.DOUBLE; + } + + // feature length, feature 1, feature 2, ..., feature n, target + int recordBytes = SizeOf.INT + featureVectorBytes + SizeOf.FLOAT; + int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself + + int remain = buf.remaining(); + if (remain < requiredBytes) { + writeBuffer(buf, dst); + } + + buf.putInt(recordBytes); + buf.putInt(featureVector.length); + for (FeatureValue f : featureVector) { + writeFeatureValue(buf, f); + } + buf.putFloat(target); + } + + private static void writeFeatureValue(@Nonnull final ByteBuffer buf, + @Nonnull final FeatureValue f) { + NIOUtils.putString(f.getFeatureAsString(), buf); + buf.putDouble(f.getValue()); + } + + @Nonnull + private static FeatureValue readFeatureValue(@Nonnull final ByteBuffer buf, + @Nonnull final FeatureType featureType) { + final String featureStr = NIOUtils.getString(buf); + final Object feature; + switch (featureType) { + case STRING: + feature = featureStr; + break; + case INT: + feature = Integer.valueOf(featureStr); + break; + case LONG: + feature = Long.valueOf(featureStr); + break; + default: + throw new IllegalStateException("Unexpected feature type " + featureType + + " for feature: " + featureStr); + } + double value = buf.getDouble(); + return new FeatureValue(feature, value); } @Nullable @@ -223,10 +403,12 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { continue; } final FeatureValue fv; - if (parseFeature) { - fv = FeatureValue.parse(f); + if (featureType == FeatureType.STRING) { + String s = f.toString(); + fv = FeatureValue.parseFeatureAsString(s); } else { - Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); + Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector, + ObjectInspectorCopyOption.JAVA); // should be Integer or Long fv = new FeatureValue(k, 1.f); } featureVector[i] = fv; @@ -234,6 +416,17 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { return featureVector; } + private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst) + throws HiveException { + srcBuf.flip(); + try { + dst.write(srcBuf); + } catch (IOException e) { + throw new HiveException("Exception causes while writing a buffer to file", e); + } + srcBuf.clear(); + } + public float predict(@Nonnull final FeatureValue[] features) { float score = 0.f; for (FeatureValue f : features) {// a += w[i] * x[i] @@ -253,8 +446,10 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { protected void update(@Nonnull final FeatureValue[] features, final float target, final float predicted) { - this.cumLoss += lossFunction.loss(predicted, target); // retain cumulative loss to check convergence - float dloss = lossFunction.dloss(predicted, target); + float loss = lossFunction.loss(predicted, target); + cvState.incrLoss(loss); // retain cumulative loss to check convergence + + final float dloss = lossFunction.dloss(predicted, target); if (is_mini_batch) { accumulateUpdate(features, dloss); @@ -318,66 +513,228 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { @Override public final void close() throws HiveException { super.close(); - if (model != null) { - if (accumulated != null) { // Update model with accumulated delta - batchUpdate(); - this.accumulated = null; - } - int numForwarded = 0; - if (useCovariance()) { - final WeightValueWithCovar probe = new WeightValueWithCovar(); - final Object[] forwardMapObj = new Object[3]; - final FloatWritable fv = new FloatWritable(); - final FloatWritable cov = new FloatWritable(); - final IMapIterator<Object, IWeightValue> itor = model.entries(); - while (itor.next() != -1) { - itor.getValue(probe); - if (!probe.isTouched()) { - continue; // skip outputting untouched weights + finalizeTraining(); + forwardModel(); + this.accumulated = null; + this.model = null; + } + + @VisibleForTesting + public void finalizeTraining() throws HiveException { + if (count == 0L) { + this.model = null; + return; + } + if (is_mini_batch) { // Update model with accumulated delta + batchUpdate(); + } + if (iterations > 1) { + runIterativeTraining(iterations); + } + } + + protected final void runIterativeTraining(@Nonnegative final int iterations) + throws HiveException { + final ByteBuffer buf = this.inputBuf; + final NioStatefullSegment dst = this.fileIO; + assert (buf != null); + assert (dst != null); + final long numTrainingExamples = count; + + final Reporter reporter = getReporter(); + final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter( + "hivemall.GeneralLearnerBase$Counter", "iteration"); + + try { + if (dst.getPosition() == 0L) {// run iterations w/o temporary file + if (buf.position() == 0) { + return; // no training example + } + buf.flip(); + + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); + reportProgress(reporter); + setCounterValue(iterCounter, iter); + + while (buf.remaining() > 0) { + int recordBytes = buf.getInt(); + assert (recordBytes > 0) : recordBytes; + int featureVectorLength = buf.getInt(); + final FeatureValue[] featureVector = new FeatureValue[featureVectorLength]; + for (int j = 0; j < featureVectorLength; j++) { + featureVector[j] = readFeatureValue(buf, featureType); + } + float target = buf.getFloat(); + train(featureVector, target); + } + buf.rewind(); + + if (is_mini_batch) { // Update model with accumulated delta + batchUpdate(); + } + + if (cvState.isConverged(numTrainingExamples)) { + break; } - Object k = itor.getKey(); - fv.set(probe.get()); - cov.set(probe.getCovariance()); - forwardMapObj[0] = k; - forwardMapObj[1] = fv; - forwardMapObj[2] = cov; - forward(forwardMapObj); - numForwarded++; } - } else { - final WeightValue probe = new WeightValue(); - final Object[] forwardMapObj = new Object[2]; - final FloatWritable fv = new FloatWritable(); - final IMapIterator<Object, IWeightValue> itor = model.entries(); - while (itor.next() != -1) { - itor.getValue(probe); - if (!probe.isTouched()) { - continue; // skip outputting untouched weights + logger.info("Performed " + + cvState.getCurrentIteration() + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on memory (thus " + + NumberUtils.formatNumber(numTrainingExamples + * cvState.getCurrentIteration()) + " training updates in total) "); + } else {// read training examples in the temporary file and invoke train for each example + // write training examples in buffer to a temporary file + if (buf.remaining() > 0) { + writeBuffer(buf, dst); + } + try { + dst.flush(); + } catch (IOException e) { + throw new HiveException("Failed to flush a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (logger.isInfoEnabled()) { + File tmpFile = dst.getFile(); + logger.info("Wrote " + numTrainingExamples + + " records to a temporary file for iterative training: " + + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + + ")"); + } + + // run iterations + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); + setCounterValue(iterCounter, iter); + + buf.clear(); + dst.resetPosition(); + while (true) { + reportProgress(reporter); + // TODO prefetch + // writes training examples to a buffer in the temporary file + final int bytesRead; + try { + bytesRead = dst.read(buf); + } catch (IOException e) { + throw new HiveException("Failed to read a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (bytesRead == 0) { // reached file EOF + break; + } + assert (bytesRead > 0) : bytesRead; + + // reads training examples from a buffer + buf.flip(); + int remain = buf.remaining(); + if (remain < SizeOf.INT) { + throw new HiveException("Illegal file format was detected"); + } + while (remain >= SizeOf.INT) { + int pos = buf.position(); + int recordBytes = buf.getInt(); + remain -= SizeOf.INT; + + if (remain < recordBytes) { + buf.position(pos); + break; + } + + int featureVectorLength = buf.getInt(); + final FeatureValue[] featureVector = new FeatureValue[featureVectorLength]; + for (int j = 0; j < featureVectorLength; j++) { + featureVector[j] = readFeatureValue(buf, featureType); + } + float target = buf.getFloat(); + train(featureVector, target); + + remain -= recordBytes; + } + buf.compact(); + } + + if (is_mini_batch) { // Update model with accumulated delta + batchUpdate(); + } + + if (cvState.isConverged(numTrainingExamples)) { + break; } - Object k = itor.getKey(); - fv.set(probe.get()); - forwardMapObj[0] = k; - forwardMapObj[1] = fv; - forward(forwardMapObj); - numForwarded++; } + logger.info("Performed " + + cvState.getCurrentIteration() + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on a secondary storage (thus " + + NumberUtils.formatNumber(numTrainingExamples + * cvState.getCurrentIteration()) + " training updates in total)"); } - long numMixed = model.getNumMixed(); - this.model = null; - logger.info("Trained a prediction model using " + count + " training examples" - + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : "")); - logger.info("Forwarded the prediction model of " + numForwarded + " rows"); + } catch (Throwable e) { + throw new HiveException("Exception caused in the iterative training", e); + } finally { + // delete the temporary file and release resources + try { + dst.close(true); + } catch (IOException e) { + throw new HiveException("Failed to close a file: " + + dst.getFile().getAbsolutePath(), e); + } + this.inputBuf = null; + this.fileIO = null; } } - @VisibleForTesting - public double getCumulativeLoss() { - return cumLoss; + protected void forwardModel() throws HiveException { + int numForwarded = 0; + if (useCovariance()) { + final WeightValueWithCovar probe = new WeightValueWithCovar(); + final Object[] forwardMapObj = new Object[3]; + final FloatWritable fv = new FloatWritable(); + final FloatWritable cov = new FloatWritable(); + final IMapIterator<Object, IWeightValue> itor = model.entries(); + while (itor.next() != -1) { + itor.getValue(probe); + if (!probe.isTouched()) { + continue; // skip outputting untouched weights + } + Object k = itor.getKey(); + fv.set(probe.get()); + cov.set(probe.getCovariance()); + forwardMapObj[0] = k; + forwardMapObj[1] = fv; + forwardMapObj[2] = cov; + forward(forwardMapObj); + numForwarded++; + } + } else { + final WeightValue probe = new WeightValue(); + final Object[] forwardMapObj = new Object[2]; + final FloatWritable fv = new FloatWritable(); + final IMapIterator<Object, IWeightValue> itor = model.entries(); + while (itor.next() != -1) { + itor.getValue(probe); + if (!probe.isTouched()) { + continue; // skip outputting untouched weights + } + Object k = itor.getKey(); + fv.set(probe.get()); + forwardMapObj[0] = k; + forwardMapObj[1] = fv; + forward(forwardMapObj); + numForwarded++; + } + } + long numMixed = model.getNumMixed(); + logger.info("Trained a prediction model using " + count + " training examples" + + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : "")); + logger.info("Forwarded the prediction model of " + numForwarded + " rows"); } @VisibleForTesting - public void resetCumulativeLoss() { - this.cumLoss = 0.d; + public double getCumulativeLoss() { + return (cvState == null) ? Double.NaN : cvState.getCumulativeLoss(); } - } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/LearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index fdb22f8..b9ec668 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -18,7 +18,6 @@ */ package hivemall; -import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableFloatObjectInspector; import hivemall.mix.MixMessage.MixEventName; import hivemall.mix.client.MixClient; import hivemall.model.DenseModel; @@ -29,22 +28,14 @@ import hivemall.model.PredictionModel; import hivemall.model.SpaceEfficientDenseModel; import hivemall.model.SparseModel; import hivemall.model.SynchronizedModelWrapper; -import hivemall.model.WeightValue; -import hivemall.model.WeightValue.WeightValueWithCovar; import hivemall.optimizer.DenseOptimizerFactory; import hivemall.optimizer.Optimizer; import hivemall.optimizer.SparseOptimizerFactory; -import hivemall.utils.datetime.StopWatch; -import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.IOUtils; import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.Primitives; -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.util.List; import java.util.Map; import javax.annotation.CheckForNull; @@ -57,21 +48,15 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.serde2.SerDeException; -import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructField; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableFloatObjectInspector; -import org.apache.hadoop.io.Text; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; public abstract class LearnerBaseUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(LearnerBaseUDTF.class); protected final boolean enableNewModel; - protected String preloadedModelFile; protected boolean dense_model; protected int model_dims; protected boolean disable_halffloat; @@ -97,7 +82,6 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { @Override protected Options getOptions() { Options opts = new Options(); - opts.addOption("loadmodel", true, "Model file name in the distributed cache"); opts.addOption("dense", "densemodel", false, "Use dense model or not"); opts.addOption("dims", "feature_dimensions", true, "The dimension of model [default: 16777216 (2^24)]"); @@ -119,7 +103,6 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { @Override protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException { - String modelfile = null; boolean denseModel = false; int modelDims = -1; boolean disableHalfFloat = false; @@ -135,8 +118,6 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { String rawArgs = HiveUtils.getConstString(argOIs[2]); cl = parseOptions(rawArgs); - modelfile = cl.getOptionValue("loadmodel"); - denseModel = cl.hasOption("dense"); if (denseModel) { modelDims = Primitives.parseInt(cl.getOptionValue("dims"), 16777216); @@ -160,7 +141,6 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { ssl = cl.hasOption("ssl"); } - this.preloadedModelFile = modelfile; this.dense_model = denseModel; this.model_dims = modelDims; this.disable_halffloat = disableHalfFloat; @@ -272,124 +252,14 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { return 16384; } - protected void loadPredictionModel(PredictionModel model, String filename, - PrimitiveObjectInspector keyOI) { - final StopWatch elapsed = new StopWatch(); - final long lines; - try { - if (useCovariance()) { - lines = loadPredictionModel(model, new File(filename), keyOI, - writableFloatObjectInspector, writableFloatObjectInspector); - } else { - lines = loadPredictionModel(model, new File(filename), keyOI, - writableFloatObjectInspector); - } - } catch (IOException e) { - throw new RuntimeException("Failed to load a model: " + filename, e); - } catch (SerDeException e) { - throw new RuntimeException("Failed to load a model: " + filename, e); - } - if (model.size() > 0) { - logger.info("Loaded " + model.size() + " features from distributed cache '" + filename - + "' (" + lines + " lines) in " + elapsed); - } - } - - private static long loadPredictionModel(PredictionModel model, File file, - PrimitiveObjectInspector keyOI, WritableFloatObjectInspector valueOI) - throws IOException, SerDeException { - long count = 0L; - if (!file.exists()) { - return count; - } - if (!file.getName().endsWith(".crc")) { - if (file.isDirectory()) { - for (File f : file.listFiles()) { - count += loadPredictionModel(model, f, keyOI, valueOI); - } - } else { - LazySimpleSerDe serde = HiveUtils.getKeyValueLineSerde(keyOI, valueOI); - StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector(); - StructField keyRef = lineOI.getStructFieldRef("key"); - StructField valueRef = lineOI.getStructFieldRef("value"); - PrimitiveObjectInspector keyRefOI = (PrimitiveObjectInspector) keyRef.getFieldObjectInspector(); - FloatObjectInspector varRefOI = (FloatObjectInspector) valueRef.getFieldObjectInspector(); - - BufferedReader reader = null; - try { - reader = HadoopUtils.getBufferedReader(file); - String line; - while ((line = reader.readLine()) != null) { - count++; - Text lineText = new Text(line); - Object lineObj = serde.deserialize(lineText); - List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj); - Object f0 = fields.get(0); - Object f1 = fields.get(1); - if (f0 == null || f1 == null) { - continue; // avoid the case that key or value is null - } - Object k = keyRefOI.getPrimitiveWritableObject(keyRefOI.copyObject(f0)); - float v = varRefOI.get(f1); - model.set(k, new WeightValue(v, false)); - } - } finally { - IOUtils.closeQuietly(reader); - } - } - } - return count; - } - - private static long loadPredictionModel(PredictionModel model, File file, - PrimitiveObjectInspector featureOI, WritableFloatObjectInspector weightOI, - WritableFloatObjectInspector covarOI) throws IOException, SerDeException { - long count = 0L; - if (!file.exists()) { - return count; - } - if (!file.getName().endsWith(".crc")) { - if (file.isDirectory()) { - for (File f : file.listFiles()) { - count += loadPredictionModel(model, f, featureOI, weightOI, covarOI); - } - } else { - LazySimpleSerDe serde = HiveUtils.getLineSerde(featureOI, weightOI, covarOI); - StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector(); - StructField c1ref = lineOI.getStructFieldRef("c1"); - StructField c2ref = lineOI.getStructFieldRef("c2"); - StructField c3ref = lineOI.getStructFieldRef("c3"); - PrimitiveObjectInspector c1oi = (PrimitiveObjectInspector) c1ref.getFieldObjectInspector(); - FloatObjectInspector c2oi = (FloatObjectInspector) c2ref.getFieldObjectInspector(); - FloatObjectInspector c3oi = (FloatObjectInspector) c3ref.getFieldObjectInspector(); - - BufferedReader reader = null; - try { - reader = HadoopUtils.getBufferedReader(file); - String line; - while ((line = reader.readLine()) != null) { - count++; - Text lineText = new Text(line); - Object lineObj = serde.deserialize(lineText); - List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj); - Object f0 = fields.get(0); - Object f1 = fields.get(1); - Object f2 = fields.get(2); - if (f0 == null || f1 == null) { - continue; // avoid unexpected case - } - Object k = c1oi.getPrimitiveWritableObject(c1oi.copyObject(f0)); - float v = c2oi.get(f1); - float cov = (f2 == null) ? WeightValueWithCovar.DEFAULT_COVAR - : c3oi.get(f2); - model.set(k, new WeightValueWithCovar(v, cov, false)); - } - } finally { - IOUtils.closeQuietly(reader); - } - } + @Nonnull + protected ObjectInspector getFeatureOutputOI(@Nonnull PrimitiveObjectInspector featureInputOI) + throws UDFArgumentException { + if (dense_model) { + // TODO validation + return PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel } - return count; + return ObjectInspectorUtils.getStandardObjectInspector(featureInputOI); } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/UDTFWithOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/UDTFWithOptions.java b/core/src/main/java/hivemall/UDTFWithOptions.java index 39ab233..b09cffa 100644 --- a/core/src/main/java/hivemall/UDTFWithOptions.java +++ b/core/src/main/java/hivemall/UDTFWithOptions.java @@ -63,7 +63,7 @@ public abstract class UDTFWithOptions extends GenericUDTF { return mapredContext.getReporter(); } - protected static void reportProgress(@Nonnull Reporter reporter) { + protected static void reportProgress(@Nullable Reporter reporter) { if (reporter != null) { synchronized (reporter) { reporter.progress(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java index 2dcf521..2f4db3a 100644 --- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java @@ -85,19 +85,15 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { processOptions(argOIs); - PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector - : featureInputOI; this.model = createModel(); - if (preloadedModelFile != null) { - loadPredictionModel(model, preloadedModelFile, featureOutputOI); - } - this.count = 0; this.sampled = 0; - return getReturnOI(featureOutputOI); + + return getReturnOI(getFeatureOutputOI(featureInputOI)); } - protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) + @Nonnull + protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg) throws UDFArgumentException { this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); @@ -106,13 +102,13 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } - protected StructObjectInspector getReturnOI(ObjectInspector featureRawOI) { + @Nonnull + protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) { ArrayList<String> fieldNames = new ArrayList<String>(); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldNames.add("feature"); - ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI); - fieldOIs.add(featureOI); + fieldOIs.add(featureOutputOI); fieldNames.add("weight"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); if (useCovariance()) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java index 8e17de1..98cdf0b 100644 --- a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java @@ -42,7 +42,8 @@ public final class GeneralClassifierUDTF extends GeneralLearnerBaseUDTF { @Override protected String getLossOptionDescription() { return "Loss function [HingeLoss (default), LogLoss, SquaredHingeLoss, ModifiedHuberLoss, or\n" - + "a regression loss: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss]"; + + "a regression loss: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, " + + "SquaredEpsilonInsensitiveLoss, HuberLoss]"; } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java index af8545c..08a040b 100644 --- a/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java @@ -103,15 +103,10 @@ public abstract class MulticlassOnlineClassifierUDTF extends LearnerBaseUDTF { processOptions(argOIs); - PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector - : featureInputOI; this.label2model = new HashMap<Object, PredictionModel>(64); - if (preloadedModelFile != null) { - loadPredictionModel(label2model, preloadedModelFile, labelInputOI, featureOutputOI); - } - this.count = 0; - return getReturnOI(labelInputOI, featureOutputOI); + + return getReturnOI(labelInputOI, getFeatureOutputOI(featureInputOI)); } @Override @@ -119,7 +114,8 @@ public abstract class MulticlassOnlineClassifierUDTF extends LearnerBaseUDTF { return 8192; } - protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) + @Nonnull + protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg) throws UDFArgumentException { this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); @@ -133,8 +129,9 @@ public abstract class MulticlassOnlineClassifierUDTF extends LearnerBaseUDTF { return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } - protected StructObjectInspector getReturnOI(ObjectInspector labelRawOI, - ObjectInspector featureRawOI) { + @Nonnull + protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector labelRawOI, + @Nonnull ObjectInspector featureOutputOI) { ArrayList<String> fieldNames = new ArrayList<String>(); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); @@ -142,8 +139,7 @@ public abstract class MulticlassOnlineClassifierUDTF extends LearnerBaseUDTF { ObjectInspector labelOI = ObjectInspectorUtils.getStandardObjectInspector(labelRawOI); fieldOIs.add(labelOI); fieldNames.add("feature"); - ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI); - fieldOIs.add(featureOI); + fieldOIs.add(featureOutputOI); fieldNames.add("weight"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); if (useCovariance()) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/common/ConversionState.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/common/ConversionState.java b/core/src/main/java/hivemall/common/ConversionState.java index dd20662..ff92241 100644 --- a/core/src/main/java/hivemall/common/ConversionState.java +++ b/core/src/main/java/hivemall/common/ConversionState.java @@ -25,20 +25,19 @@ public final class ConversionState { private static final Log logger = LogFactory.getLog(ConversionState.class); /** Whether to check conversion */ - protected final boolean conversionCheck; + private final boolean conversionCheck; /** Threshold to determine convergence */ - protected final double convergenceRate; + private final double convergenceRate; /** being ready to end iteration */ - protected boolean readyToFinishIterations; + private boolean readyToFinishIterations; /** The cumulative errors in the training */ - protected double totalErrors; + private double totalErrors; /** The cumulative losses in an iteration */ - protected double currLosses, prevLosses; + private double currLosses, prevLosses; - protected int curIter; - protected float curEta; + private int curIter; public ConversionState() { this(true, 0.005d); @@ -51,8 +50,7 @@ public final class ConversionState { this.totalErrors = 0.d; this.currLosses = 0.d; this.prevLosses = Double.POSITIVE_INFINITY; - this.curIter = 0; - this.curEta = Float.NaN; + this.curIter = 1; } public double getTotalErrors() { @@ -83,20 +81,16 @@ public final class ConversionState { return currLosses > prevLosses; } - public boolean isConverged(final int iter, final long obserbedTrainingExamples) { + public boolean isConverged(final long obserbedTrainingExamples) { if (conversionCheck == false) { - this.prevLosses = currLosses; - this.currLosses = 0.d; return false; } if (currLosses > prevLosses) { if (logger.isInfoEnabled()) { - logger.info("Iteration #" + iter + " currLoss `" + currLosses + "` > prevLosses `" - + prevLosses + '`'); + logger.info("Iteration #" + curIter + " currLoss `" + currLosses + + "` > prevLosses `" + prevLosses + '`'); } - this.prevLosses = currLosses; - this.currLosses = 0.d; this.readyToFinishIterations = false; return false; } @@ -105,7 +99,7 @@ public final class ConversionState { if (changeRate < convergenceRate) { if (readyToFinishIterations) { // NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY - logger.info("Training converged at " + iter + "-th iteration. [curLosses=" + logger.info("Training converged at " + curIter + "-th iteration. [curLosses=" + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate + ']'); return true; @@ -114,33 +108,24 @@ public final class ConversionState { } } else { if (logger.isDebugEnabled()) { - logger.debug("Iteration #" + iter + " [curLosses=" + currLosses + ", prevLosses=" - + prevLosses + ", changeRate=" + changeRate + ", #trainingExamples=" - + obserbedTrainingExamples + ']'); + logger.debug("Iteration #" + curIter + " [curLosses=" + currLosses + + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate + + ", #trainingExamples=" + obserbedTrainingExamples + ']'); } this.readyToFinishIterations = false; } - this.prevLosses = currLosses; - this.currLosses = 0.d; return false; } - public void logState(int iter, float eta) { - if (logger.isInfoEnabled()) { - logger.info("Iteration #" + iter + " [curLoss=" + currLosses + ", prevLoss=" - + prevLosses + ", eta=" + eta + ']'); - } - this.curIter = iter; - this.curEta = eta; + public void next() { + this.prevLosses = currLosses; + this.currLosses = 0.d; + this.curIter++; } public int getCurrentIteration() { return curIter; } - public float getCurrentEta() { - return curEta; - } - } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java index 36af127..65b6ba7 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java @@ -20,11 +20,11 @@ package hivemall.fm; import hivemall.UDTFWithOptions; import hivemall.common.ConversionState; +import hivemall.fm.FMStringFeatureMapModel.Entry; import hivemall.optimizer.EtaEstimator; import hivemall.optimizer.LossFunctions; import hivemall.optimizer.LossFunctions.LossFunction; import hivemall.optimizer.LossFunctions.LossType; -import hivemall.fm.FMStringFeatureMapModel.Entry; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; @@ -539,8 +539,8 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { } inputBuf.flip(); - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + _cvState.next(); reportProgress(reporter); setCounterValue(iterCounter, iter); @@ -557,12 +557,12 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { ++_t; train(x, y, adaregr); } - if (_cvState.isConverged(iter, numTrainingExamples)) { + if (_cvState.isConverged(numTrainingExamples)) { break; } inputBuf.rewind(); } - LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + LOG.info("Performed " + _cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(_t) + " training updates in total) "); @@ -587,8 +587,8 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { } // run iterations - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + _cvState.next(); setCounterValue(iterCounter, iter); inputBuf.clear(); @@ -639,11 +639,11 @@ public class FactorizationMachineUDTF extends UDTFWithOptions { } inputBuf.compact(); } - if (_cvState.isConverged(iter, numTrainingExamples)) { + if (_cvState.isConverged(numTrainingExamples)) { break; } } - LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + LOG.info("Performed " + _cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on a secondary storage (thus " + NumberUtils.formatNumber(_t) + " training updates in total)"); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java index 56a1992..141b261 100644 --- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java @@ -20,8 +20,8 @@ package hivemall.mf; import hivemall.UDTFWithOptions; import hivemall.common.ConversionState; -import hivemall.optimizer.EtaEstimator; import hivemall.mf.FactorizedModel.RankInitScheme; +import hivemall.optimizer.EtaEstimator; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; import hivemall.utils.io.NioFixedSegment; @@ -479,8 +479,8 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements } inputBuf.flip(); - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); reportProgress(reporter); setCounterValue(iterCounter, iter); @@ -493,8 +493,7 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements train(u, i, j); } cvState.multiplyLoss(0.5d); - cvState.logState(iter, eta()); - if (cvState.isConverged(iter, numTrainingExamples)) { + if (cvState.isConverged(numTrainingExamples)) { break; } if (cvState.isLossIncreased()) { @@ -504,7 +503,7 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements } inputBuf.rewind(); } - LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + LOG.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(count) + " training updates in total) "); @@ -531,8 +530,8 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements } // run iterations - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); setCounterValue(iterCounter, iter); inputBuf.clear(); @@ -569,8 +568,7 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements inputBuf.compact(); } cvState.multiplyLoss(0.5d); - cvState.logState(iter, eta()); - if (cvState.isConverged(iter, numTrainingExamples)) { + if (cvState.isConverged(numTrainingExamples)) { break; } if (cvState.isLossIncreased()) { @@ -579,7 +577,7 @@ public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements etaEstimator.update(0.5f); } } - LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + LOG.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples using a secondary storage (thus " + NumberUtils.formatNumber(count) + " training updates in total)"); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java index bfc1f19..66ec60d 100644 --- a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java @@ -477,8 +477,8 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl } inputBuf.flip(); - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); reportProgress(reporter); setCounterValue(iterCounter, iter); @@ -491,12 +491,12 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl train(user, item, rating); } cvState.multiplyLoss(0.5d); - if (cvState.isConverged(iter, numTrainingExamples)) { + if (cvState.isConverged(numTrainingExamples)) { break; } inputBuf.rewind(); } - logger.info("Performed " + Math.min(iter, iterations) + " iterations of " + logger.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(count) + " training updates in total) "); @@ -523,8 +523,8 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl } // run iterations - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); setCounterValue(iterCounter, iter); inputBuf.clear(); @@ -561,11 +561,11 @@ public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions impl inputBuf.compact(); } cvState.multiplyLoss(0.5d); - if (cvState.isConverged(iter, numTrainingExamples)) { + if (cvState.isConverged(numTrainingExamples)) { break; } } - logger.info("Performed " + Math.min(iter, iterations) + " iterations of " + logger.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples using a secondary storage (thus " + NumberUtils.formatNumber(count) + " training updates in total)"); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/model/FeatureValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/FeatureValue.java b/core/src/main/java/hivemall/model/FeatureValue.java index 11aa8f0..11005e9 100644 --- a/core/src/main/java/hivemall/model/FeatureValue.java +++ b/core/src/main/java/hivemall/model/FeatureValue.java @@ -28,7 +28,7 @@ import org.apache.hadoop.io.Text; public final class FeatureValue { - private/* final */Object feature; + private/* final */Object feature; // possible types: String, Text, Integer, Long private/* final */double value; public FeatureValue() {}// used for Probe @@ -108,7 +108,11 @@ public final class FeatureValue { String s1 = s.substring(0, pos); String s2 = s.substring(pos + 1); feature = mhash ? Integer.valueOf(MurmurHash3.murmurhash3(s1)) : new Text(s1); - weight = Double.parseDouble(s2); + try { + weight = Double.parseDouble(s2); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a feature value: " + s, e); + } } else { feature = mhash ? Integer.valueOf(MurmurHash3.murmurhash3(s)) : new Text(s); weight = 1.d; @@ -135,7 +139,11 @@ public final class FeatureValue { if (pos > 0) { feature = s.substring(0, pos); String s2 = s.substring(pos + 1); - weight = Double.parseDouble(s2); + try { + weight = Double.parseDouble(s2); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a feature value: " + s, e); + } } else { feature = s; weight = 1.d; @@ -157,7 +165,11 @@ public final class FeatureValue { if (pos > 0) { probe.feature = s.substring(0, pos); String s2 = s.substring(pos + 1); - probe.value = Double.parseDouble(s2); + try { + probe.value = Double.parseDouble(s2); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a feature value: " + s, e); + } } else { probe.feature = s; probe.value = 1.d; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/optimizer/Optimizer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java index bbd2320..0f82833 100644 --- a/core/src/main/java/hivemall/optimizer/Optimizer.java +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -51,7 +51,7 @@ public interface Optimizer { @Nonnull protected final Regularization _reg; @Nonnegative - protected int _numStep = 1; + protected long _numStep = 1L; public OptimizerBase(@Nonnull Map<String, String> options) { this._eta = EtaEstimator.get(options); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java index 1bd9393..a34a6e6 100644 --- a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java @@ -41,8 +41,8 @@ public final class GeneralRegressionUDTF extends GeneralLearnerBaseUDTF { @Override protected String getLossOptionDescription() { - return "Loss function [default: SquaredLoss/squared, QuantileLoss/quantile, " - + "EpsilonInsensitiveLoss/epsilon_insensitive, HuberLoss/huber]"; + return "Loss function [SquaredLoss (default), QuantileLoss, EpsilonInsensitiveLoss, " + + "SquaredEpsilonInsensitiveLoss, HuberLoss]"; } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java index f8fae89..33196ab 100644 --- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java +++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java @@ -89,19 +89,15 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { processOptions(argOIs); - PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector - : featureInputOI; this.model = createModel(); - if (preloadedModelFile != null) { - loadPredictionModel(model, preloadedModelFile, featureOutputOI); - } - this.count = 0; this.sampled = 0; - return getReturnOI(featureOutputOI); + + return getReturnOI(getFeatureOutputOI(featureInputOI)); } - protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) + @Nonnull + protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg) throws UDFArgumentException { this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); @@ -110,13 +106,13 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF { return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } - protected StructObjectInspector getReturnOI(ObjectInspector featureOutputOI) { + @Nonnull + protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) { ArrayList<String> fieldNames = new ArrayList<String>(); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldNames.add("feature"); - ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureOutputOI); - fieldOIs.add(featureOI); + fieldOIs.add(featureOutputOI); fieldNames.add("weight"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); if (useCovariance()) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/smile/regression/RegressionTree.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java index 5ec27df..38b7b83 100755 --- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java +++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java @@ -34,7 +34,6 @@ package hivemall.smile.regression; import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName; -import static hivemall.smile.utils.SmileExtUtils.resolveName; import hivemall.annotations.VisibleForTesting; import hivemall.math.matrix.Matrix; import hivemall.math.matrix.ints.ColumnMajorIntMatrix; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e186a587/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 4ed1f12..cb2b5e3 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -28,6 +28,7 @@ import static hivemall.HivemallConstants.SMALLINT_TYPE_NAME; import static hivemall.HivemallConstants.STRING_TYPE_NAME; import static hivemall.HivemallConstants.TINYINT_TYPE_NAME; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.BitSet; import java.util.Collections; @@ -46,10 +47,14 @@ import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.lazy.ByteArrayRef; import org.apache.hadoop.hive.serde2.lazy.LazyDouble; import org.apache.hadoop.hive.serde2.lazy.LazyInteger; +import org.apache.hadoop.hive.serde2.lazy.LazyLong; import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.hive.serde2.lazy.LazyString; +import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyStringObjectInspector; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryMap; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; @@ -775,6 +780,7 @@ public final class HiveUtils { return (ConstantObjectInspector) oi; } + @Nonnull public static PrimitiveObjectInspector asPrimitiveObjectInspector( @Nonnull final ObjectInspector oi) throws UDFArgumentException { if (oi.getCategory() != Category.PRIMITIVE) { @@ -784,6 +790,7 @@ public final class HiveUtils { return (PrimitiveObjectInspector) oi; } + @Nonnull public static StringObjectInspector asStringOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!STRING_TYPE_NAME.equals(argOI.getTypeName())) { @@ -792,6 +799,7 @@ public final class HiveUtils { return (StringObjectInspector) argOI; } + @Nonnull public static BinaryObjectInspector asBinaryOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!BINARY_TYPE_NAME.equals(argOI.getTypeName())) { @@ -800,6 +808,7 @@ public final class HiveUtils { return (BinaryObjectInspector) argOI; } + @Nonnull public static BooleanObjectInspector asBooleanOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!BOOLEAN_TYPE_NAME.equals(argOI.getTypeName())) { @@ -808,6 +817,7 @@ public final class HiveUtils { return (BooleanObjectInspector) argOI; } + @Nonnull public static IntObjectInspector asIntOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!INT_TYPE_NAME.equals(argOI.getTypeName())) { @@ -816,6 +826,7 @@ public final class HiveUtils { return (IntObjectInspector) argOI; } + @Nonnull public static LongObjectInspector asLongOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!BIGINT_TYPE_NAME.equals(argOI.getTypeName())) { @@ -824,6 +835,7 @@ public final class HiveUtils { return (LongObjectInspector) argOI; } + @Nonnull public static DoubleObjectInspector asDoubleOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!DOUBLE_TYPE_NAME.equals(argOI.getTypeName())) { @@ -832,6 +844,7 @@ public final class HiveUtils { return (DoubleObjectInspector) argOI; } + @Nonnull public static PrimitiveObjectInspector asIntCompatibleOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { @@ -857,6 +870,7 @@ public final class HiveUtils { return oi; } + @Nonnull public static PrimitiveObjectInspector asLongCompatibleOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { @@ -883,6 +897,7 @@ public final class HiveUtils { return oi; } + @Nonnull public static PrimitiveObjectInspector asIntegerOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { @@ -1037,4 +1052,43 @@ public final class HiveUtils { } return obj; } + + @Nonnull + public static LazyString lazyString(@Nonnull final String str) { + return lazyString(str, (byte) '\\'); + } + + @Nonnull + public static LazyString lazyString(@Nonnull final String str, final byte escapeChar) { + LazyStringObjectInspector oi = LazyPrimitiveObjectInspectorFactory.getLazyStringObjectInspector( + false, escapeChar); + return lazyString(str, oi); + } + + @Nonnull + public static LazyString lazyString(@Nonnull final String str, + @Nonnull final LazyStringObjectInspector oi) { + LazyString lazy = new LazyString(oi); + ByteArrayRef ref = new ByteArrayRef(); + byte[] data = str.getBytes(StandardCharsets.UTF_8); + ref.setData(data); + lazy.init(ref, 0, data.length); + return lazy; + } + + @Nonnull + public static LazyInteger lazyInteger(@Nonnull final int v) { + LazyInteger lazy = new LazyInteger( + LazyPrimitiveObjectInspectorFactory.LAZY_INT_OBJECT_INSPECTOR); + lazy.getWritableObject().set(v); + return lazy; + } + + @Nonnull + public static LazyLong lazyLong(@Nonnull final long v) { + LazyLong lazy = new LazyLong(LazyPrimitiveObjectInspectorFactory.LAZY_LONG_OBJECT_INSPECTOR); + lazy.getWritableObject().set(v); + return lazy; + } + }
