Close #79: [HIVEMALL-101] Separate optimizer implementation
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/3848ea60 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/3848ea60 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/3848ea60 Branch: refs/heads/master Commit: 3848ea60a76cb4b9df7f975bd90e71cd46c77faa Parents: 1db5358 Author: Takeshi Yamamuro <[email protected]> Authored: Thu Jun 8 14:14:58 2017 +0900 Committer: Takuya Kitazawa <[email protected]> Committed: Thu Jun 15 02:48:16 2017 +0900 ---------------------------------------------------------------------- .../java/hivemall/GeneralLearnerBaseUDTF.java | 380 ++++++++++++ .../src/main/java/hivemall/LearnerBaseUDTF.java | 64 +- .../java/hivemall/annotations/InternalAPI.java | 34 ++ .../main/java/hivemall/annotations/Since.java | 29 + .../java/hivemall/anomaly/ChangeFinderUDF.java | 2 + .../anomaly/SingularSpectrumTransformUDF.java | 2 + .../hivemall/classifier/AROWClassifierUDTF.java | 2 +- .../hivemall/classifier/AdaGradRDAUDTF.java | 6 +- .../classifier/BinaryOnlineClassifierUDTF.java | 38 +- .../classifier/GeneralClassifierUDTF.java | 68 +++ .../KernelExpansionPassiveAggressiveUDTF.java | 2 +- .../classifier/PassiveAggressiveUDTF.java | 2 +- .../MulticlassOnlineClassifierUDTF.java | 8 + .../main/java/hivemall/common/EtaEstimator.java | 160 ----- .../java/hivemall/common/LossFunctions.java | 467 -------------- .../java/hivemall/fm/FMHyperParameters.java | 2 +- .../hivemall/fm/FactorizationMachineModel.java | 2 +- .../hivemall/fm/FactorizationMachineUDTF.java | 8 +- .../hivemall/mf/BPRMatrixFactorizationUDTF.java | 2 +- .../hivemall/mf/MatrixFactorizationSGDUDTF.java | 2 +- .../hivemall/model/AbstractPredictionModel.java | 10 +- .../main/java/hivemall/model/DenseModel.java | 24 +- .../main/java/hivemall/model/IWeightValue.java | 16 +- .../main/java/hivemall/model/NewDenseModel.java | 295 +++++++++ .../model/NewSpaceEfficientDenseModel.java | 321 ++++++++++ .../java/hivemall/model/NewSparseModel.java | 202 ++++++ .../java/hivemall/model/PredictionModel.java | 5 +- .../model/SpaceEfficientDenseModel.java | 24 +- .../main/java/hivemall/model/SparseModel.java | 26 +- .../model/SynchronizedModelWrapper.java | 25 +- .../main/java/hivemall/model/WeightValue.java | 162 ++++- .../hivemall/model/WeightValueWithClock.java | 167 ++++- .../optimizer/DenseOptimizerFactory.java | 224 +++++++ .../java/hivemall/optimizer/EtaEstimator.java | 196 ++++++ .../java/hivemall/optimizer/LossFunctions.java | 609 +++++++++++++++++++ .../main/java/hivemall/optimizer/Optimizer.java | 263 ++++++++ .../hivemall/optimizer/OptimizerOptions.java | 77 +++ .../java/hivemall/optimizer/Regularization.java | 140 +++++ .../optimizer/SparseOptimizerFactory.java | 178 ++++++ .../hivemall/regression/AROWRegressionUDTF.java | 2 +- .../java/hivemall/regression/AdaDeltaUDTF.java | 5 +- .../java/hivemall/regression/AdaGradUDTF.java | 5 +- .../regression/GeneralRegressionUDTF.java | 69 +++ .../java/hivemall/regression/LogressUDTF.java | 15 +- .../PassiveAggressiveRegressionUDTF.java | 2 +- .../hivemall/regression/RegressionBaseUDTF.java | 28 +- .../classifier/GeneralClassifierUDTFTest.java | 261 ++++++++ .../NewSpaceEfficientNewDenseModelTest.java | 60 ++ .../model/SpaceEfficientDenseModelTest.java | 60 -- .../java/hivemall/optimizer/OptimizerTest.java | 172 ++++++ .../regression/GeneralRegressionUDTFTest.java | 193 ++++++ docs/gitbook/SUMMARY.md | 45 +- docs/gitbook/binaryclass/general.md | 132 ++++ docs/gitbook/misc/prediction.md | 137 +++++ docs/gitbook/regression/general.md | 74 +++ .../java/hivemall/mix/server/MixServerTest.java | 18 +- resources/ddl/define-all-as-permanent.hive | 11 + resources/ddl/define-all.hive | 10 +- resources/ddl/define-all.spark | 6 + resources/ddl/define-udfs.td.hql | 2 + .../hivemall/mix/server/MixServerSuite.scala | 8 +- 61 files changed, 4758 insertions(+), 801 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java new file mode 100644 index 0000000..e798fdf --- /dev/null +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall; + +import hivemall.annotations.VisibleForTesting; +import hivemall.model.FeatureValue; +import hivemall.model.IWeightValue; +import hivemall.model.PredictionModel; +import hivemall.model.WeightValue; +import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.optimizer.LossFunctions; +import hivemall.optimizer.LossFunctions.LossFunction; +import hivemall.optimizer.LossFunctions.LossType; +import hivemall.optimizer.Optimizer; +import hivemall.optimizer.OptimizerOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.FloatAccumulator; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; + +public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { + private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class); + + private ListObjectInspector featureListOI; + private PrimitiveObjectInspector featureInputOI; + private PrimitiveObjectInspector targetOI; + private boolean parseFeature; + + @Nonnull + private final Map<String, String> optimizerOptions; + private Optimizer optimizer; + private LossFunction lossFunction; + + protected PredictionModel model; + protected int count; + + // The accumulated delta of each weight values. + protected transient Map<Object, FloatAccumulator> accumulated; + protected int sampled; + + private float cumLoss; + + public GeneralLearnerBaseUDTF() { + this(true); + } + + public GeneralLearnerBaseUDTF(boolean enableNewModel) { + super(enableNewModel); + this.optimizerOptions = OptimizerOptions.create(); + } + + @Nonnull + protected abstract String getLossOptionDescription(); + + @Nonnull + protected abstract LossType getDefaultLossType(); + + protected abstract void checkLossFunction(@Nonnull LossFunction lossFunction) + throws UDFArgumentException; + + protected abstract void checkTargetValue(float target) throws UDFArgumentException; + + protected abstract void train(@Nonnull final FeatureValue[] features, final float target); + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length < 2) { + throw new UDFArgumentException( + "_FUNC_ takes 2 arguments: List<Int|BigInt|Text> features, float target [, constant string options]"); + } + this.featureInputOI = processFeaturesOI(argOIs[0]); + this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]); + + processOptions(argOIs); + + PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector + : featureInputOI; + this.model = createModel(); + if (preloadedModelFile != null) { + loadPredictionModel(model, preloadedModelFile, featureOutputOI); + } + + try { + this.optimizer = createOptimizer(optimizerOptions); + } catch (Throwable e) { + throw new UDFArgumentException(e.getMessage()); + } + + this.count = 0; + this.sampled = 0; + this.cumLoss = 0.f; + + return getReturnOI(featureOutputOI); + } + + @Override + protected Options getOptions() { + Options opts = super.getOptions(); + opts.addOption("loss", "loss_function", true, getLossOptionDescription()); + OptimizerOptions.setup(opts); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = super.processOptions(argOIs); + + LossFunction lossFunction = LossFunctions.getLossFunction(getDefaultLossType()); + if (cl.hasOption("loss_function")) { + try { + lossFunction = LossFunctions.getLossFunction(cl.getOptionValue("loss_function")); + } catch (Throwable e) { + throw new UDFArgumentException(e.getMessage()); + } + } + checkLossFunction(lossFunction); + this.lossFunction = lossFunction; + + OptimizerOptions.propcessOptions(cl, optimizerOptions); + + return cl; + } + + protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) + throws UDFArgumentException { + this.featureListOI = (ListObjectInspector) arg; + ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); + HiveUtils.validateFeatureOI(featureRawOI); + this.parseFeature = HiveUtils.isStringOI(featureRawOI); + return HiveUtils.asPrimitiveObjectInspector(featureRawOI); + } + + protected StructObjectInspector getReturnOI(ObjectInspector featureOutputOI) { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("feature"); + ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureOutputOI); + fieldOIs.add(featureOI); + fieldNames.add("weight"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + if (useCovariance()) { + fieldNames.add("covar"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + } + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public void process(Object[] args) throws HiveException { + if (is_mini_batch && accumulated == null) { + this.accumulated = new HashMap<Object, FloatAccumulator>(1024); + } + + List<?> features = (List<?>) featureListOI.getList(args[0]); + FeatureValue[] featureVector = parseFeatures(features); + if (featureVector == null) { + return; + } + float target = PrimitiveObjectInspectorUtils.getFloat(args[1], targetOI); + checkTargetValue(target); + + count++; + + train(featureVector, target); + } + + @Nullable + public final FeatureValue[] parseFeatures(@Nonnull final List<?> features) { + final int size = features.size(); + if (size == 0) { + return null; + } + + final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector(); + final FeatureValue[] featureVector = new FeatureValue[size]; + for (int i = 0; i < size; i++) { + Object f = features.get(i); + if (f == null) { + continue; + } + final FeatureValue fv; + if (parseFeature) { + fv = FeatureValue.parse(f); + } else { + Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); + fv = new FeatureValue(k, 1.f); + } + featureVector[i] = fv; + } + return featureVector; + } + + public float predict(@Nonnull final FeatureValue[] features) { + float score = 0.f; + for (FeatureValue f : features) {// a += w[i] * x[i] + if (f == null) { + continue; + } + final Object k = f.getFeature(); + final float v = f.getValueAsFloat(); + + float old_w = model.getWeight(k); + if (old_w != 0f) { + score += (old_w * v); + } + } + return score; + } + + protected void update(@Nonnull final FeatureValue[] features, final float target, + final float predicted) { + this.cumLoss += lossFunction.loss(predicted, target); // retain cumulative loss to check convergence + float dloss = lossFunction.dloss(predicted, target); + if (is_mini_batch) { + accumulateUpdate(features, dloss); + + if (sampled >= mini_batch_size) { + batchUpdate(); + } + } else { + onlineUpdate(features, dloss); + } + optimizer.proceedStep(); + } + + protected void accumulateUpdate(@Nonnull final FeatureValue[] features, final float dloss) { + for (FeatureValue f : features) { + Object feature = f.getFeature(); + float xi = f.getValueAsFloat(); + float weight = model.getWeight(feature); + + // compute new weight, but still not set to the model + float new_weight = optimizer.update(feature, weight, dloss * xi); + + // (w_i - eta * delta_1) + (w_i - eta * delta_2) + ... + (w_i - eta * delta_M) + FloatAccumulator acc = accumulated.get(feature); + if (acc == null) { + acc = new FloatAccumulator(new_weight); + accumulated.put(feature, acc); + } else { + acc.add(new_weight); + } + } + sampled++; + } + + protected void batchUpdate() { + if (accumulated.isEmpty()) { + this.sampled = 0; + return; + } + + for (Map.Entry<Object, FloatAccumulator> e : accumulated.entrySet()) { + Object feature = e.getKey(); + FloatAccumulator v = e.getValue(); + float new_weight = v.get(); // w_i - (eta / M) * (delta_1 + delta_2 + ... + delta_M) + model.setWeight(feature, new_weight); + } + + accumulated.clear(); + this.sampled = 0; + } + + protected void onlineUpdate(@Nonnull final FeatureValue[] features, float dloss) { + for (FeatureValue f : features) { + Object feature = f.getFeature(); + float xi = f.getValueAsFloat(); + float weight = model.getWeight(feature); + float new_weight = optimizer.update(feature, weight, dloss * xi); + model.setWeight(feature, new_weight); + } + } + + @Override + public final void close() throws HiveException { + super.close(); + if (model != null) { + if (accumulated != null) { // Update model with accumulated delta + batchUpdate(); + this.accumulated = null; + } + int numForwarded = 0; + if (useCovariance()) { + final WeightValueWithCovar probe = new WeightValueWithCovar(); + final Object[] forwardMapObj = new Object[3]; + final FloatWritable fv = new FloatWritable(); + final FloatWritable cov = new FloatWritable(); + final IMapIterator<Object, IWeightValue> itor = model.entries(); + while (itor.next() != -1) { + itor.getValue(probe); + if (!probe.isTouched()) { + continue; // skip outputting untouched weights + } + Object k = itor.getKey(); + fv.set(probe.get()); + cov.set(probe.getCovariance()); + forwardMapObj[0] = k; + forwardMapObj[1] = fv; + forwardMapObj[2] = cov; + forward(forwardMapObj); + numForwarded++; + } + } else { + final WeightValue probe = new WeightValue(); + final Object[] forwardMapObj = new Object[2]; + final FloatWritable fv = new FloatWritable(); + final IMapIterator<Object, IWeightValue> itor = model.entries(); + while (itor.next() != -1) { + itor.getValue(probe); + if (!probe.isTouched()) { + continue; // skip outputting untouched weights + } + Object k = itor.getKey(); + fv.set(probe.get()); + forwardMapObj[0] = k; + forwardMapObj[1] = fv; + forward(forwardMapObj); + numForwarded++; + } + } + long numMixed = model.getNumMixed(); + this.model = null; + logger.info("Trained a prediction model using " + count + " training examples" + + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : "")); + logger.info("Forwarded the prediction model of " + numForwarded + " rows"); + } + } + + @VisibleForTesting + public float getCumulativeLoss() { + return cumLoss; + } + + @VisibleForTesting + public void resetCumulativeLoss() { + this.cumLoss = 0.f; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/LearnerBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index 17c3ebc..bb15bb3 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -22,23 +22,32 @@ import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveO import hivemall.mix.MixMessage.MixEventName; import hivemall.mix.client.MixClient; import hivemall.model.DenseModel; +import hivemall.model.NewDenseModel; +import hivemall.model.NewSpaceEfficientDenseModel; +import hivemall.model.NewSparseModel; import hivemall.model.PredictionModel; import hivemall.model.SpaceEfficientDenseModel; import hivemall.model.SparseModel; import hivemall.model.SynchronizedModelWrapper; import hivemall.model.WeightValue; import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.optimizer.DenseOptimizerFactory; +import hivemall.optimizer.Optimizer; +import hivemall.optimizer.SparseOptimizerFactory; import hivemall.utils.datetime.StopWatch; import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.IOUtils; +import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.Primitives; import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.util.List; +import java.util.Map; +import javax.annotation.CheckForNull; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -61,6 +70,7 @@ import org.apache.hadoop.io.Text; public abstract class LearnerBaseUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(LearnerBaseUDTF.class); + protected final boolean enableNewModel; protected String preloadedModelFile; protected boolean dense_model; protected int model_dims; @@ -73,9 +83,12 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { protected boolean mixCancel; protected boolean ssl; + @Nullable protected MixClient mixClient; - public LearnerBaseUDTF() {} + public LearnerBaseUDTF(boolean enableNewModel) { + this.enableNewModel = enableNewModel; + } protected boolean useCovariance() { return false; @@ -163,11 +176,15 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { @Nullable protected PredictionModel createModel() { - return createModel(null); + if (enableNewModel) { + return createNewModel(null); + } else { + return createOldModel(null); + } } @Nonnull - protected PredictionModel createModel(@Nullable String label) { + private final PredictionModel createOldModel(@Nullable String label) { PredictionModel model; final boolean useCovar = useCovariance(); if (dense_model) { @@ -197,6 +214,47 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { return model; } + @Nonnull + private final PredictionModel createNewModel(@Nullable String label) { + PredictionModel model; + final boolean useCovar = useCovariance(); + if (dense_model) { + if (disable_halffloat == false && model_dims > 16777216) { + logger.info("Build a space efficient dense model with " + model_dims + + " initial dimensions" + (useCovar ? " w/ covariances" : "")); + model = new NewSpaceEfficientDenseModel(model_dims, useCovar); + } else { + logger.info("Build a dense model with initial with " + model_dims + + " initial dimensions" + (useCovar ? " w/ covariances" : "")); + model = new NewDenseModel(model_dims, useCovar); + } + } else { + int initModelSize = getInitialModelSize(); + logger.info("Build a sparse model with initial with " + initModelSize + + " initial dimensions"); + model = new NewSparseModel(initModelSize, useCovar); + } + if (mixConnectInfo != null) { + model.configureClock(); + model = new SynchronizedModelWrapper(model); + MixClient client = configureMixClient(mixConnectInfo, label, model); + model.configureMix(client, mixCancel); + this.mixClient = client; + } + assert (model != null); + return model; + } + + @Nonnull + protected final Optimizer createOptimizer(@CheckForNull Map<String, String> options) { + Preconditions.checkNotNull(options); + if (dense_model) { + return DenseOptimizerFactory.create(model_dims, options); + } else { + return SparseOptimizerFactory.create(model_dims, options); + } + } + protected MixClient configureMixClient(String connectURIs, String label, PredictionModel model) { assert (connectURIs != null); assert (model != null); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/annotations/InternalAPI.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/annotations/InternalAPI.java b/core/src/main/java/hivemall/annotations/InternalAPI.java new file mode 100644 index 0000000..49e31aa --- /dev/null +++ b/core/src/main/java/hivemall/annotations/InternalAPI.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.annotations; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotate program elements that might be changed in the future release. + */ +@Retention(RetentionPolicy.CLASS) +@Target({ElementType.METHOD, ElementType.CONSTRUCTOR}) +@Documented +public @interface InternalAPI { +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/annotations/Since.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/annotations/Since.java b/core/src/main/java/hivemall/annotations/Since.java new file mode 100644 index 0000000..62c415e --- /dev/null +++ b/core/src/main/java/hivemall/annotations/Since.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.annotations; + +import java.lang.annotation.Documented; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +@Documented +@Retention(RetentionPolicy.SOURCE) +public @interface Since { + String version(); +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/anomaly/ChangeFinderUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/anomaly/ChangeFinderUDF.java b/core/src/main/java/hivemall/anomaly/ChangeFinderUDF.java index b2e1960..e786e86 100644 --- a/core/src/main/java/hivemall/anomaly/ChangeFinderUDF.java +++ b/core/src/main/java/hivemall/anomaly/ChangeFinderUDF.java @@ -19,6 +19,7 @@ package hivemall.anomaly; import hivemall.UDFWithOptions; +import hivemall.annotations.Since; import hivemall.utils.collections.DoubleRingBuffer; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.Preconditions; @@ -51,6 +52,7 @@ import org.apache.hadoop.io.BooleanWritable; + " - Returns outlier/change-point scores and decisions using ChangeFinder." + " It will return a tuple <double outlier_score, double changepoint_score [, boolean is_anomaly [, boolean is_changepoint]]") @UDFType(deterministic = false, stateful = true) +@Since(version="0.5-rc.1") public final class ChangeFinderUDF extends UDFWithOptions { private transient Parameters _params; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java index 1fac3e7..7b1f9e3 100644 --- a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java +++ b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java @@ -19,6 +19,7 @@ package hivemall.anomaly; import hivemall.UDFWithOptions; +import hivemall.annotations.Since; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.Primitives; @@ -58,6 +59,7 @@ import org.apache.hadoop.io.BooleanWritable; + " - Returns change-point scores and decisions using Singular Spectrum Transformation (SST)." + " It will return a tuple <double changepoint_score [, boolean is_changepoint]>") @UDFType(deterministic = false, stateful = true) +@Since(version="0.5-rc.1") public final class SingularSpectrumTransformUDF extends UDFWithOptions { private transient Parameters _params; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java index 01c5554..959aaaa 100644 --- a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java @@ -18,11 +18,11 @@ */ package hivemall.classifier; -import hivemall.common.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.PredictionResult; import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.optimizer.LossFunctions; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java index a3e77db..6adeaa9 100644 --- a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java +++ b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java @@ -18,10 +18,10 @@ */ package hivemall.classifier; -import hivemall.common.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.WeightValue.WeightValueParamsF2; +import hivemall.optimizer.LossFunctions; import hivemall.utils.lang.Primitives; import javax.annotation.Nonnull; @@ -33,6 +33,10 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +/** + * @deprecated Use {@link hivemall.classifier.GeneralClassifierUDTF} instead + */ +@Deprecated @Description(name = "train_adagrad_rda", value = "_FUNC_(list<string|int|bigint> features, int label [, const string options])" + " - Returns a relation consists of <string|int|bigint feature, float weight>", http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java index b0e2efd..d25f254 100644 --- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java @@ -19,6 +19,7 @@ package hivemall.classifier; import hivemall.LearnerBaseUDTF; +import hivemall.annotations.VisibleForTesting; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.PredictionModel; @@ -27,9 +28,12 @@ import hivemall.model.WeightValue; import hivemall.model.WeightValue.WeightValueWithCovar; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.FloatAccumulator; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -58,6 +62,17 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { protected PredictionModel model; protected int count; + protected transient Map<Object, FloatAccumulator> accumulated; + protected int sampled; + + public BinaryOnlineClassifierUDTF() { + this(false); + } + + public BinaryOnlineClassifierUDTF(boolean enableNewModel) { + super(enableNewModel); + } + @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length < 2) { @@ -78,6 +93,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { } this.count = 0; + this.sampled = 0; return getReturnOI(featureOutputOI); } @@ -109,6 +125,10 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { @Override public void process(Object[] args) throws HiveException { + if (is_mini_batch && accumulated == null) { + this.accumulated = new HashMap<Object, FloatAccumulator>(1024); + } + List<?> features = (List<?>) featureListOI.getList(args[0]); FeatureValue[] featureVector = parseFeatures(features); if (featureVector == null) { @@ -151,7 +171,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { assert (label == -1 || label == 0 || label == 1) : label; } - //@VisibleForTesting + @VisibleForTesting void train(List<?> features, int label) { FeatureValue[] featureVector = parseFeatures(features); train(featureVector, label); @@ -247,10 +267,26 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF { } } + protected void accumulateUpdate(@Nonnull final FeatureValue[] features, final float coeff) { + throw new UnsupportedOperationException(); + } + + protected void batchUpdate() { + throw new UnsupportedOperationException(); + } + + protected void onlineUpdate(@Nonnull final FeatureValue[] features, float coeff) { + throw new UnsupportedOperationException(); + } + @Override public void close() throws HiveException { super.close(); if (model != null) { + if (accumulated != null) { // Update model with accumulated delta + batchUpdate(); + this.accumulated = null; + } int numForwarded = 0; if (useCovariance()) { final WeightValueWithCovar probe = new WeightValueWithCovar(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java new file mode 100644 index 0000000..753a498 --- /dev/null +++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.classifier; + +import hivemall.GeneralLearnerBaseUDTF; +import hivemall.annotations.Since; +import hivemall.model.FeatureValue; +import hivemall.optimizer.LossFunctions.LossFunction; +import hivemall.optimizer.LossFunctions.LossType; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; + +/** + * A general classifier class that can select a loss function and an optimization function. + */ +@Description(name = "train_classifier", + value = "_FUNC_(list<string|int|bigint> features, int label [, const string options])" + + " - Returns a relation consists of <string|int|bigint feature, float weight>", + extended = "Build a prediction model by a generic classifier") +@Since(version = "0.5-rc.1") +public final class GeneralClassifierUDTF extends GeneralLearnerBaseUDTF { + + @Override + protected String getLossOptionDescription() { + return "Loss function [default: HingeLoss, LogLoss, SquaredHingeLoss, ModifiedHuberLoss, " + + "SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss]"; + } + + @Override + protected LossType getDefaultLossType() { + return LossType.HingeLoss; + } + + @Override + protected void checkLossFunction(LossFunction lossFunction) throws UDFArgumentException {}; + + @Override + protected void checkTargetValue(float label) throws UDFArgumentException { + assert (label == -1.f || label == 0.f || label == 1.f) : label; + } + + @Override + protected void train(@Nonnull final FeatureValue[] features, final float label) { + float predicted = predict(features); + float y = label > 0.f ? 1.f : -1.f; + update(features, y, predicted); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java index 8534231..6e6a2a0 100644 --- a/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java +++ b/core/src/main/java/hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.java @@ -20,12 +20,12 @@ package hivemall.classifier; import hivemall.annotations.Experimental; import hivemall.annotations.VisibleForTesting; -import hivemall.common.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.PredictionModel; import hivemall.model.PredictionResult; import hivemall.utils.collections.maps.Int2FloatOpenHashTable; import hivemall.utils.collections.maps.Int2FloatOpenHashTable.IMapIterator; +import hivemall.optimizer.LossFunctions; import hivemall.utils.hashing.HashFunction; import hivemall.utils.lang.Preconditions; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java index 5ffda7b..e4146ce 100644 --- a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java +++ b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java @@ -18,9 +18,9 @@ */ package hivemall.classifier; -import hivemall.common.LossFunctions; import hivemall.model.FeatureValue; import hivemall.model.PredictionResult; +import hivemall.optimizer.LossFunctions; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java index 8ae949f..af8545c 100644 --- a/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java @@ -77,6 +77,14 @@ public abstract class MulticlassOnlineClassifierUDTF extends LearnerBaseUDTF { protected Map<Object, PredictionModel> label2model; protected int count; + public MulticlassOnlineClassifierUDTF() { + this(false); + } + + public MulticlassOnlineClassifierUDTF(boolean enableNewModel) { + super(enableNewModel); + } + @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length < 2) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/common/EtaEstimator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/common/EtaEstimator.java b/core/src/main/java/hivemall/common/EtaEstimator.java deleted file mode 100644 index 0bfb9dc..0000000 --- a/core/src/main/java/hivemall/common/EtaEstimator.java +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.common; - -import hivemall.utils.lang.NumberUtils; -import hivemall.utils.lang.Primitives; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.commons.cli.CommandLine; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; - -public abstract class EtaEstimator { - - protected final float eta0; - - public EtaEstimator(float eta0) { - this.eta0 = eta0; - } - - public float eta0() { - return eta0; - } - - public abstract float eta(long t); - - public void update(@Nonnegative float multipler) {} - - public static final class FixedEtaEstimator extends EtaEstimator { - - public FixedEtaEstimator(float eta) { - super(eta); - } - - @Override - public float eta(long t) { - return eta0; - } - - } - - public static final class SimpleEtaEstimator extends EtaEstimator { - - private final float finalEta; - private final double total_steps; - - public SimpleEtaEstimator(float eta0, long total_steps) { - super(eta0); - this.finalEta = (float) (eta0 / 2.d); - this.total_steps = total_steps; - } - - @Override - public float eta(final long t) { - if (t > total_steps) { - return finalEta; - } - return (float) (eta0 / (1.d + (t / total_steps))); - } - - } - - public static final class InvscalingEtaEstimator extends EtaEstimator { - - private final double power_t; - - public InvscalingEtaEstimator(float eta0, double power_t) { - super(eta0); - this.power_t = power_t; - } - - @Override - public float eta(final long t) { - return (float) (eta0 / Math.pow(t, power_t)); - } - - } - - /** - * bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic - * gradient descent, KDD 2011. - */ - public static final class AdjustingEtaEstimator extends EtaEstimator { - - private float eta; - - public AdjustingEtaEstimator(float eta) { - super(eta); - this.eta = eta; - } - - @Override - public float eta(long t) { - return eta; - } - - @Override - public void update(@Nonnegative float multipler) { - float newEta = eta * multipler; - if (!NumberUtils.isFinite(newEta)) { - // avoid NaN or INFINITY - return; - } - this.eta = Math.min(eta0, newEta); // never be larger than eta0 - } - - } - - @Nonnull - public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException { - return get(cl, 0.1f); - } - - @Nonnull - public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0) - throws UDFArgumentException { - if (cl == null) { - return new InvscalingEtaEstimator(defaultEta0, 0.1d); - } - - if (cl.hasOption("boldDriver")) { - float eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.3f); - return new AdjustingEtaEstimator(eta); - } - - String etaValue = cl.getOptionValue("eta"); - if (etaValue != null) { - float eta = Float.parseFloat(etaValue); - return new FixedEtaEstimator(eta); - } - - float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0); - if (cl.hasOption("t")) { - long t = Long.parseLong(cl.getOptionValue("t")); - return new SimpleEtaEstimator(eta0, t); - } - - double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), 0.1d); - return new InvscalingEtaEstimator(eta0, power_t); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/common/LossFunctions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/common/LossFunctions.java b/core/src/main/java/hivemall/common/LossFunctions.java deleted file mode 100644 index e1a0f31..0000000 --- a/core/src/main/java/hivemall/common/LossFunctions.java +++ /dev/null @@ -1,467 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.common; - -import hivemall.utils.math.MathUtils; - -/** - * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions - */ -public final class LossFunctions { - - public enum LossType { - SquaredLoss, LogLoss, HingeLoss, SquaredHingeLoss, QuantileLoss, EpsilonInsensitiveLoss - } - - public static LossFunction getLossFunction(String type) { - if ("SquaredLoss".equalsIgnoreCase(type)) { - return new SquaredLoss(); - } else if ("LogLoss".equalsIgnoreCase(type)) { - return new LogLoss(); - } else if ("HingeLoss".equalsIgnoreCase(type)) { - return new HingeLoss(); - } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) { - return new SquaredHingeLoss(); - } else if ("QuantileLoss".equalsIgnoreCase(type)) { - return new QuantileLoss(); - } else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) { - return new EpsilonInsensitiveLoss(); - } - throw new IllegalArgumentException("Unsupported type: " + type); - } - - public static LossFunction getLossFunction(LossType type) { - switch (type) { - case SquaredLoss: - return new SquaredLoss(); - case LogLoss: - return new LogLoss(); - case HingeLoss: - return new HingeLoss(); - case SquaredHingeLoss: - return new SquaredHingeLoss(); - case QuantileLoss: - return new QuantileLoss(); - case EpsilonInsensitiveLoss: - return new EpsilonInsensitiveLoss(); - default: - throw new IllegalArgumentException("Unsupported type: " + type); - } - } - - public interface LossFunction { - - /** - * Evaluate the loss function. - * - * @param p The prediction, p = w^T x - * @param y The true value (aka target) - * @return The loss evaluated at `p` and `y`. - */ - public float loss(float p, float y); - - public double loss(double p, double y); - - /** - * Evaluate the derivative of the loss function with respect to the prediction `p`. - * - * @param p The prediction, p = w^T x - * @param y The true value (aka target) - * @return The derivative of the loss function w.r.t. `p`. - */ - public float dloss(float p, float y); - - public boolean forBinaryClassification(); - - public boolean forRegression(); - - } - - public static abstract class BinaryLoss implements LossFunction { - - protected static void checkTarget(float y) { - if (!(y == 1.f || y == -1.f)) { - throw new IllegalArgumentException("target must be [+1,-1]: " + y); - } - } - - protected static void checkTarget(double y) { - if (!(y == 1.d || y == -1.d)) { - throw new IllegalArgumentException("target must be [+1,-1]: " + y); - } - } - - @Override - public boolean forBinaryClassification() { - return true; - } - - @Override - public boolean forRegression() { - return false; - } - } - - public static abstract class RegressionLoss implements LossFunction { - - @Override - public boolean forBinaryClassification() { - return false; - } - - @Override - public boolean forRegression() { - return true; - } - - } - - /** - * Squared loss for regression problems. - * - * If you're trying to minimize the mean error, use squared-loss. - */ - public static final class SquaredLoss extends RegressionLoss { - - @Override - public float loss(float p, float y) { - final float z = p - y; - return z * z * 0.5f; - } - - @Override - public double loss(double p, double y) { - final double z = p - y; - return z * z * 0.5d; - } - - @Override - public float dloss(float p, float y) { - return p - y; // 2 (p - y) / 2 - } - } - - /** - * Logistic regression loss for binary classification with y in {-1, 1}. - */ - public static final class LogLoss extends BinaryLoss { - - /** - * <code>logloss(p,y) = log(1+exp(-p*y))</code> - */ - @Override - public float loss(float p, float y) { - checkTarget(y); - - final float z = y * p; - if (z > 18.f) { - return (float) Math.exp(-z); - } - if (z < -18.f) { - return -z; - } - return (float) Math.log(1.d + Math.exp(-z)); - } - - @Override - public double loss(double p, double y) { - checkTarget(y); - - final double z = y * p; - if (z > 18.d) { - return Math.exp(-z); - } - if (z < -18.d) { - return -z; - } - return Math.log(1.d + Math.exp(-z)); - } - - @Override - public float dloss(float p, float y) { - checkTarget(y); - - float z = y * p; - if (z > 18.f) { - return (float) Math.exp(-z) * -y; - } - if (z < -18.f) { - return -y; - } - return -y / ((float) Math.exp(z) + 1.f); - } - } - - /** - * Hinge loss for binary classification tasks with y in {-1,1}. - */ - public static final class HingeLoss extends BinaryLoss { - - private float threshold; - - public HingeLoss() { - this(1.f); - } - - /** - * @param threshold Margin threshold. When threshold=1.0, one gets the loss used by SVM. - * When threshold=0.0, one gets the loss used by the Perceptron. - */ - public HingeLoss(float threshold) { - this.threshold = threshold; - } - - public void setThreshold(float threshold) { - this.threshold = threshold; - } - - @Override - public float loss(float p, float y) { - float loss = hingeLoss(p, y, threshold); - return (loss > 0.f) ? loss : 0.f; - } - - @Override - public double loss(double p, double y) { - double loss = hingeLoss(p, y, threshold); - return (loss > 0.d) ? loss : 0.d; - } - - @Override - public float dloss(float p, float y) { - float loss = hingeLoss(p, y, threshold); - return (loss > 0.f) ? -y : 0.f; - } - } - - /** - * Squared Hinge loss for binary classification tasks with y in {-1,1}. - */ - public static final class SquaredHingeLoss extends BinaryLoss { - - @Override - public float loss(float p, float y) { - return squaredHingeLoss(p, y); - } - - @Override - public double loss(double p, double y) { - return squaredHingeLoss(p, y); - } - - @Override - public float dloss(float p, float y) { - checkTarget(y); - - float d = 1 - (y * p); - return (d > 0.f) ? -2.f * d * y : 0.f; - } - - } - - /** - * Quantile loss is useful to predict rank/order and you do not mind the mean error to increase - * as long as you get the relative order correct. - * - * @link http://en.wikipedia.org/wiki/Quantile_regression - */ - public static final class QuantileLoss extends RegressionLoss { - - private float tau; - - public QuantileLoss() { - this.tau = 0.5f; - } - - public QuantileLoss(float tau) { - setTau(tau); - } - - public void setTau(float tau) { - if (tau <= 0 || tau >= 1.0) { - throw new IllegalArgumentException("tau must be in range (0, 1): " + tau); - } - this.tau = tau; - } - - @Override - public float loss(float p, float y) { - float e = y - p; - if (e > 0.f) { - return tau * e; - } else { - return -(1.f - tau) * e; - } - } - - @Override - public double loss(double p, double y) { - double e = y - p; - if (e > 0.d) { - return tau * e; - } else { - return -(1.d - tau) * e; - } - } - - @Override - public float dloss(float p, float y) { - float e = y - p; - if (e == 0.f) { - return 0.f; - } - return (e > 0.f) ? -tau : (1.f - tau); - } - - } - - /** - * Epsilon-Insensitive loss used by Support Vector Regression (SVR). - * <code>loss = max(0, |y - p| - epsilon)</code> - */ - public static final class EpsilonInsensitiveLoss extends RegressionLoss { - - private float epsilon; - - public EpsilonInsensitiveLoss() { - this(0.1f); - } - - public EpsilonInsensitiveLoss(float epsilon) { - this.epsilon = epsilon; - } - - public void setEpsilon(float epsilon) { - this.epsilon = epsilon; - } - - @Override - public float loss(float p, float y) { - float loss = Math.abs(y - p) - epsilon; - return (loss > 0.f) ? loss : 0.f; - } - - @Override - public double loss(double p, double y) { - double loss = Math.abs(y - p) - epsilon; - return (loss > 0.d) ? loss : 0.d; - } - - @Override - public float dloss(float p, float y) { - if ((y - p) > epsilon) {// real value > predicted value - epsilon - return -1.f; - } - if ((p - y) > epsilon) {// real value < predicted value - epsilon - return 1.f; - } - return 0.f; - } - - } - - public static float logisticLoss(final float target, final float predicted) { - if (predicted > -100.d) { - return target - (float) MathUtils.sigmoid(predicted); - } else { - return target; - } - } - - public static float logLoss(final float p, final float y) { - BinaryLoss.checkTarget(y); - - final float z = y * p; - if (z > 18.f) { - return (float) Math.exp(-z); - } - if (z < -18.f) { - return -z; - } - return (float) Math.log(1.d + Math.exp(-z)); - } - - public static double logLoss(final double p, final double y) { - BinaryLoss.checkTarget(y); - - final double z = y * p; - if (z > 18.d) { - return Math.exp(-z); - } - if (z < -18.d) { - return -z; - } - return Math.log(1.d + Math.exp(-z)); - } - - public static float squaredLoss(float p, float y) { - final float z = p - y; - return z * z * 0.5f; - } - - public static double squaredLoss(double p, double y) { - final double z = p - y; - return z * z * 0.5d; - } - - public static float hingeLoss(final float p, final float y, final float threshold) { - BinaryLoss.checkTarget(y); - - float z = y * p; - return threshold - z; - } - - public static double hingeLoss(final double p, final double y, final double threshold) { - BinaryLoss.checkTarget(y); - - double z = y * p; - return threshold - z; - } - - public static float hingeLoss(float p, float y) { - return hingeLoss(p, y, 1.f); - } - - public static double hingeLoss(double p, double y) { - return hingeLoss(p, y, 1.d); - } - - public static float squaredHingeLoss(final float p, final float y) { - BinaryLoss.checkTarget(y); - - float z = y * p; - float d = 1.f - z; - return (d > 0.f) ? (d * d) : 0.f; - } - - public static double squaredHingeLoss(final double p, final double y) { - BinaryLoss.checkTarget(y); - - double z = y * p; - double d = 1.d - z; - return (d > 0.d) ? d * d : 0.d; - } - - /** - * Math.abs(target - predicted) - epsilon - */ - public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) { - return Math.abs(target - predicted) - epsilon; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/fm/FMHyperParameters.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java b/core/src/main/java/hivemall/fm/FMHyperParameters.java index a1e4d25..accb99a 100644 --- a/core/src/main/java/hivemall/fm/FMHyperParameters.java +++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java @@ -18,8 +18,8 @@ */ package hivemall.fm; -import hivemall.common.EtaEstimator; import hivemall.fm.FactorizationMachineModel.VInitScheme; +import hivemall.optimizer.EtaEstimator; import hivemall.utils.lang.Primitives; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/fm/FactorizationMachineModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java index a7fbb4e..eb26276 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java @@ -18,7 +18,7 @@ */ package hivemall.fm; -import hivemall.common.EtaEstimator; +import hivemall.optimizer.EtaEstimator; import hivemall.utils.lang.NumberUtils; import hivemall.utils.math.MathUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java index 3fadc38..36af127 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java @@ -20,10 +20,10 @@ package hivemall.fm; import hivemall.UDTFWithOptions; import hivemall.common.ConversionState; -import hivemall.common.EtaEstimator; -import hivemall.common.LossFunctions; -import hivemall.common.LossFunctions.LossFunction; -import hivemall.common.LossFunctions.LossType; +import hivemall.optimizer.EtaEstimator; +import hivemall.optimizer.LossFunctions; +import hivemall.optimizer.LossFunctions.LossFunction; +import hivemall.optimizer.LossFunctions.LossType; import hivemall.fm.FMStringFeatureMapModel.Entry; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java index f28d3b0..56a1992 100644 --- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java @@ -20,7 +20,7 @@ package hivemall.mf; import hivemall.UDTFWithOptions; import hivemall.common.ConversionState; -import hivemall.common.EtaEstimator; +import hivemall.optimizer.EtaEstimator; import hivemall.mf.FactorizedModel.RankInitScheme; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java b/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java index a95d01d..7edeaa0 100644 --- a/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java +++ b/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java @@ -18,7 +18,7 @@ */ package hivemall.mf; -import hivemall.common.EtaEstimator; +import hivemall.optimizer.EtaEstimator; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/AbstractPredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/AbstractPredictionModel.java b/core/src/main/java/hivemall/model/AbstractPredictionModel.java index b48282b..95935d3 100644 --- a/core/src/main/java/hivemall/model/AbstractPredictionModel.java +++ b/core/src/main/java/hivemall/model/AbstractPredictionModel.java @@ -18,6 +18,7 @@ */ package hivemall.model; +import hivemall.annotations.InternalAPI; import hivemall.mix.MixedWeight; import hivemall.mix.MixedWeight.WeightWithCovar; import hivemall.mix.MixedWeight.WeightWithDelta; @@ -25,10 +26,12 @@ import hivemall.utils.collections.maps.IntOpenHashMap; import hivemall.utils.collections.maps.OpenHashMap; import javax.annotation.Nonnull; +import javax.annotation.Nullable; public abstract class AbstractPredictionModel implements PredictionModel { public static final byte BYTE0 = 0; + @Nullable protected ModelUpdateHandler handler; private long numMixed; @@ -50,7 +53,7 @@ public abstract class AbstractPredictionModel implements PredictionModel { } @Override - public void configureMix(ModelUpdateHandler handler, boolean cancelMixRequest) { + public void configureMix(@Nonnull ModelUpdateHandler handler, boolean cancelMixRequest) { this.handler = handler; this.cancelMixRequest = cancelMixRequest; if (cancelMixRequest) { @@ -184,9 +187,6 @@ public abstract class AbstractPredictionModel implements PredictionModel { } } - /** - * - */ @Override public void set(@Nonnull Object feature, float weight, float covar, short clock) { if (hasCovariance()) { @@ -197,8 +197,10 @@ public abstract class AbstractPredictionModel implements PredictionModel { numMixed++; } + @InternalAPI protected abstract void _set(@Nonnull Object feature, float weight, short clock); + @InternalAPI protected abstract void _set(@Nonnull Object feature, float weight, float covar, short clock); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/DenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/DenseModel.java b/core/src/main/java/hivemall/model/DenseModel.java index 628b43e..db72070 100644 --- a/core/src/main/java/hivemall/model/DenseModel.java +++ b/core/src/main/java/hivemall/model/DenseModel.java @@ -147,7 +147,7 @@ public final class DenseModel extends AbstractPredictionModel { @SuppressWarnings("unchecked") @Override - public <T extends IWeightValue> T get(Object feature) { + public <T extends IWeightValue> T get(@Nonnull final Object feature) { final int i = HiveUtils.parseInt(feature); if (i >= size) { return null; @@ -170,7 +170,7 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - public <T extends IWeightValue> void set(Object feature, T value) { + public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); float weight = value.get(); @@ -204,7 +204,7 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - public void delete(@Nonnull Object feature) { + public void delete(@Nonnull final Object feature) { final int i = HiveUtils.parseInt(feature); if (i >= size) { return; @@ -226,7 +226,7 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - public float getWeight(Object feature) { + public float getWeight(@Nonnull final Object feature) { int i = HiveUtils.parseInt(feature); if (i >= size) { return 0f; @@ -235,7 +235,12 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - public float getCovariance(Object feature) { + public void setWeight(@Nonnull final Object feature, final float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getCovariance(@Nonnull final Object feature) { int i = HiveUtils.parseInt(feature); if (i >= size) { return 1f; @@ -244,7 +249,7 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - protected void _set(Object feature, float weight, short clock) { + protected void _set(@Nonnull final Object feature, final float weight, final short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); weights[i] = weight; @@ -253,7 +258,8 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - protected void _set(Object feature, float weight, float covar, short clock) { + protected void _set(@Nonnull final Object feature, final float weight, final float covar, + final short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); weights[i] = weight; @@ -268,7 +274,7 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - public boolean contains(Object feature) { + public boolean contains(@Nonnull final Object feature) { int i = HiveUtils.parseInt(feature); if (i >= size) { return false; @@ -329,7 +335,7 @@ public final class DenseModel extends AbstractPredictionModel { } @Override - public <T extends Copyable<IWeightValue>> void getValue(T probe) { + public <T extends Copyable<IWeightValue>> void getValue(@Nonnull final T probe) { float w = weights[cursor]; tmpWeight.value = w; float cov = 1.f; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/IWeightValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/IWeightValue.java b/core/src/main/java/hivemall/model/IWeightValue.java index cd25564..731c310 100644 --- a/core/src/main/java/hivemall/model/IWeightValue.java +++ b/core/src/main/java/hivemall/model/IWeightValue.java @@ -25,7 +25,7 @@ import javax.annotation.Nonnegative; public interface IWeightValue extends Copyable<IWeightValue> { public enum WeightValueType { - NoParams, ParamsF1, ParamsF2, ParamsCovar; + NoParams, ParamsF1, ParamsF2, ParamsF3, ParamsCovar; } WeightValueType getType(); @@ -44,10 +44,24 @@ public interface IWeightValue extends Copyable<IWeightValue> { float getSumOfSquaredGradients(); + void setSumOfSquaredGradients(float value); + float getSumOfSquaredDeltaX(); + void setSumOfSquaredDeltaX(float value); + float getSumOfGradients(); + void setSumOfGradients(float value); + + float getM(); + + void setM(float value); + + float getV(); + + void setV(float value); + /** * @return whether touched in training or not */ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/NewDenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewDenseModel.java b/core/src/main/java/hivemall/model/NewDenseModel.java new file mode 100644 index 0000000..aab3c2b --- /dev/null +++ b/core/src/main/java/hivemall/model/NewDenseModel.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.model; + +import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Copyable; +import hivemall.utils.math.MathUtils; + +import java.util.Arrays; + +import javax.annotation.Nonnull; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public final class NewDenseModel extends AbstractPredictionModel { + private static final Log logger = LogFactory.getLog(NewDenseModel.class); + + private int size; + private float[] weights; + private float[] covars; + + // optional value for MIX + private short[] clocks; + private byte[] deltaUpdates; + + public NewDenseModel(int ndims) { + this(ndims, false); + } + + public NewDenseModel(int ndims, boolean withCovar) { + super(); + int size = ndims + 1; + this.size = size; + this.weights = new float[size]; + if (withCovar) { + float[] covars = new float[size]; + Arrays.fill(covars, 1f); + this.covars = covars; + } else { + this.covars = null; + } + this.clocks = null; + this.deltaUpdates = null; + } + + @Override + protected boolean isDenseModel() { + return true; + } + + @Override + public boolean hasCovariance() { + return covars != null; + } + + @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) {} + + @Override + public void configureClock() { + if (clocks == null) { + this.clocks = new short[size]; + this.deltaUpdates = new byte[size]; + } + } + + @Override + public boolean hasClock() { + return clocks != null; + } + + @Override + public void resetDeltaUpdates(int feature) { + deltaUpdates[feature] = 0; + } + + private void ensureCapacity(final int index) { + if (index >= size) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + int oldSize = size; + logger.info("Expands internal array size from " + oldSize + " to " + newSize + " (" + + bits + " bits)"); + this.size = newSize; + this.weights = Arrays.copyOf(weights, newSize); + if (covars != null) { + this.covars = Arrays.copyOf(covars, newSize); + Arrays.fill(covars, oldSize, newSize, 1.f); + } + if (clocks != null) { + this.clocks = Arrays.copyOf(clocks, newSize); + this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); + } + } + } + + @SuppressWarnings("unchecked") + @Override + public <T extends IWeightValue> T get(@Nonnull final Object feature) { + final int i = HiveUtils.parseInt(feature); + if (i >= size) { + return null; + } + if (covars != null) { + return (T) new WeightValueWithCovar(weights[i], covars[i]); + } else { + return (T) new WeightValue(weights[i]); + } + } + + @Override + public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + float weight = value.get(); + weights[i] = weight; + float covar = 1.f; + boolean hasCovar = value.hasCovariance(); + if (hasCovar) { + covar = value.getCovariance(); + covars[i] = covar; + } + short clock = 0; + int delta = 0; + if (clocks != null && value.isTouched()) { + clock = (short) (clocks[i] + 1); + clocks[i] = clock; + delta = deltaUpdates[i] + 1; + assert (delta > 0) : delta; + deltaUpdates[i] = (byte) delta; + } + + onUpdate(i, weight, covar, clock, delta, hasCovar); + } + + @Override + public void delete(@Nonnull final Object feature) { + final int i = HiveUtils.parseInt(feature); + if (i >= size) { + return; + } + weights[i] = 0.f; + if (covars != null) { + covars[i] = 1.f; + } + // avoid clock/delta + } + + @Override + public float getWeight(@Nonnull final Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return 0f; + } + return weights[i]; + } + + @Override + public void setWeight(@Nonnull final Object feature, final float value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weights[i] = value; + } + + @Override + public float getCovariance(@Nonnull final Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return 1f; + } + return covars[i]; + } + + @Override + protected void _set(@Nonnull final Object feature, final float weight, final short clock) { + int i = ((Integer) feature).intValue(); + ensureCapacity(i); + weights[i] = weight; + clocks[i] = clock; + deltaUpdates[i] = 0; + } + + @Override + protected void _set(@Nonnull final Object feature, final float weight, final float covar, + final short clock) { + int i = ((Integer) feature).intValue(); + ensureCapacity(i); + weights[i] = weight; + covars[i] = covar; + clocks[i] = clock; + deltaUpdates[i] = 0; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean contains(@Nonnull final Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return false; + } + float w = weights[i]; + return w != 0.f; + } + + @SuppressWarnings("unchecked") + @Override + public <K, V extends IWeightValue> IMapIterator<K, V> entries() { + return (IMapIterator<K, V>) new Itr(); + } + + private final class Itr implements IMapIterator<Number, IWeightValue> { + + private int cursor; + private final WeightValueWithCovar tmpWeight; + + private Itr() { + this.cursor = -1; + this.tmpWeight = new WeightValueWithCovar(); + } + + @Override + public boolean hasNext() { + return cursor < size; + } + + @Override + public int next() { + ++cursor; + if (!hasNext()) { + return -1; + } + return cursor; + } + + @Override + public Integer getKey() { + return cursor; + } + + @Override + public IWeightValue getValue() { + if (covars == null) { + float w = weights[cursor]; + WeightValue v = new WeightValue(w); + v.setTouched(w != 0f); + return v; + } else { + float w = weights[cursor]; + float cov = covars[cursor]; + WeightValueWithCovar v = new WeightValueWithCovar(w, cov); + v.setTouched(w != 0.f || cov != 1.f); + return v; + } + } + + @Override + public <T extends Copyable<IWeightValue>> void getValue(@Nonnull final T probe) { + float w = weights[cursor]; + tmpWeight.value = w; + float cov = 1.f; + if (covars != null) { + cov = covars[cursor]; + tmpWeight.setCovariance(cov); + } + tmpWeight.setTouched(w != 0.f || cov != 1.f); + probe.copyFrom(tmpWeight); + } + + } + +}
