http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java index 76bead8..730d0f4 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java @@ -22,9 +22,11 @@ import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.collections.arrays.DoubleArray3D; import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.lang.NumberUtils; +import hivemall.utils.math.MathUtils; import java.util.Arrays; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -34,19 +36,33 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM @Nonnull protected final FFMHyperParameters _params; - protected final float _eta0_V; + protected final float _eta0; protected final float _eps; protected final boolean _useAdaGrad; protected final boolean _useFTRL; + // FTEL + private final float _alpha; + private final float _beta; + private final float _lambda1; + private final float _lamdda2; + public FieldAwareFactorizationMachineModel(@Nonnull FFMHyperParameters params) { super(params); this._params = params; - this._eta0_V = params.eta0_V; + if (params.useAdaGrad) { + this._eta0 = 1.0f; + } else { + this._eta0 = params.eta.eta0(); + } this._eps = params.eps; this._useAdaGrad = params.useAdaGrad; this._useFTRL = params.useFTRL; + this._alpha = params.alphaFTRL; + this._beta = params.betaFTRL; + this._lambda1 = params.lambda1; + this._lamdda2 = params.lamdda2; } public abstract float getV(@Nonnull Feature x, @Nonnull int yField, int f); @@ -100,31 +116,152 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM return ret; } + void updateWi(final double dloss, @Nonnull final Feature x, final long t) { + if (_useFTRL) { + updateWi_FTRL(dloss, x); + return; + } + + final double Xi = x.getValue(); + float gradWi = (float) (dloss * Xi); + + final Entry theta = getEntryW(x); + float wi = theta.getW(); + + final float eta = eta(theta, t, gradWi); + float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi); + if (!NumberUtils.isFinite(nextWi)) { + throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() + + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss + + ", eta=" + eta + ", t=" + t); + } + if (MathUtils.closeToZero(nextWi, 1E-9f)) { + removeEntry(theta); + return; + } + theta.setW(nextWi); + } + + /** + * Update Wi using Follow-the-Regularized-Leader + */ + private void updateWi_FTRL(final double dloss, @Nonnull final Feature x) { + final double Xi = x.getValue(); + float gradWi = (float) (dloss * Xi); + + final Entry theta = getEntryW(x); + + final float z = theta.updateZ(gradWi, _alpha); + final double n = theta.updateN(gradWi); + + if (Math.abs(z) <= _lambda1) { + removeEntry(theta); + return; + } + + final float nextWi = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n)) + / _alpha + _lamdda2)); + if (!NumberUtils.isFinite(nextWi)) { + throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() + + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + theta.getW() + + ", dloss=" + dloss + ", n=" + n + ", z=" + z); + } + if (MathUtils.closeToZero(nextWi, 1E-9f)) { + removeEntry(theta); + return; + } + theta.setW(nextWi); + } + + protected abstract void removeEntry(@Nonnull final Entry entry); + void updateV(final double dloss, @Nonnull final Feature x, @Nonnull final int yField, final int f, final double sumViX, long t) { + if (_useFTRL) { + updateV_FTRL(dloss, x, yField, f, sumViX); + return; + } + + final Entry theta = getEntryV(x, yField); + if (theta == null) { + return; + } + final double Xi = x.getValue(); final double h = Xi * sumViX; final float gradV = (float) (dloss * h); final float lambdaVf = getLambdaV(f); - final Entry theta = getEntry(x, yField); final float currentV = theta.getV(f); - final float eta = etaV(theta, t, gradV); + final float eta = eta(theta, f, t, gradV); final float nextV = currentV - eta * (gradV + 2.f * lambdaVf * currentV); if (!NumberUtils.isFinite(nextV)) { throw new IllegalStateException("Got " + nextV + " for next V" + f + '[' + x.getFeatureIndex() + "]\n" + "Xi=" + Xi + ", Vif=" + currentV + ", h=" + h + ", gradV=" + gradV + ", lambdaVf=" + lambdaVf + ", dloss=" + dloss - + ", sumViX=" + sumViX); + + ", sumViX=" + sumViX + ", t=" + t); + } + if (MathUtils.closeToZero(nextV, 1E-9f)) { + theta.setV(f, 0.f); + if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled + removeEntry(theta); + } + return; + } + theta.setV(f, nextV); + } + + private void updateV_FTRL(final double dloss, @Nonnull final Feature x, + @Nonnull final int yField, final int f, final double sumViX) { + final Entry theta = getEntryV(x, yField); + if (theta == null) { + return; + } + + final double Xi = x.getValue(); + final double h = Xi * sumViX; + final float gradV = (float) (dloss * h); + + float oldV = theta.getV(f); + final float z = theta.updateZ(f, oldV, gradV, _alpha); + final double n = theta.updateN(f, gradV); + + if (Math.abs(z) <= _lambda1) { + theta.setV(f, 0.f); + if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled + removeEntry(theta); + } + return; + } + + final float nextV = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n)) + / _alpha + _lamdda2)); + if (!NumberUtils.isFinite(nextV)) { + throw new IllegalStateException("Got " + nextV + " for next V" + f + '[' + + x.getFeatureIndex() + "]\n" + "Xi=" + Xi + ", Vif=" + theta.getV(f) + ", h=" + + h + ", gradV=" + gradV + ", dloss=" + dloss + ", sumViX=" + sumViX + ", n=" + + n + ", z=" + z); + } + if (MathUtils.closeToZero(nextV, 1E-9f)) { + theta.setV(f, 0.f); + if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled + removeEntry(theta); + } + return; } theta.setV(f, nextV); } - protected final float etaV(@Nonnull final Entry theta, final long t, final float grad) { + protected final float eta(@Nonnull final Entry theta, final long t, final float grad) { + return eta(theta, 0, t, grad); + } + + protected final float eta(@Nonnull final Entry theta, @Nonnegative final int f, final long t, + final float grad) { if (_useAdaGrad) { - double gg = theta.getSumOfSquaredGradientsV(); - theta.addGradientV(grad); - return (float) (_eta0_V / Math.sqrt(_eps + gg)); + double gg = theta.getSumOfSquaredGradients(f); + theta.addGradient(f, grad); + return (float) (_eta0 / Math.sqrt(_eps + gg)); } else { return _eta.eta(t); } @@ -187,10 +324,10 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM } @Nonnull - protected abstract Entry getEntry(@Nonnull Feature x); + protected abstract Entry getEntryW(@Nonnull Feature x); - @Nonnull - protected abstract Entry getEntry(@Nonnull Feature x, @Nonnull int yField); + @Nullable + protected abstract Entry getEntryV(@Nonnull Feature x, @Nonnull int yField); @Override protected final String varDump(@Nonnull final Feature[] x) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java index 67dbf87..56d9dc2 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java @@ -18,17 +18,18 @@ */ package hivemall.fm; +import hivemall.fm.FFMStringFeatureMapModel.EntryIterator; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.collections.arrays.DoubleArray3D; import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.hadoop.HadoopUtils; -import hivemall.utils.hadoop.Text3; -import hivemall.utils.lang.NumberUtils; +import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.math.MathUtils; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -44,6 +45,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; /** @@ -60,8 +63,6 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi // ---------------------------------------- // Learning hyper-parameters/options - private boolean _FTRL; - private boolean _globalBias; private boolean _linearCoeff; @@ -87,26 +88,25 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi opts.addOption("disable_wi", "no_coeff", false, "Not to include linear term [default: OFF]"); // feature hashing opts.addOption("feature_hashing", true, - "The number of bits for feature hashing in range [18,31] [default:21]"); - opts.addOption("num_fields", true, "The number of fields [default:1024]"); + "The number of bits for feature hashing in range [18,31] [default: -1]. No feature hashing for -1."); + opts.addOption("num_fields", true, "The number of fields [default: 256]"); + // optimizer + opts.addOption("opt", "optimizer", true, + "Gradient Descent optimizer [default: ftrl, adagrad, sgd]"); // adagrad - opts.addOption("disable_adagrad", false, - "Whether to use AdaGrad for tuning learning rate [default: ON]"); - opts.addOption("eta0_V", true, "The initial learning rate for V [default 1.0]"); - opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]"); + opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default: 1.0]"); // FTRL - opts.addOption("disable_ftrl", false, - "Whether not to use Follow-The-Regularized-Reader [default: OFF]"); opts.addOption("alpha", "alphaFTRL", true, - "Alpha value (learning rate) of Follow-The-Regularized-Reader [default 0.1]"); + "Alpha value (learning rate) of Follow-The-Regularized-Reader [default: 0.2]"); opts.addOption("beta", "betaFTRL", true, - "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default 1.0]"); + "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default: 1.0]"); opts.addOption( + "l1", "lambda1", true, - "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default 0.1]"); - opts.addOption("lambda2", true, - "L2 regularization value of Follow-The-Regularized-Reader [default 0.01]"); + "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default: 0.001]"); + opts.addOption("l2", "lambda2", true, + "L2 regularization value of Follow-The-Regularized-Reader [default: 0.0001]"); return opts; } @@ -125,7 +125,6 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi CommandLine cl = super.processOptions(argOIs); FFMHyperParameters params = (FFMHyperParameters) _params; - this._FTRL = params.useFTRL; this._globalBias = params.globalBias; this._linearCoeff = params.linearCoeff; this._numFeatures = params.numFeatures; @@ -150,8 +149,14 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi fieldNames.add("model_id"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); - fieldNames.add("model"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + fieldNames.add("i"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + fieldNames.add("Wi"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + fieldNames.add("Vi"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @@ -184,20 +189,19 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi @Override protected void trainTheta(@Nonnull final Feature[] x, final double y) throws HiveException { - final float eta_t = _etaEstimator.eta(_t); - final double p = _ffmModel.predict(x); final double lossGrad = _ffmModel.dloss(p, y); double loss = _lossFunction.loss(p, y); _cvState.incrLoss(loss); - if (MathUtils.closeToZero(lossGrad)) { + if (MathUtils.closeToZero(lossGrad, 1E-9d)) { return; } // w0 update if (_globalBias) { + float eta_t = _etaEstimator.eta(_t); _ffmModel.updateW0(lossGrad, eta_t); } @@ -210,14 +214,16 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi if (x_i.value == 0.f) { continue; } - boolean useV = updateWi(lossGrad, x_i, eta_t); // wi update - if (useV == false) { - continue; + if (_linearCoeff) { + _ffmModel.updateWi(lossGrad, x_i, _t);// wi update } for (int fieldIndex = 0, size = fieldList.size(); fieldIndex < size; fieldIndex++) { final int yField = fieldList.get(fieldIndex); for (int f = 0, k = _factors; f < k; f++) { - double sumViX = sumVfX.get(i, fieldIndex, f); + final double sumViX = sumVfX.get(i, fieldIndex, f); + if (MathUtils.closeToZero(sumViX)) {// grad will be 0 => skip it + continue; + } _ffmModel.updateV(lossGrad, x_i, yField, f, sumViX, _t); } } @@ -229,18 +235,6 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi fieldList.clear(); } - private boolean updateWi(double lossGrad, @Nonnull Feature xi, float eta) { - if (!_linearCoeff) { - return true; - } - if (_FTRL) { - return _ffmModel.updateWiFTRL(lossGrad, xi, eta); - } else { - _ffmModel.updateWi(lossGrad, xi, eta); - return true; - } - } - @Nonnull private IntArrayList getFieldList(@Nonnull final Feature[] x) { for (Feature e : x) { @@ -257,7 +251,16 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi @Override public void close() throws HiveException { + if (LOG.isInfoEnabled()) { + LOG.info(_ffmModel.getStatistics()); + } + + _ffmModel.disableInitV(); // trick to avoid re-instantiating removed (zero-filled) entry of V super.close(); + + if (LOG.isInfoEnabled()) { + LOG.info(_ffmModel.getStatistics()); + } this._ffmModel = null; } @@ -267,39 +270,54 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi this._fieldList = null; this._sumVfX = null; - Text modelId = new Text(); - String taskId = HadoopUtils.getUniqueTaskIdString(); - modelId.set(taskId); - - FFMPredictionModel predModel = _ffmModel.toPredictionModel(); - this._ffmModel = null; // help GC - - if (LOG.isInfoEnabled()) { - LOG.info("Serializing a model '" + modelId + "'... Configured # features: " - + _numFeatures + ", Configured # fields: " + _numFields - + ", Actual # features: " + predModel.getActualNumFeatures() - + ", Estimated uncompressed bytes: " - + NumberUtils.prettySize(predModel.approxBytesConsumed())); - } + final int factors = _factors; + final IntWritable idx = new IntWritable(); + final FloatWritable Wi = new FloatWritable(0.f); + final FloatWritable[] Vi = HiveUtils.newFloatArray(factors, 0.f); + final List<FloatWritable> ViObj = Arrays.asList(Vi); + + final Object[] forwardObjs = new Object[4]; + String modelId = HadoopUtils.getUniqueTaskIdString(); + forwardObjs[0] = new Text(modelId); + forwardObjs[1] = idx; + forwardObjs[2] = Wi; + forwardObjs[3] = null; // Vi + + // W0 + idx.set(0); + Wi.set(_ffmModel.getW0()); + forward(forwardObjs); - byte[] serialized; - try { - serialized = predModel.serialize(); - predModel = null; - } catch (IOException e) { - throw new HiveException("Failed to serialize a model", e); - } + final EntryIterator itor = _ffmModel.entries(); + final Entry entryW = itor.getEntryProbeW(); + final Entry entryV = itor.getEntryProbeV(); + final float[] Vf = new float[factors]; + while (itor.next()) { + // set i + int i = itor.getEntryIndex(); + idx.set(i); + + if (Entry.isEntryW(i)) {// set Wi + itor.getEntry(entryW); + float w = entryV.getW(); + if (w == 0.f) { + continue; // skip w_i=0 + } + Wi.set(w); + forwardObjs[2] = Wi; + forwardObjs[3] = null; + } else {// set Vif + itor.getEntry(entryV); + entryV.getV(Vf); + for (int f = 0; f < factors; f++) { + Vi[f].set(Vf[f]); + } + forwardObjs[2] = null; + forwardObjs[3] = ViObj; + } - if (LOG.isInfoEnabled()) { - LOG.info("Forwarding a serialized/compressed model '" + modelId + "' of size: " - + NumberUtils.prettySize(serialized.length)); + forward(forwardObjs); } - - Text modelObj = new Text3(serialized); - serialized = null; - Object[] forwardObjs = new Object[] {modelId, modelObj}; - - forward(forwardObjs); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/fm/IntFeature.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/fm/IntFeature.java b/core/src/main/java/hivemall/fm/IntFeature.java index 2052f7e..64a4daa 100644 --- a/core/src/main/java/hivemall/fm/IntFeature.java +++ b/core/src/main/java/hivemall/fm/IntFeature.java @@ -20,19 +20,21 @@ package hivemall.fm; import java.nio.ByteBuffer; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; public final class IntFeature extends Feature { + @Nonnegative private int index; /** -1 if not defined */ private short field; - public IntFeature(int index, double value) { + public IntFeature(@Nonnegative int index, double value) { this(index, (short) -1, value); } - public IntFeature(int index, short field, double value) { + public IntFeature(@Nonnegative int index, short field, double value) { super(value); this.field = field; this.index = index; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java index 6aebd64..3ec6ad7 100644 --- a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java +++ b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java @@ -19,15 +19,18 @@ package hivemall.ftvec.pairing; import hivemall.UDTFWithOptions; +import hivemall.fm.Feature; import hivemall.model.FeatureValue; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hashing.HashFunction; import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.Primitives; import java.util.ArrayList; import java.util.List; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; @@ -50,6 +53,8 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { private Type _type; private RowProcessor _proc; + private int _numFields; + private int _numFeatures; public FeaturePairsUDTF() {} @@ -57,9 +62,14 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { protected Options getOptions() { Options opts = new Options(); opts.addOption("kpa", false, - "Generate feature pairs for Kernel-Expansion Passive Aggressive [default:true]"); + "Generate feature pairs for Kernel-Expansion Passive Aggressive [default:false]"); opts.addOption("ffm", false, "Generate feature pairs for Field-aware Factorization Machines [default:false]"); + // feature hashing + opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]"); + opts.addOption("feature_hashing", true, + "The number of bits for feature hashing in range [18,31]. [default: -1] No feature hashing for -1."); + opts.addOption("num_fields", true, "The number of fields [default:1024]"); return opts; } @@ -70,13 +80,30 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { String args = HiveUtils.getConstString(argOIs[1]); cl = parseOptions(args); - Preconditions.checkArgument(cl.getOptions().length == 1, UDFArgumentException.class, - "Only one option can be specified: " + cl.getArgList()); + Preconditions.checkArgument(cl.getOptions().length <= 3, UDFArgumentException.class, + "Too many options were specified: " + cl.getArgList()); if (cl.hasOption("kpa")) { this._type = Type.kpa; } else if (cl.hasOption("ffm")) { this._type = Type.ffm; + this._numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), -1); + if (_numFeatures == -1) { + int featureBits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), -1); + if (featureBits != -1) { + if (featureBits < 18 || featureBits > 31) { + throw new UDFArgumentException( + "-feature_hashing MUST be in range [18,31]: " + featureBits); + } + this._numFeatures = 1 << featureBits; + } + } + this._numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), + Feature.DEFAULT_NUM_FIELDS); + if (_numFields <= 1) { + throw new UDFArgumentException("-num_fields MUST be greater than 1: " + + _numFields); + } } else { throw new UDFArgumentException("Unsupported option: " + cl.getArgList().get(0)); } @@ -113,8 +140,16 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { break; } case ffm: { - throw new UDFArgumentException("-ffm is not supported yet"); - //break; + this._proc = new FFMProcessor(fvOI); + fieldNames.add("i"); // <ei, jField> index + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("j"); // <ej, iField> index + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("xi"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("xj"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + break; } default: throw new UDFArgumentException("Illegal condition: " + _type); @@ -144,26 +179,7 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { this.fvOI = fvOI; } - void process(@Nonnull Object arg) throws HiveException { - final int size = fvOI.getListLength(arg); - if (size == 0) { - return; - } - - final List<FeatureValue> features = new ArrayList<FeatureValue>(size); - for (int i = 0; i < size; i++) { - Object f = fvOI.getListElement(arg, i); - if (f == null) { - continue; - } - FeatureValue fv = FeatureValue.parse(f, true); - features.add(fv); - } - - process(features); - } - - abstract void process(@Nonnull List<FeatureValue> features) throws HiveException; + abstract void process(@Nonnull Object arg) throws HiveException; } @@ -186,7 +202,22 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { } @Override - void process(@Nonnull List<FeatureValue> features) throws HiveException { + void process(@Nonnull Object arg) throws HiveException { + final int size = fvOI.getListLength(arg); + if (size == 0) { + return; + } + + final List<FeatureValue> features = new ArrayList<FeatureValue>(size); + for (int i = 0; i < size; i++) { + Object f = fvOI.getListElement(arg, i); + if (f == null) { + continue; + } + FeatureValue fv = FeatureValue.parse(f, true); + features.add(fv); + } + forward[0] = f0; f0.set(0); forward[1] = null; @@ -222,6 +253,78 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { } } + final class FFMProcessor extends RowProcessor { + + @Nonnull + private final IntWritable f0, f1; + @Nonnull + private final DoubleWritable f2, f3; + @Nonnull + private final Writable[] forward; + + @Nullable + private transient Feature[] _features; + + FFMProcessor(@Nonnull ListObjectInspector fvOI) { + super(fvOI); + this.f0 = new IntWritable(); + this.f1 = new IntWritable(); + this.f2 = new DoubleWritable(); + this.f3 = new DoubleWritable(); + this.forward = new Writable[] {f0, null, null, null}; + this._features = null; + } + + @Override + void process(@Nonnull Object arg) throws HiveException { + final int size = fvOI.getListLength(arg); + if (size == 0) { + return; + } + + this._features = Feature.parseFFMFeatures(arg, fvOI, _features, _numFeatures, + _numFields); + + // W0 + f0.set(0); + forward[1] = null; + forward[2] = null; + forward[3] = null; + forward(forward); + + forward[2] = f2; + final Feature[] features = _features; + for (int i = 0, len = features.length; i < len; i++) { + Feature ei = features[i]; + + // Wi + f0.set(Feature.toIntFeature(ei)); + forward[1] = null; + f2.set(ei.getValue()); + forward[3] = null; + forward(forward); + + forward[1] = f1; + forward[3] = f3; + final int iField = ei.getField(); + for (int j = i + 1; j < len; j++) { + Feature ej = features[j]; + double xj = ej.getValue(); + int jField = ej.getField(); + + int ifj = Feature.toIntFeature(ei, jField, _numFields); + int jfi = Feature.toIntFeature(ej, iField, _numFields); + + // Vifj, Vjfi + f0.set(ifj); + f1.set(jfi); + // `f2` is consistently set to `xi` + f3.set(xj); + forward(forward); + } + } + } + } @Override public void close() throws HiveException { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java index 5e9f797..cdba00b 100644 --- a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java +++ b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java @@ -19,8 +19,8 @@ package hivemall.ftvec.ranking; import hivemall.utils.collections.lists.IntArrayList; -import hivemall.utils.collections.maps.IntOpenHashMap; -import hivemall.utils.collections.maps.IntOpenHashMap.IMapIterator; +import hivemall.utils.collections.maps.IntOpenHashTable; +import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator; import java.util.BitSet; @@ -30,13 +30,13 @@ import javax.annotation.Nullable; public class PositiveOnlyFeedback { @Nonnull - protected final IntOpenHashMap<IntArrayList> rows; + protected final IntOpenHashTable<IntArrayList> rows; protected int maxItemId; protected int totalFeedbacks; public PositiveOnlyFeedback(int maxItemId) { - this.rows = new IntOpenHashMap<IntArrayList>(1024); + this.rows = new IntOpenHashTable<IntArrayList>(1024); this.maxItemId = maxItemId; this.totalFeedbacks = 0; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java new file mode 100644 index 0000000..53b998c --- /dev/null +++ b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.trans; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +@Description(name = "add_field_indicies", value = "_FUNC_(array<string> features) " + + "- Returns arrays of string that field indicies (<field>:<feature>)* are argumented") +@UDFType(deterministic = true, stateful = false) +public final class AddFieldIndicesUDF extends GenericUDF { + + private ListObjectInspector listOI; + + @Override + public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) + throws UDFArgumentException { + if (argOIs.length != 1) { + throw new UDFArgumentException("Expected a single argument: " + argOIs.length); + } + + this.listOI = HiveUtils.asListOI(argOIs[0]); + if (!HiveUtils.isStringOI(listOI.getListElementObjectInspector())) { + throw new UDFArgumentException("Expected array<string> but got " + argOIs[0]); + } + + return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector); + } + + @Override + public List<String> evaluate(@Nonnull DeferredObject[] args) throws HiveException { + Preconditions.checkArgument(args.length == 1); + + final String[] features = HiveUtils.asStringArray(args[0], listOI); + if (features == null) { + return null; + } + + final List<String> argumented = new ArrayList<>(features.length); + for (int i = 0; i < features.length; i++) { + final String f = features[i]; + if (f == null) { + continue; + } + argumented.add((i + 1) + ":" + f); + } + + return argumented; + } + + @Override + public String getDisplayString(String[] args) { + return "add_field_indicies( " + Arrays.toString(args) + " )"; + } + + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java index 98617bd..4722efd 100644 --- a/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java @@ -18,6 +18,7 @@ */ package hivemall.ftvec.trans; +import hivemall.UDFWithOptions; import hivemall.utils.hadoop.HiveUtils; import java.util.ArrayList; @@ -26,26 +27,55 @@ import java.util.List; import javax.annotation.Nonnull; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import org.apache.hadoop.io.Text; -@Description(name = "categorical_features", - value = "_FUNC_(array<string> featureNames, ...) - Returns a feature vector array<string>") +@Description( + name = "categorical_features", + value = "_FUNC_(array<string> featureNames, feature1, feature2, .. [, const string options])" + + " - Returns a feature vector array<string>") @UDFType(deterministic = true, stateful = false) -public final class CategoricalFeaturesUDF extends GenericUDF { +public final class CategoricalFeaturesUDF extends UDFWithOptions { - private String[] featureNames; - private PrimitiveObjectInspector[] inputOIs; - private List<Text> result; + private String[] _featureNames; + private PrimitiveObjectInspector[] _inputOIs; + private List<String> _result; + + private boolean _emitNull = false; + private boolean _forceValue = false; + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("no_elim", "no_elimination", false, + "Wheather to emit NULL and value [default: false]"); + opts.addOption("emit_null", false, "Wheather to emit NULL [default: false]"); + opts.addOption("force_value", false, "Wheather to force emit value [default: false]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValue); + if (cl.hasOption("no_elim")) { + this._emitNull = true; + this._forceValue = true; + } else { + this._emitNull = cl.hasOption("emit_null"); + this._forceValue = cl.hasOption("force_value"); + } + return cl; + } @Override public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) @@ -55,54 +85,91 @@ public final class CategoricalFeaturesUDF extends GenericUDF { throw new UDFArgumentException("argOIs.length must be greater that or equals to 2: " + numArgOIs); } - this.featureNames = HiveUtils.getConstStringArray(argOIs[0]); - if (featureNames == null) { + + this._featureNames = HiveUtils.getConstStringArray(argOIs[0]); + if (_featureNames == null) { throw new UDFArgumentException("#featureNames should not be null"); } - int numFeatureNames = featureNames.length; + int numFeatureNames = _featureNames.length; if (numFeatureNames < 1) { throw new UDFArgumentException("#featureNames must be greater than or equals to 1: " + numFeatureNames); } - int numFeatures = numArgOIs - 1; + for (String featureName : _featureNames) { + if (featureName == null) { + throw new UDFArgumentException("featureName should not be null: " + + Arrays.toString(_featureNames)); + } else if (featureName.indexOf(':') != -1) { + throw new UDFArgumentException("featureName should not include colon: " + + featureName); + } + } + + final int numFeatures; + final int lastArgIndex = numArgOIs - 1; + if (lastArgIndex > numFeatureNames) { + if (lastArgIndex == (numFeatureNames + 1) + && HiveUtils.isConstString(argOIs[lastArgIndex])) { + String optionValue = HiveUtils.getConstString(argOIs[lastArgIndex]); + processOptions(optionValue); + numFeatures = numArgOIs - 2; + } else { + throw new UDFArgumentException( + "Unexpected arguments for _FUNC_" + + "(const array<string> featureNames, feature1, feature2, .. [, const string options])"); + } + } else { + numFeatures = lastArgIndex; + } if (numFeatureNames != numFeatures) { - throw new UDFArgumentException("#featureNames '" + numFeatureNames - + "' != #arguments '" + numFeatures + "'"); + throw new UDFArgumentLengthException("#featureNames '" + numFeatureNames + + "' != #features '" + numFeatures + "'"); } - this.inputOIs = new PrimitiveObjectInspector[numFeatures]; + this._inputOIs = new PrimitiveObjectInspector[numFeatures]; for (int i = 0; i < numFeatures; i++) { ObjectInspector oi = argOIs[i + 1]; - inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi); + _inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi); } - this.result = new ArrayList<Text>(numFeatures); + this._result = new ArrayList<String>(numFeatures); - return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector); } @Override - public List<Text> evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException { - result.clear(); + public List<String> evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException { + _result.clear(); - final int size = arguments.length - 1; + final int size = _featureNames.length; for (int i = 0; i < size; i++) { Object argument = arguments[i + 1].get(); if (argument == null) { + if (_emitNull) { + _result.add(null); + } continue; } - PrimitiveObjectInspector oi = inputOIs[i]; + PrimitiveObjectInspector oi = _inputOIs[i]; String s = PrimitiveObjectInspectorUtils.getString(argument, oi); if (s.isEmpty()) { + if (_emitNull) { + _result.add(null); + } continue; } - // categorical feature representation - String featureName = featureNames[i]; - Text f = new Text(featureName + '#' + s); - result.add(f); + // categorical feature representation + final String f; + if (_forceValue) { + f = _featureNames[i] + '#' + s + ":1"; + } else { + f = _featureNames[i] + '#' + s; + } + _result.add(f); + } - return result; + return _result; } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java index c98ffda..eead738 100644 --- a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java @@ -23,6 +23,7 @@ import hivemall.fm.Feature; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hashing.MurmurHash3; import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.StringUtils; import java.util.ArrayList; import java.util.Arrays; @@ -59,6 +60,7 @@ public final class FFMFeaturesUDF extends UDFWithOptions { private boolean _mhash = true; private int _numFeatures = Feature.DEFAULT_NUM_FEATURES; private int _numFields = Feature.DEFAULT_NUM_FIELDS; + private boolean _emitIndicies = false; @Override protected Options getOptions() { @@ -66,9 +68,11 @@ public final class FFMFeaturesUDF extends UDFWithOptions { opts.addOption("no_hash", "disable_feature_hashing", false, "Wheather to disable feature hashing [default: false]"); // feature hashing + opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]"); opts.addOption("hash", "feature_hashing", true, "The number of bits for feature hashing in range [18,31] [default:21]"); opts.addOption("fields", "num_fields", true, "The number of fields [default:1024]"); + opts.addOption("emit_indicies", false, "Emit indicies for fields [default: false]"); return opts; } @@ -77,19 +81,27 @@ public final class FFMFeaturesUDF extends UDFWithOptions { CommandLine cl = parseOptions(optionValue); // feature hashing - int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), - Feature.DEFAULT_FEATURE_BITS); - if (hashbits < 18 || hashbits > 31) { - throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: " + hashbits); + int numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), -1); + if (numFeatures == -1) { + int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), + Feature.DEFAULT_FEATURE_BITS); + if (hashbits < 18 || hashbits > 31) { + throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: " + + hashbits); + } + numFeatures = 1 << hashbits; } - int numFeatures = 1 << hashbits; + this._numFeatures = numFeatures; + int numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), Feature.DEFAULT_NUM_FIELDS); if (numFields <= 1) { throw new UDFArgumentException("-num_fields MUST be greater than 1: " + numFields); } - this._numFeatures = numFeatures; this._numFields = numFields; + + this._emitIndicies = cl.hasOption("emit_indicies"); + return cl; } @@ -111,7 +123,10 @@ public final class FFMFeaturesUDF extends UDFWithOptions { + numFeatureNames); } for (String featureName : _featureNames) { - if (featureName.indexOf(':') != -1) { + if (featureName == null) { + throw new UDFArgumentException("featureName should not be null: " + + Arrays.toString(_featureNames)); + } else if (featureName.indexOf(':') != -1) { throw new UDFArgumentException("featureName should not include colon: " + featureName); } @@ -174,18 +189,20 @@ public final class FFMFeaturesUDF extends UDFWithOptions { // categorical feature representation final String fv; if (_mhash) { - int field = MurmurHash3.murmurhash3(_featureNames[i], _numFields); + int field = _emitIndicies ? i : MurmurHash3.murmurhash3(_featureNames[i], + _numFields); // +NUM_FIELD to avoid conflict to quantitative features int index = MurmurHash3.murmurhash3(feature, _numFeatures) + _numFields; fv = builder.append(field).append(':').append(index).append(":1").toString(); - builder.setLength(0); + StringUtils.clear(builder); } else { - fv = builder.append(featureName) - .append(':') - .append(feature) - .append(":1") - .toString(); - builder.setLength(0); + if (_emitIndicies) { + builder.append(i); + } else { + builder.append(featureName); + } + fv = builder.append(':').append(feature).append(":1").toString(); + StringUtils.clear(builder); } _result.add(new Text(fv)); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java b/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java index 2886996..846be97 100644 --- a/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java +++ b/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java @@ -23,6 +23,7 @@ import hivemall.utils.lang.Identifier; import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; @@ -39,7 +40,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn @Description( name = "quantified_features", - value = "_FUNC_(boolean output, col1, col2, ...) - Returns an identified features in a dence array<double>") + value = "_FUNC_(boolean output, col1, col2, ...) - Returns an identified features in a dense array<double>") public final class QuantifiedFeaturesUDTF extends GenericUDTF { private BooleanObjectInspector boolOI; @@ -76,8 +77,8 @@ public final class QuantifiedFeaturesUDTF extends GenericUDTF { } } - ArrayList<String> fieldNames = new ArrayList<String>(outputSize); - ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(outputSize); + List<String> fieldNames = new ArrayList<String>(outputSize); + List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(outputSize); fieldNames.add("features"); fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java index 43f837f..38e35e2 100644 --- a/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java @@ -18,6 +18,7 @@ */ package hivemall.ftvec.trans; +import hivemall.UDFWithOptions; import hivemall.utils.hadoop.HiveUtils; import java.util.ArrayList; @@ -26,11 +27,13 @@ import java.util.List; import javax.annotation.Nonnull; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -39,14 +42,32 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.io.Text; -@Description(name = "quantitative_features", - value = "_FUNC_(array<string> featureNames, ...) - Returns a feature vector array<string>") +@Description( + name = "quantitative_features", + value = "_FUNC_(array<string> featureNames, feature1, feature2, .. [, const string options])" + + " - Returns a feature vector array<string>") @UDFType(deterministic = true, stateful = false) -public final class QuantitativeFeaturesUDF extends GenericUDF { +public final class QuantitativeFeaturesUDF extends UDFWithOptions { - private String[] featureNames; - private PrimitiveObjectInspector[] inputOIs; - private List<Text> result; + private String[] _featureNames; + private PrimitiveObjectInspector[] _inputOIs; + private List<Text> _result; + + private boolean _emitNull = false; + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("emit_null", false, "Wheather to emit NULL [default: false]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValue); + this._emitNull = cl.hasOption("emit_null"); + return cl; + } @Override public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) @@ -56,58 +77,92 @@ public final class QuantitativeFeaturesUDF extends GenericUDF { throw new UDFArgumentException("argOIs.length must be greater that or equals to 2: " + numArgOIs); } - this.featureNames = HiveUtils.getConstStringArray(argOIs[0]); - if (featureNames == null) { + + this._featureNames = HiveUtils.getConstStringArray(argOIs[0]); + if (_featureNames == null) { throw new UDFArgumentException("#featureNames should not be null"); } - int numFeatureNames = featureNames.length; + int numFeatureNames = _featureNames.length; if (numFeatureNames < 1) { throw new UDFArgumentException("#featureNames must be greater than or equals to 1: " + numFeatureNames); } - int numFeatures = numArgOIs - 1; + for (String featureName : _featureNames) { + if (featureName == null) { + throw new UDFArgumentException("featureName should not be null: " + + Arrays.toString(_featureNames)); + } else if (featureName.indexOf(':') != -1) { + throw new UDFArgumentException("featureName should not include colon: " + + featureName); + } + } + + final int numFeatures; + final int lastArgIndex = numArgOIs - 1; + if (lastArgIndex > numFeatureNames) { + if (lastArgIndex == (numFeatureNames + 1) + && HiveUtils.isConstString(argOIs[lastArgIndex])) { + String optionValue = HiveUtils.getConstString(argOIs[lastArgIndex]); + processOptions(optionValue); + numFeatures = numArgOIs - 2; + } else { + throw new UDFArgumentException( + "Unexpected arguments for _FUNC_" + + "(const array<string> featureNames, feature1, feature2, .. [, const string options])"); + } + } else { + numFeatures = lastArgIndex; + } if (numFeatureNames != numFeatures) { - throw new UDFArgumentException("#featureNames '" + numFeatureNames - + "' != #arguments '" + numFeatures + "'"); + throw new UDFArgumentLengthException("#featureNames '" + numFeatureNames + + "' != #features '" + numFeatures + "'"); } - this.inputOIs = new PrimitiveObjectInspector[numFeatures]; + this._inputOIs = new PrimitiveObjectInspector[numFeatures]; for (int i = 0; i < numFeatures; i++) { ObjectInspector oi = argOIs[i + 1]; - inputOIs[i] = HiveUtils.asDoubleCompatibleOI(oi); + _inputOIs[i] = HiveUtils.asDoubleCompatibleOI(oi); } - this.result = new ArrayList<Text>(numFeatures); + this._result = new ArrayList<Text>(numFeatures); return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector); } @Override public List<Text> evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException { - result.clear(); + _result.clear(); - final int size = arguments.length - 1; + final int size = _featureNames.length; for (int i = 0; i < size; i++) { Object argument = arguments[i + 1].get(); if (argument == null) { + if (_emitNull) { + _result.add(null); + } continue; } - PrimitiveObjectInspector oi = inputOIs[i]; + PrimitiveObjectInspector oi = _inputOIs[i]; if (oi.getPrimitiveCategory() == PrimitiveCategory.STRING) { String s = argument.toString(); if (s.isEmpty()) { + if (_emitNull) { + _result.add(null); + } continue; } } final double v = PrimitiveObjectInspectorUtils.getDouble(argument, oi); if (v != 0.d) { - String featureName = featureNames[i]; - Text f = new Text(featureName + ':' + v); - result.add(f); + Text f = new Text(_featureNames[i] + ':' + v); + _result.add(f); + } else if (_emitNull) { + Text f = new Text(_featureNames[i] + ":0"); + _result.add(f); } } - return result; + return _result; } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java index 48bf126..f2ecbb6 100644 --- a/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java @@ -18,6 +18,7 @@ */ package hivemall.ftvec.trans; +import hivemall.UDFWithOptions; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.StringUtils; @@ -27,11 +28,13 @@ import java.util.List; import javax.annotation.Nonnull; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -40,14 +43,32 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.io.Text; -@Description(name = "vectorize_features", - value = "_FUNC_(array<string> featureNames, ...) - Returns a feature vector array<string>") +@Description( + name = "vectorize_features", + value = "_FUNC_(array<string> featureNames, feature1, feature2, .. [, const string options])" + + " - Returns a feature vector array<string>") @UDFType(deterministic = true, stateful = false) -public final class VectorizeFeaturesUDF extends GenericUDF { +public final class VectorizeFeaturesUDF extends UDFWithOptions { - private String[] featureNames; - private PrimitiveObjectInspector[] inputOIs; - private List<Text> result; + private String[] _featureNames; + private PrimitiveObjectInspector[] _inputOIs; + private List<Text> _result; + + private boolean _emitNull = false; + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("emit_null", false, "Wheather to emit NULL [default: false]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValue); + this._emitNull = cl.hasOption("emit_null"); + return cl; + } @Override public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) @@ -57,63 +78,96 @@ public final class VectorizeFeaturesUDF extends GenericUDF { throw new UDFArgumentException("argOIs.length must be greater that or equals to 2: " + numArgOIs); } - this.featureNames = HiveUtils.getConstStringArray(argOIs[0]); - if (featureNames == null) { + + this._featureNames = HiveUtils.getConstStringArray(argOIs[0]); + if (_featureNames == null) { throw new UDFArgumentException("#featureNames should not be null"); } - int numFeatureNames = featureNames.length; + int numFeatureNames = _featureNames.length; if (numFeatureNames < 1) { throw new UDFArgumentException("#featureNames must be greater than or equals to 1: " + numFeatureNames); } - int numFeatures = numArgOIs - 1; + for (String featureName : _featureNames) { + if (featureName == null) { + throw new UDFArgumentException("featureName should not be null: " + + Arrays.toString(_featureNames)); + } else if (featureName.indexOf(':') != -1) { + throw new UDFArgumentException("featureName should not include colon: " + + featureName); + } + } + + final int numFeatures; + final int lastArgIndex = numArgOIs - 1; + if (lastArgIndex > numFeatureNames) { + if (lastArgIndex == (numFeatureNames + 1) + && HiveUtils.isConstString(argOIs[lastArgIndex])) { + String optionValue = HiveUtils.getConstString(argOIs[lastArgIndex]); + processOptions(optionValue); + numFeatures = numArgOIs - 2; + } else { + throw new UDFArgumentException( + "Unexpected arguments for _FUNC_" + + "(const array<string> featureNames, feature1, feature2, .. [, const string options])"); + } + } else { + numFeatures = lastArgIndex; + } if (numFeatureNames != numFeatures) { - throw new UDFArgumentException("#featureNames '" + numFeatureNames - + "' != #arguments '" + numFeatures + "'"); + throw new UDFArgumentLengthException("#featureNames '" + numFeatureNames + + "' != #features '" + numFeatures + "'"); } - this.inputOIs = new PrimitiveObjectInspector[numFeatures]; + this._inputOIs = new PrimitiveObjectInspector[numFeatures]; for (int i = 0; i < numFeatures; i++) { ObjectInspector oi = argOIs[i + 1]; - inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi); + _inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi); } - this.result = new ArrayList<Text>(numFeatures); + this._result = new ArrayList<Text>(numFeatures); return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector); } @Override public List<Text> evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException { - result.clear(); + _result.clear(); - final int size = arguments.length - 1; + final int size = _featureNames.length; for (int i = 0; i < size; i++) { Object argument = arguments[i + 1].get(); if (argument == null) { + if (_emitNull) { + _result.add(null); + } continue; } - PrimitiveObjectInspector oi = inputOIs[i]; + PrimitiveObjectInspector oi = _inputOIs[i]; if (oi.getPrimitiveCategory() == PrimitiveCategory.STRING) { String s = PrimitiveObjectInspectorUtils.getString(argument, oi); if (s.isEmpty()) { + if (_emitNull) { + _result.add(null); + } continue; } - if (StringUtils.isNumber(s) == false) {// categorical feature representation - String featureName = featureNames[i]; - Text f = new Text(featureName + '#' + s); - result.add(f); + if (StringUtils.isNumber(s) == false) {// categorical feature representation + Text f = new Text(_featureNames[i] + '#' + s); + _result.add(f); continue; } } - float v = PrimitiveObjectInspectorUtils.getFloat(argument, oi); + final float v = PrimitiveObjectInspectorUtils.getFloat(argument, oi); if (v != 0.f) { - String featureName = featureNames[i]; - Text f = new Text(featureName + ':' + v); - result.add(f); + Text f = new Text(_featureNames[i] + ':' + v); + _result.add(f); + } else if (_emitNull) { + Text f = new Text(_featureNames[i] + ":0"); + _result.add(f); } } - return result; + return _result; } @Override http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/mf/FactorizedModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/FactorizedModel.java b/core/src/main/java/hivemall/mf/FactorizedModel.java index a4bea00..1b7140f 100644 --- a/core/src/main/java/hivemall/mf/FactorizedModel.java +++ b/core/src/main/java/hivemall/mf/FactorizedModel.java @@ -18,7 +18,7 @@ */ package hivemall.mf; -import hivemall.utils.collections.maps.IntOpenHashMap; +import hivemall.utils.collections.maps.IntOpenHashTable; import hivemall.utils.math.MathUtils; import java.util.Random; @@ -42,10 +42,10 @@ public final class FactorizedModel { private int minIndex, maxIndex; @Nonnull private Rating meanRating; - private IntOpenHashMap<Rating[]> users; - private IntOpenHashMap<Rating[]> items; - private IntOpenHashMap<Rating> userBias; - private IntOpenHashMap<Rating> itemBias; + private IntOpenHashTable<Rating[]> users; + private IntOpenHashTable<Rating[]> items; + private IntOpenHashTable<Rating> userBias; + private IntOpenHashTable<Rating> itemBias; private final Random[] randU, randI; @@ -67,10 +67,10 @@ public final class FactorizedModel { this.minIndex = 0; this.maxIndex = 0; this.meanRating = ratingInitializer.newRating(meanRating); - this.users = new IntOpenHashMap<Rating[]>(expectedSize); - this.items = new IntOpenHashMap<Rating[]>(expectedSize); - this.userBias = new IntOpenHashMap<Rating>(expectedSize); - this.itemBias = new IntOpenHashMap<Rating>(expectedSize); + this.users = new IntOpenHashTable<Rating[]>(expectedSize); + this.items = new IntOpenHashTable<Rating[]>(expectedSize); + this.userBias = new IntOpenHashTable<Rating>(expectedSize); + this.itemBias = new IntOpenHashTable<Rating>(expectedSize); this.randU = newRandoms(factor, 31L); this.randI = newRandoms(factor, 41L); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/model/AbstractPredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/AbstractPredictionModel.java b/core/src/main/java/hivemall/model/AbstractPredictionModel.java index 95935d3..cd298a7 100644 --- a/core/src/main/java/hivemall/model/AbstractPredictionModel.java +++ b/core/src/main/java/hivemall/model/AbstractPredictionModel.java @@ -22,7 +22,7 @@ import hivemall.annotations.InternalAPI; import hivemall.mix.MixedWeight; import hivemall.mix.MixedWeight.WeightWithCovar; import hivemall.mix.MixedWeight.WeightWithDelta; -import hivemall.utils.collections.maps.IntOpenHashMap; +import hivemall.utils.collections.maps.IntOpenHashTable; import hivemall.utils.collections.maps.OpenHashMap; import javax.annotation.Nonnull; @@ -37,7 +37,7 @@ public abstract class AbstractPredictionModel implements PredictionModel { private long numMixed; private boolean cancelMixRequest; - private IntOpenHashMap<MixedWeight> mixedRequests_i; + private IntOpenHashTable<MixedWeight> mixedRequests_i; private OpenHashMap<Object, MixedWeight> mixedRequests_o; public AbstractPredictionModel() { @@ -58,7 +58,7 @@ public abstract class AbstractPredictionModel implements PredictionModel { this.cancelMixRequest = cancelMixRequest; if (cancelMixRequest) { if (isDenseModel()) { - this.mixedRequests_i = new IntOpenHashMap<MixedWeight>(327680); + this.mixedRequests_i = new IntOpenHashTable<MixedWeight>(327680); } else { this.mixedRequests_o = new OpenHashMap<Object, MixedWeight>(327680); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/model/NewSparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java index 8326d22..5c0a6c7 100644 --- a/core/src/main/java/hivemall/model/NewSparseModel.java +++ b/core/src/main/java/hivemall/model/NewSparseModel.java @@ -194,7 +194,7 @@ public final class NewSparseModel extends AbstractPredictionModel { @SuppressWarnings("unchecked") @Override public <K, V extends IWeightValue> IMapIterator<K, V> entries() { - return (IMapIterator<K, V>) weights.entries(); + return (IMapIterator<K, V>) weights.entries(true); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/model/SparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java index cb8ab9f..65e751d 100644 --- a/core/src/main/java/hivemall/model/SparseModel.java +++ b/core/src/main/java/hivemall/model/SparseModel.java @@ -183,7 +183,7 @@ public final class SparseModel extends AbstractPredictionModel { @SuppressWarnings("unchecked") @Override public <K, V extends IWeightValue> IMapIterator<K, V> entries() { - return (IMapIterator<K, V>) weights.entries(); + return (IMapIterator<K, V>) weights.entries(true); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java b/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java index a2e3e55..6dbb7d5 100644 --- a/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java +++ b/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java @@ -18,6 +18,10 @@ */ package hivemall.tools.array; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_ARRAY_META; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_REF; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES1; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES2; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; @@ -34,6 +38,7 @@ import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -220,7 +225,8 @@ public final class ArrayAvgGenericUDAF extends AbstractGenericUDAFResolver { } } - public static class ArrayAvgAggregationBuffer extends AbstractAggregationBuffer { + @AggregationType(estimable = true) + public static final class ArrayAvgAggregationBuffer extends AbstractAggregationBuffer { int _size; // note that primitive array cannot be serialized by JDK serializer @@ -289,6 +295,15 @@ public final class ArrayAvgGenericUDAF extends AbstractGenericUDAFResolver { } } + @Override + public int estimate() { + if (_size == -1) { + return JAVA64_REF; + } else { + return PRIMITIVES1 + 2 * (JAVA64_ARRAY_META + PRIMITIVES2 * _size); + } + } + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java b/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java index e0a3c9e..10051a9 100644 --- a/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java +++ b/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java @@ -20,7 +20,6 @@ package hivemall.utils.buffer; import hivemall.utils.lang.NumberUtils; import hivemall.utils.lang.Preconditions; -import hivemall.utils.lang.Primitives; import hivemall.utils.lang.SizeOf; import hivemall.utils.lang.UnsafeUtils; @@ -97,8 +96,8 @@ public final class HeapBuffer { Preconditions.checkArgument(bytes <= _chunkBytes, "Cannot allocate memory greater than %s bytes: %s", _chunkBytes, bytes); - int i = Primitives.castToInt(_position / _chunkBytes); - final int j = Primitives.castToInt(_position % _chunkBytes); + int i = NumberUtils.castToInt(_position / _chunkBytes); + final int j = NumberUtils.castToInt(_position % _chunkBytes); if (bytes > (_chunkBytes - j)) { // cannot allocate the object in the current chunk // so, skip the current chunk @@ -144,7 +143,7 @@ public final class HeapBuffer { public byte getByte(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getByte(chunk, j); @@ -152,7 +151,7 @@ public final class HeapBuffer { public void putByte(final long ptr, final byte value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putByte(chunk, j, value); @@ -160,7 +159,7 @@ public final class HeapBuffer { public int getInt(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getInt(chunk, j); @@ -168,7 +167,7 @@ public final class HeapBuffer { public void putInt(final long ptr, final int value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putInt(chunk, j, value); @@ -176,7 +175,7 @@ public final class HeapBuffer { public short getShort(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getShort(chunk, j); @@ -184,7 +183,7 @@ public final class HeapBuffer { public void putShort(final long ptr, final short value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putShort(chunk, j, value); @@ -192,7 +191,7 @@ public final class HeapBuffer { public char getChar(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getChar(chunk, j); @@ -200,14 +199,14 @@ public final class HeapBuffer { public void putChar(final long ptr, final char value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putChar(chunk, j, value); } public long getLong(final long ptr) { - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getLong(chunk, j); @@ -215,7 +214,7 @@ public final class HeapBuffer { public void putLong(final long ptr, final long value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putLong(chunk, j, value); @@ -223,7 +222,7 @@ public final class HeapBuffer { public float getFloat(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getFloat(chunk, j); @@ -231,7 +230,7 @@ public final class HeapBuffer { public void putFloat(final long ptr, final float value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putFloat(chunk, j, value); @@ -239,7 +238,7 @@ public final class HeapBuffer { public double getDouble(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getDouble(chunk, j); @@ -247,7 +246,7 @@ public final class HeapBuffer { public void putDouble(final long ptr, final double value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putDouble(chunk, j, value); @@ -260,7 +259,7 @@ public final class HeapBuffer { throw new IllegalArgumentException("Cannot put empty array at " + ptr); } - int chunkIdx = Primitives.castToInt(ptr / _chunkBytes); + int chunkIdx = NumberUtils.castToInt(ptr / _chunkBytes); final int[] chunk = _chunks[chunkIdx]; final long base = offset(ptr); for (int i = 0; i < len; i++) { @@ -277,7 +276,7 @@ public final class HeapBuffer { throw new IllegalArgumentException("Cannot put empty array at " + ptr); } - int chunkIdx = Primitives.castToInt(ptr / _chunkBytes); + int chunkIdx = NumberUtils.castToInt(ptr / _chunkBytes); final int[] chunk = _chunks[chunkIdx]; final long base = offset(ptr); for (int i = 0; i < len; i++) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java index f847b15..e9b5c8a 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java @@ -27,8 +27,13 @@ import java.io.ObjectOutput; import java.util.Arrays; /** - * An open-addressing hash table with double hashing - * + * An open-addressing hash table using double hashing. + * + * <pre> + * Primary hash function: h1(k) = k mod m + * Secondary hash function: h2(k) = 1 + (k mod(m-2)) + * </pre> + * * @see http://en.wikipedia.org/wiki/Double_hashing */ public class Int2FloatOpenHashTable implements Externalizable { @@ -37,7 +42,7 @@ public class Int2FloatOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java index 5e9e812..8e87fce 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java @@ -27,7 +27,12 @@ import java.io.ObjectOutput; import java.util.Arrays; /** - * An open-addressing hash table with double hashing + * An open-addressing hash table using double hashing. + * + * <pre> + * Primary hash function: h1(k) = k mod m + * Secondary hash function: h2(k) = 1 + (k mod(m-2)) + * </pre> * * @see http://en.wikipedia.org/wiki/Double_hashing */ @@ -37,7 +42,7 @@ public final class Int2IntOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor;
