http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java new file mode 100644 index 0000000..1848529 --- /dev/null +++ b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.model; + +import hivemall.annotations.InternalAPI; +import hivemall.model.WeightValue.WeightValueWithCovar; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Copyable; +import hivemall.utils.lang.HalfFloat; +import hivemall.utils.math.MathUtils; + +import java.util.Arrays; + +import javax.annotation.Nonnull; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel { + private static final Log logger = LogFactory.getLog(NewSpaceEfficientDenseModel.class); + + private int size; + private short[] weights; + private short[] covars; + + // optional value for MIX + private short[] clocks; + private byte[] deltaUpdates; + + public NewSpaceEfficientDenseModel(int ndims) { + this(ndims, false); + } + + public NewSpaceEfficientDenseModel(int ndims, boolean withCovar) { + super(); + int size = ndims + 1; + this.size = size; + this.weights = new short[size]; + if (withCovar) { + short[] covars = new short[size]; + Arrays.fill(covars, HalfFloat.ONE); + this.covars = covars; + } else { + this.covars = null; + } + this.clocks = null; + this.deltaUpdates = null; + } + + @Override + protected boolean isDenseModel() { + return true; + } + + @Override + public boolean hasCovariance() { + return covars != null; + } + + @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) {} + + @Override + public void configureClock() { + if (clocks == null) { + this.clocks = new short[size]; + this.deltaUpdates = new byte[size]; + } + } + + @Override + public boolean hasClock() { + return clocks != null; + } + + @Override + public void resetDeltaUpdates(int feature) { + deltaUpdates[feature] = 0; + } + + private float getWeight(final int i) { + final short w = weights[i]; + return (w == HalfFloat.ZERO) ? HalfFloat.ZERO : HalfFloat.halfFloatToFloat(w); + } + + private float getCovar(final int i) { + return HalfFloat.halfFloatToFloat(covars[i]); + } + + @InternalAPI + private void _setWeight(final int i, final float v) { + if (Math.abs(v) >= HalfFloat.MAX_FLOAT) { + throw new IllegalArgumentException("Acceptable maximum weight is " + + HalfFloat.MAX_FLOAT + ": " + v); + } + weights[i] = HalfFloat.floatToHalfFloat(v); + } + + private void setCovar(final int i, final float v) { + HalfFloat.checkRange(v); + covars[i] = HalfFloat.floatToHalfFloat(v); + } + + private void ensureCapacity(final int index) { + if (index >= size) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + int oldSize = size; + logger.info("Expands internal array size from " + oldSize + " to " + newSize + " (" + + bits + " bits)"); + this.size = newSize; + this.weights = Arrays.copyOf(weights, newSize); + if (covars != null) { + this.covars = Arrays.copyOf(covars, newSize); + Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE); + } + if (clocks != null) { + this.clocks = Arrays.copyOf(clocks, newSize); + this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); + } + } + } + + @SuppressWarnings("unchecked") + @Override + public <T extends IWeightValue> T get(@Nonnull final Object feature) { + final int i = HiveUtils.parseInt(feature); + if (i >= size) { + return null; + } + + if (covars != null) { + return (T) new WeightValueWithCovar(getWeight(i), getCovar(i)); + } else { + return (T) new WeightValue(getWeight(i)); + } + } + + @Override + public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + float weight = value.get(); + _setWeight(i, weight); + float covar = 1.f; + boolean hasCovar = value.hasCovariance(); + if (hasCovar) { + covar = value.getCovariance(); + setCovar(i, covar); + } + short clock = 0; + int delta = 0; + if (clocks != null && value.isTouched()) { + clock = (short) (clocks[i] + 1); + clocks[i] = clock; + delta = deltaUpdates[i] + 1; + assert (delta > 0) : delta; + deltaUpdates[i] = (byte) delta; + } + + onUpdate(i, weight, covar, clock, delta, hasCovar); + } + + @Override + public void delete(@Nonnull final Object feature) { + final int i = HiveUtils.parseInt(feature); + if (i >= size) { + return; + } + _setWeight(i, 0.f); + if (covars != null) { + setCovar(i, 1.f); + } + // avoid clock/delta + } + + @Override + public float getWeight(@Nonnull final Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return 0f; + } + return getWeight(i); + } + + @Override + public void setWeight(@Nonnull final Object feature, final float value) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + _setWeight(i, value); + } + + @Override + public float getCovariance(@Nonnull final Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return 1f; + } + return getCovar(i); + } + + @Override + protected void _set(@Nonnull final Object feature, final float weight, final short clock) { + int i = ((Integer) feature).intValue(); + ensureCapacity(i); + _setWeight(i, weight); + clocks[i] = clock; + deltaUpdates[i] = 0; + } + + @Override + protected void _set(@Nonnull final Object feature, final float weight, final float covar, + final short clock) { + int i = ((Integer) feature).intValue(); + ensureCapacity(i); + _setWeight(i, weight); + setCovar(i, covar); + clocks[i] = clock; + deltaUpdates[i] = 0; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean contains(@Nonnull final Object feature) { + int i = HiveUtils.parseInt(feature); + if (i >= size) { + return false; + } + float w = getWeight(i); + return w != 0.f; + } + + @SuppressWarnings("unchecked") + @Override + public <K, V extends IWeightValue> IMapIterator<K, V> entries() { + return (IMapIterator<K, V>) new Itr(); + } + + private final class Itr implements IMapIterator<Number, IWeightValue> { + + private int cursor; + private final WeightValueWithCovar tmpWeight; + + private Itr() { + this.cursor = -1; + this.tmpWeight = new WeightValueWithCovar(); + } + + @Override + public boolean hasNext() { + return cursor < size; + } + + @Override + public int next() { + ++cursor; + if (!hasNext()) { + return -1; + } + return cursor; + } + + @Override + public Integer getKey() { + return cursor; + } + + @Override + public IWeightValue getValue() { + if (covars == null) { + float w = getWeight(cursor); + WeightValue v = new WeightValue(w); + v.setTouched(w != 0f); + return v; + } else { + float w = getWeight(cursor); + float cov = getCovar(cursor); + WeightValueWithCovar v = new WeightValueWithCovar(w, cov); + v.setTouched(w != 0.f || cov != 1.f); + return v; + } + } + + @Override + public <T extends Copyable<IWeightValue>> void getValue(@Nonnull final T probe) { + float w = getWeight(cursor); + tmpWeight.value = w; + float cov = 1.f; + if (covars != null) { + cov = getCovar(cursor); + tmpWeight.setCovariance(cov); + } + tmpWeight.setTouched(w != 0.f || cov != 1.f); + probe.copyFrom(tmpWeight); + } + + } + +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/NewSparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java new file mode 100644 index 0000000..e312ae4 --- /dev/null +++ b/core/src/main/java/hivemall/model/NewSparseModel.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.model; + +import hivemall.model.WeightValueWithClock.WeightValueParamsF1Clock; +import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock; +import hivemall.model.WeightValueWithClock.WeightValueParamsF3Clock; +import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.OpenHashMap; + +import javax.annotation.Nonnull; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public final class NewSparseModel extends AbstractPredictionModel { + private static final Log logger = LogFactory.getLog(NewSparseModel.class); + + private final OpenHashMap<Object, IWeightValue> weights; + private final boolean hasCovar; + private boolean clockEnabled; + + public NewSparseModel(int size) { + this(size, false); + } + + public NewSparseModel(int size, boolean hasCovar) { + super(); + this.weights = new OpenHashMap<Object, IWeightValue>(size); + this.hasCovar = hasCovar; + this.clockEnabled = false; + } + + @Override + protected boolean isDenseModel() { + return false; + } + + @Override + public boolean hasCovariance() { + return hasCovar; + } + + @Override + public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x, + boolean sum_of_gradients) {} + + @Override + public void configureClock() { + this.clockEnabled = true; + } + + @Override + public boolean hasClock() { + return clockEnabled; + } + + @SuppressWarnings("unchecked") + @Override + public <T extends IWeightValue> T get(@Nonnull final Object feature) { + return (T) weights.get(feature); + } + + @Override + public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) { + assert (feature != null); + assert (value != null); + + final IWeightValue wrapperValue = wrapIfRequired(value); + + if (clockEnabled && value.isTouched()) { + IWeightValue old = weights.get(feature); + if (old != null) { + short newclock = (short) (old.getClock() + (short) 1); + wrapperValue.setClock(newclock); + int newDelta = old.getDeltaUpdates() + 1; + wrapperValue.setDeltaUpdates((byte) newDelta); + } + } + weights.put(feature, wrapperValue); + + onUpdate(feature, wrapperValue); + } + + @Override + public void delete(@Nonnull final Object feature) { + weights.remove(feature); + } + + @Nonnull + private IWeightValue wrapIfRequired(@Nonnull final IWeightValue value) { + final IWeightValue wrapper; + if (clockEnabled) { + switch (value.getType()) { + case NoParams: + wrapper = new WeightValueWithClock(value); + break; + case ParamsCovar: + wrapper = new WeightValueWithCovarClock(value); + break; + case ParamsF1: + wrapper = new WeightValueParamsF1Clock(value); + break; + case ParamsF2: + wrapper = new WeightValueParamsF2Clock(value); + break; + case ParamsF3: + wrapper = new WeightValueParamsF3Clock(value); + break; + default: + throw new IllegalStateException("Unexpected value type: " + value.getType()); + } + } else { + wrapper = value; + } + return wrapper; + } + + @Override + public float getWeight(@Nonnull final Object feature) { + IWeightValue v = weights.get(feature); + return v == null ? 0.f : v.get(); + } + + @Override + public void setWeight(@Nonnull final Object feature, final float value) { + if (weights.containsKey(feature)) { + IWeightValue weight = weights.get(feature); + weight.set(value); + } else { + IWeightValue weight = new WeightValue(value); + weight.set(value); + weights.put(feature, weight); + } + } + + @Override + public float getCovariance(@Nonnull final Object feature) { + IWeightValue v = weights.get(feature); + return v == null ? 1.f : v.getCovariance(); + } + + @Override + protected void _set(@Nonnull final Object feature, final float weight, final short clock) { + final IWeightValue w = weights.get(feature); + if (w == null) { + logger.warn("Previous weight not found: " + feature); + throw new IllegalStateException("Previous weight not found " + feature); + } + w.set(weight); + w.setClock(clock); + w.setDeltaUpdates(BYTE0); + } + + @Override + protected void _set(@Nonnull final Object feature, final float weight, final float covar, + final short clock) { + final IWeightValue w = weights.get(feature); + if (w == null) { + logger.warn("Previous weight not found: " + feature); + throw new IllegalStateException("Previous weight not found: " + feature); + } + w.set(weight); + w.setCovariance(covar); + w.setClock(clock); + w.setDeltaUpdates(BYTE0); + } + + @Override + public int size() { + return weights.size(); + } + + @Override + public boolean contains(@Nonnull final Object feature) { + return weights.containsKey(feature); + } + + @SuppressWarnings("unchecked") + @Override + public <K, V extends IWeightValue> IMapIterator<K, V> entries() { + return (IMapIterator<K, V>) weights.entries(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/PredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/PredictionModel.java b/core/src/main/java/hivemall/model/PredictionModel.java index 71f67d5..3b6e766 100644 --- a/core/src/main/java/hivemall/model/PredictionModel.java +++ b/core/src/main/java/hivemall/model/PredictionModel.java @@ -26,9 +26,10 @@ import javax.annotation.Nullable; public interface PredictionModel extends MixedModel { + @Nullable ModelUpdateHandler getUpdateHandler(); - void configureMix(ModelUpdateHandler handler, boolean cancelMixRequest); + void configureMix(@Nonnull ModelUpdateHandler handler, boolean cancelMixRequest); long getNumMixed(); @@ -56,6 +57,8 @@ public interface PredictionModel extends MixedModel { float getWeight(@Nonnull Object feature); + void setWeight(@Nonnull Object feature, float value); + float getCovariance(@Nonnull Object feature); <K, V extends IWeightValue> IMapIterator<K, V> entries(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java index 46f5d6e..a638939 100644 --- a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java +++ b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java @@ -167,7 +167,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { @SuppressWarnings("unchecked") @Override - public <T extends IWeightValue> T get(Object feature) { + public <T extends IWeightValue> T get(@Nonnull final Object feature) { final int i = HiveUtils.parseInt(feature); if (i >= size) { return null; @@ -190,7 +190,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - public <T extends IWeightValue> void set(Object feature, T value) { + public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); float weight = value.get(); @@ -224,7 +224,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - public void delete(@Nonnull Object feature) { + public void delete(@Nonnull final Object feature) { final int i = HiveUtils.parseInt(feature); if (i >= size) { return; @@ -246,7 +246,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - public float getWeight(Object feature) { + public float getWeight(@Nonnull final Object feature) { int i = HiveUtils.parseInt(feature); if (i >= size) { return 0f; @@ -255,7 +255,12 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - public float getCovariance(Object feature) { + public void setWeight(@Nonnull Object feature, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getCovariance(@Nonnull final Object feature) { int i = HiveUtils.parseInt(feature); if (i >= size) { return 1f; @@ -264,7 +269,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - protected void _set(Object feature, float weight, short clock) { + protected void _set(@Nonnull final Object feature, final float weight, final short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); setWeight(i, weight); @@ -273,7 +278,8 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - protected void _set(Object feature, float weight, float covar, short clock) { + protected void _set(@Nonnull final Object feature, final float weight, final float covar, + final short clock) { int i = ((Integer) feature).intValue(); ensureCapacity(i); setWeight(i, weight); @@ -288,7 +294,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - public boolean contains(Object feature) { + public boolean contains(@Nonnull final Object feature) { int i = HiveUtils.parseInt(feature); if (i >= size) { return false; @@ -349,7 +355,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel { } @Override - public <T extends Copyable<IWeightValue>> void getValue(T probe) { + public <T extends Copyable<IWeightValue>> void getValue(@Nonnull final T probe) { float w = getWeight(cursor); tmpWeight.value = w; float cov = 1.f; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/SparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java index a2b4708..ec26552 100644 --- a/core/src/main/java/hivemall/model/SparseModel.java +++ b/core/src/main/java/hivemall/model/SparseModel.java @@ -20,6 +20,7 @@ package hivemall.model; import hivemall.model.WeightValueWithClock.WeightValueParamsF1Clock; import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock; +import hivemall.model.WeightValueWithClock.WeightValueParamsF3Clock; import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock; import hivemall.utils.collections.IMapIterator; import hivemall.utils.collections.maps.OpenHashMap; @@ -69,12 +70,12 @@ public final class SparseModel extends AbstractPredictionModel { @SuppressWarnings("unchecked") @Override - public <T extends IWeightValue> T get(final Object feature) { + public <T extends IWeightValue> T get(@Nonnull final Object feature) { return (T) weights.get(feature); } @Override - public <T extends IWeightValue> void set(final Object feature, final T value) { + public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) { assert (feature != null); assert (value != null); @@ -99,7 +100,8 @@ public final class SparseModel extends AbstractPredictionModel { weights.remove(feature); } - private IWeightValue wrapIfRequired(final IWeightValue value) { + @Nonnull + private IWeightValue wrapIfRequired(@Nonnull final IWeightValue value) { final IWeightValue wrapper; if (clockEnabled) { switch (value.getType()) { @@ -115,6 +117,9 @@ public final class SparseModel extends AbstractPredictionModel { case ParamsF2: wrapper = new WeightValueParamsF2Clock(value); break; + case ParamsF3: + wrapper = new WeightValueParamsF3Clock(value); + break; default: throw new IllegalStateException("Unexpected value type: " + value.getType()); } @@ -125,19 +130,24 @@ public final class SparseModel extends AbstractPredictionModel { } @Override - public float getWeight(final Object feature) { + public float getWeight(@Nonnull final Object feature) { IWeightValue v = weights.get(feature); return v == null ? 0.f : v.get(); } @Override - public float getCovariance(final Object feature) { + public void setWeight(@Nonnull Object feature, float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getCovariance(@Nonnull final Object feature) { IWeightValue v = weights.get(feature); return v == null ? 1.f : v.getCovariance(); } @Override - protected void _set(final Object feature, final float weight, final short clock) { + protected void _set(@Nonnull final Object feature, final float weight, final short clock) { final IWeightValue w = weights.get(feature); if (w == null) { logger.warn("Previous weight not found: " + feature); @@ -149,7 +159,7 @@ public final class SparseModel extends AbstractPredictionModel { } @Override - protected void _set(final Object feature, final float weight, final float covar, + protected void _set(@Nonnull final Object feature, final float weight, final float covar, final short clock) { final IWeightValue w = weights.get(feature); if (w == null) { @@ -168,7 +178,7 @@ public final class SparseModel extends AbstractPredictionModel { } @Override - public boolean contains(final Object feature) { + public boolean contains(@Nonnull final Object feature) { return weights.containsKey(feature); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java index 67d05e5..5c2ded1 100644 --- a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java +++ b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java @@ -107,7 +107,7 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override - public boolean contains(Object feature) { + public boolean contains(@Nonnull final Object feature) { try { lock.lock(); return model.contains(feature); @@ -117,7 +117,7 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override - public <T extends IWeightValue> T get(Object feature) { + public <T extends IWeightValue> T get(@Nonnull final Object feature) { try { lock.lock(); return model.get(feature); @@ -127,7 +127,7 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override - public <T extends IWeightValue> void set(Object feature, T value) { + public <T extends IWeightValue> void set(@Nonnull final Object feature, @Nonnull final T value) { try { lock.lock(); model.set(feature, value); @@ -137,7 +137,7 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override - public void delete(@Nonnull Object feature) { + public void delete(@Nonnull final Object feature) { try { lock.lock(); model.delete(feature); @@ -147,7 +147,7 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override - public float getWeight(Object feature) { + public float getWeight(@Nonnull final Object feature) { try { lock.lock(); return model.getWeight(feature); @@ -157,7 +157,17 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override - public float getCovariance(Object feature) { + public void setWeight(@Nonnull final Object feature, final float value) { + try { + lock.lock(); + model.setWeight(feature, value); + } finally { + lock.unlock(); + } + } + + @Override + public float getCovariance(@Nonnull final Object feature) { try { lock.lock(); return model.getCovariance(feature); @@ -167,7 +177,8 @@ public final class SynchronizedModelWrapper implements PredictionModel { } @Override - public void set(@Nonnull Object feature, float weight, float covar, short clock) { + public void set(@Nonnull final Object feature, final float weight, final float covar, + final short clock) { try { lock.lock(); model.set(feature, weight, covar, clock); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/WeightValue.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/WeightValue.java b/core/src/main/java/hivemall/model/WeightValue.java index 4e19fef..2fee76a 100644 --- a/core/src/main/java/hivemall/model/WeightValue.java +++ b/core/src/main/java/hivemall/model/WeightValue.java @@ -77,15 +77,50 @@ public class WeightValue implements IWeightValue { } @Override + public void setSumOfSquaredGradients(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfSquaredDeltaX() { return 0.f; } @Override + public void setSumOfSquaredDeltaX(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfGradients() { return 0.f; } + @Override + public void setSumOfGradients(float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getM() { + return 0.f; + } + + @Override + public void setM(float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getV() { + return 0.f; + } + + @Override + public void setV(float value) { + throw new UnsupportedOperationException(); + } + /** * @return whether touched in training or not */ @@ -137,7 +172,7 @@ public class WeightValue implements IWeightValue { } public static final class WeightValueParamsF1 extends WeightValue { - private final float f1; + private float f1; public WeightValueParamsF1(float weight, float f1) { super(weight); @@ -162,14 +197,19 @@ public class WeightValue implements IWeightValue { return f1; } + @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + } /** * WeightValue with Sum of Squared Gradients */ public static final class WeightValueParamsF2 extends WeightValue { - private final float f1; - private final float f2; + private float f1; + private float f2; public WeightValueParamsF2(float weight, float f1, float f2) { super(weight); @@ -198,15 +238,131 @@ public class WeightValue implements IWeightValue { } @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + + @Override public final float getSumOfSquaredDeltaX() { return f2; } @Override + public void setSumOfSquaredDeltaX(float value) { + this.f2 = value; + } + + @Override public float getSumOfGradients() { return f2; } + @Override + public void setSumOfGradients(float value) { + this.f2 = value; + } + + @Override + public float getM() { + return f1; + } + + @Override + public void setM(float value) { + this.f1 = value; + } + + @Override + public float getV() { + return f2; + } + + @Override + public void setV(float value) { + this.f2 = value; + } + + } + + public static final class WeightValueParamsF3 extends WeightValue { + private float f1; + private float f2; + private float f3; + + public WeightValueParamsF3(float weight, float f1, float f2, float f3) { + super(weight); + this.f1 = f1; + this.f2 = f2; + this.f3 = f3; + } + + @Override + public WeightValueType getType() { + return WeightValueType.ParamsF3; + } + + @Override + public float getFloatParams(@Nonnegative final int i) { + if(i == 1) { + return f1; + } else if(i == 2) { + return f2; + } else if (i == 3) { + return f3; + } + throw new IllegalArgumentException("getFloatParams(" + i + ") should not be called"); + } + + @Override + public final float getSumOfSquaredGradients() { + return f1; + } + + @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + + @Override + public final float getSumOfSquaredDeltaX() { + return f2; + } + + @Override + public void setSumOfSquaredDeltaX(float value) { + this.f2 = value; + } + + @Override + public float getSumOfGradients() { + return f3; + } + + @Override + public void setSumOfGradients(float value) { + this.f3 = value; + } + + @Override + public float getM() { + return f1; + } + + @Override + public void setM(float value) { + this.f1 = value; + } + + @Override + public float getV() { + return f2; + } + + @Override + public void setV(float value) { + this.f2 = value; + } + } public static final class WeightValueWithCovar extends WeightValue { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/model/WeightValueWithClock.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/WeightValueWithClock.java b/core/src/main/java/hivemall/model/WeightValueWithClock.java index e419c5d..524fa94 100644 --- a/core/src/main/java/hivemall/model/WeightValueWithClock.java +++ b/core/src/main/java/hivemall/model/WeightValueWithClock.java @@ -79,15 +79,50 @@ public class WeightValueWithClock implements IWeightValue { } @Override + public void setSumOfSquaredGradients(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfSquaredDeltaX() { return 0.f; } @Override + public void setSumOfSquaredDeltaX(float value) { + throw new UnsupportedOperationException(); + } + + @Override public float getSumOfGradients() { return 0.f; } + @Override + public void setSumOfGradients(float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getM() { + return 0.f; + } + + @Override + public void setM(float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float getV() { + return 0.f; + } + + @Override + public void setV(float value) { + throw new UnsupportedOperationException(); + } + /** * @return whether touched in training or not */ @@ -144,7 +179,7 @@ public class WeightValueWithClock implements IWeightValue { * WeightValue with Sum of Squared Gradients */ public static final class WeightValueParamsF1Clock extends WeightValueWithClock { - private final float f1; + private float f1; public WeightValueParamsF1Clock(float value, float f1) { super(value); @@ -174,11 +209,16 @@ public class WeightValueWithClock implements IWeightValue { return f1; } + @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + } public static final class WeightValueParamsF2Clock extends WeightValueWithClock { - private final float f1; - private final float f2; + private float f1; + private float f2; public WeightValueParamsF2Clock(float value, float f1, float f2) { super(value); @@ -213,15 +253,136 @@ public class WeightValueWithClock implements IWeightValue { } @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + + @Override + public float getSumOfSquaredDeltaX() { + return f2; + } + + @Override + public void setSumOfSquaredDeltaX(float value) { + this.f2 = value; + } + + @Override + public float getSumOfGradients() { + return f2; + } + + @Override + public void setSumOfGradients(float value) { + this.f2 = value; + } + @Override + public float getM() { + return f1; + } + + @Override + public void setM(float value) { + this.f1 = value; + } + + @Override + public float getV() { + return f2; + } + + @Override + public void setV(float value) { + this.f2 = value; + } + + } + + public static final class WeightValueParamsF3Clock extends WeightValueWithClock { + private float f1; + private float f2; + private float f3; + + public WeightValueParamsF3Clock(float value, float f1, float f2, float f3) { + super(value); + this.f1 = f1; + this.f2 = f2; + this.f3 = f3; + } + + public WeightValueParamsF3Clock(IWeightValue src) { + super(src); + this.f1 = src.getFloatParams(1); + this.f2 = src.getFloatParams(2); + this.f3 = src.getFloatParams(3); + } + + @Override + public WeightValueType getType() { + return WeightValueType.ParamsF3; + } + + @Override + public float getFloatParams(@Nonnegative final int i) { + if(i == 1) { + return f1; + } else if(i == 2) { + return f2; + } else if(i == 3) { + return f3; + } + throw new IllegalArgumentException("getFloatParams(" + i + ") should not be called"); + } + + @Override + public float getSumOfSquaredGradients() { + return f1; + } + + @Override + public void setSumOfSquaredGradients(float value) { + this.f1 = value; + } + + @Override public float getSumOfSquaredDeltaX() { return f2; } @Override + public void setSumOfSquaredDeltaX(float value) { + this.f2 = value; + } + + @Override public float getSumOfGradients() { + return f3; + } + + @Override + public void setSumOfGradients(float value) { + this.f3 = value; + } + @Override + public float getM() { + return f1; + } + + @Override + public void setM(float value) { + this.f1 = value; + } + + @Override + public float getV() { return f2; } + @Override + public void setV(float value) { + this.f2 = value; + } + } public static final class WeightValueWithCovarClock extends WeightValueWithClock { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java new file mode 100644 index 0000000..2bf030b --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.optimizer; + +import hivemall.model.IWeightValue; +import hivemall.model.WeightValue; +import hivemall.optimizer.Optimizer.OptimizerBase; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.math.MathUtils; + +import java.util.Arrays; +import java.util.Map; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public final class DenseOptimizerFactory { + private static final Log logger = LogFactory.getLog(DenseOptimizerFactory.class); + + @Nonnull + public static Optimizer create(int ndims, @Nonnull Map<String, String> options) { + final String optimizerName = options.get("optimizer"); + if (optimizerName != null) { + OptimizerBase optimizerImpl; + if (optimizerName.toLowerCase().equals("sgd")) { + optimizerImpl = new Optimizer.SGD(options); + } else if (optimizerName.toLowerCase().equals("adadelta")) { + optimizerImpl = new AdaDelta(ndims, options); + } else if (optimizerName.toLowerCase().equals("adagrad")) { + optimizerImpl = new AdaGrad(ndims, options); + } else if (optimizerName.toLowerCase().equals("adam")) { + optimizerImpl = new Adam(ndims, options); + } else { + throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName); + } + + logger.info("set " + optimizerImpl.getClass().getSimpleName() + " as an optimizer: " + + options); + + // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`. + if (options.get("regularization") != null + && options.get("regularization").toLowerCase().equals("rda")) { + optimizerImpl = new AdagradRDA(ndims, optimizerImpl, options); + } + + return optimizerImpl; + } + throw new IllegalArgumentException("`optimizer` not defined"); + } + + @NotThreadSafe + static final class AdaDelta extends Optimizer.AdaDelta { + + @Nonnull + private final IWeightValue weightValueReused; + + @Nonnull + private float[] sum_of_squared_gradients; + @Nonnull + private float[] sum_of_squared_delta_x; + + public AdaDelta(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = new WeightValue.WeightValueParamsF2(0.f, 0.f, 0.f); + this.sum_of_squared_gradients = new float[ndims]; + this.sum_of_squared_delta_x = new float[ndims]; + } + + @Override + public float update(@Nonnull Object feature, float weight, float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]); + weightValueReused.setSumOfSquaredDeltaX(sum_of_squared_delta_x[i]); + update(weightValueReused, gradient); + sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients(); + sum_of_squared_delta_x[i] = weightValueReused.getSumOfSquaredDeltaX(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= sum_of_squared_gradients.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); + this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize); + } + } + + } + + @NotThreadSafe + static final class AdaGrad extends Optimizer.AdaGrad { + + private final IWeightValue weightValueReused; + + private float[] sum_of_squared_gradients; + + public AdaGrad(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = new WeightValue.WeightValueParamsF1(0.f, 0.f); + this.sum_of_squared_gradients = new float[ndims]; + } + + @Override + public float update(@Nonnull Object feature, float weight, float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]); + update(weightValueReused, gradient); + sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= sum_of_squared_gradients.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize); + } + } + + } + + @NotThreadSafe + static final class Adam extends Optimizer.Adam { + + @Nonnull + private final IWeightValue weightValueReused; + + @Nonnull + private float[] val_m; + @Nonnull + private float[] val_v; + + public Adam(int ndims, Map<String, String> options) { + super(options); + this.weightValueReused = new WeightValue.WeightValueParamsF2(0.f, 0.f, 0.f); + this.val_m = new float[ndims]; + this.val_v = new float[ndims]; + } + + @Override + public float update(@Nonnull Object feature, float weight, float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setM(val_m[i]); + weightValueReused.setV(val_v[i]); + update(weightValueReused, gradient); + val_m[i] = weightValueReused.getM(); + val_v[i] = weightValueReused.getV(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= val_m.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.val_m = Arrays.copyOf(val_m, newSize); + this.val_v = Arrays.copyOf(val_v, newSize); + } + } + + } + + @NotThreadSafe + static final class AdagradRDA extends Optimizer.AdagradRDA { + + @Nonnull + private final IWeightValue weightValueReused; + + @Nonnull + private float[] sum_of_gradients; + + public AdagradRDA(int ndims, final OptimizerBase optimizerImpl, Map<String, String> options) { + super(optimizerImpl, options); + this.weightValueReused = new WeightValue.WeightValueParamsF3(0.f, 0.f, 0.f, 0.f); + this.sum_of_gradients = new float[ndims]; + } + + @Override + public float update(@Nonnull Object feature, float weight, float gradient) { + int i = HiveUtils.parseInt(feature); + ensureCapacity(i); + weightValueReused.set(weight); + weightValueReused.setSumOfGradients(sum_of_gradients[i]); + update(weightValueReused, gradient); + sum_of_gradients[i] = weightValueReused.getSumOfGradients(); + return weightValueReused.get(); + } + + private void ensureCapacity(final int index) { + if (index >= sum_of_gradients.length) { + int bits = MathUtils.bitsRequired(index); + int newSize = (1 << bits) + 1; + this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize); + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/optimizer/EtaEstimator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/EtaEstimator.java b/core/src/main/java/hivemall/optimizer/EtaEstimator.java new file mode 100644 index 0000000..17b39d1 --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/EtaEstimator.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.optimizer; + +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; + +import java.util.Map; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; + +public abstract class EtaEstimator { + + protected final float eta0; + + public EtaEstimator(float eta0) { + this.eta0 = eta0; + } + + public float eta0() { + return eta0; + } + + public abstract float eta(long t); + + public void update(@Nonnegative float multipler) {} + + public static final class FixedEtaEstimator extends EtaEstimator { + + public FixedEtaEstimator(float eta) { + super(eta); + } + + @Override + public float eta(long t) { + return eta0; + } + + } + + public static final class SimpleEtaEstimator extends EtaEstimator { + + private final float finalEta; + private final double total_steps; + + public SimpleEtaEstimator(float eta0, long total_steps) { + super(eta0); + this.finalEta = (float) (eta0 / 2.d); + this.total_steps = total_steps; + } + + @Override + public float eta(final long t) { + if (t > total_steps) { + return finalEta; + } + return (float) (eta0 / (1.d + (t / total_steps))); + } + + } + + public static final class InvscalingEtaEstimator extends EtaEstimator { + + private final double power_t; + + public InvscalingEtaEstimator(float eta0, double power_t) { + super(eta0); + this.power_t = power_t; + } + + @Override + public float eta(final long t) { + return (float) (eta0 / Math.pow(t, power_t)); + } + + } + + /** + * bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic gradient descent, KDD 2011. + */ + public static final class AdjustingEtaEstimator extends EtaEstimator { + + private float eta; + + public AdjustingEtaEstimator(float eta) { + super(eta); + this.eta = eta; + } + + @Override + public float eta(long t) { + return eta; + } + + @Override + public void update(@Nonnegative float multipler) { + float newEta = eta * multipler; + if (!NumberUtils.isFinite(newEta)) { + // avoid NaN or INFINITY + return; + } + this.eta = Math.min(eta0, newEta); // never be larger than eta0 + } + + } + + @Nonnull + public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException { + return get(cl, 0.1f); + } + + @Nonnull + public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0) + throws UDFArgumentException { + if (cl == null) { + return new InvscalingEtaEstimator(defaultEta0, 0.1d); + } + + if (cl.hasOption("boldDriver")) { + float eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.3f); + return new AdjustingEtaEstimator(eta); + } + + String etaValue = cl.getOptionValue("eta"); + if (etaValue != null) { + float eta = Float.parseFloat(etaValue); + return new FixedEtaEstimator(eta); + } + + float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0); + if (cl.hasOption("t")) { + long t = Long.parseLong(cl.getOptionValue("t")); + return new SimpleEtaEstimator(eta0, t); + } + + double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), 0.1d); + return new InvscalingEtaEstimator(eta0, power_t); + } + + @Nonnull + public static EtaEstimator get(@Nonnull final Map<String, String> options) + throws IllegalArgumentException { + final String etaScheme = options.get("eta"); + if (etaScheme == null) { + return new InvscalingEtaEstimator(0.1f, 0.1d); + } + + float eta0 = 0.1f; + if (options.containsKey("eta0")) { + eta0 = Float.parseFloat(options.get("eta0")); + } + + if ("fixed".equalsIgnoreCase(etaScheme)) { + return new FixedEtaEstimator(eta0); + } else if ("simple".equalsIgnoreCase(etaScheme)) { + final long t; + if (options.containsKey("total_steps")) { + t = Long.parseLong(options.get("total_steps")); + } else { + throw new IllegalArgumentException( + "-total_steps MUST be provided when `-eta simple` is specified"); + } + return new SimpleEtaEstimator(eta0, t); + } else if ("inv".equalsIgnoreCase(etaScheme) || "inverse".equalsIgnoreCase(etaScheme)) { + double power_t = 0.1; + if (options.containsKey("power_t")) { + power_t = Double.parseDouble(options.get("power_t")); + } + return new InvscalingEtaEstimator(eta0, power_t); + } else { + throw new IllegalArgumentException("Unsupported ETA name: " + etaScheme); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/optimizer/LossFunctions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/LossFunctions.java b/core/src/main/java/hivemall/optimizer/LossFunctions.java new file mode 100644 index 0000000..0dff4aa --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java @@ -0,0 +1,609 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.optimizer; + +import hivemall.utils.math.MathUtils; + +/** + * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions + */ +public final class LossFunctions { + + public enum LossType { + SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss, HingeLoss, LogLoss, + SquaredHingeLoss, ModifiedHuberLoss + } + + public static LossFunction getLossFunction(String type) { + if ("SquaredLoss".equalsIgnoreCase(type)) { + return new SquaredLoss(); + } else if ("QuantileLoss".equalsIgnoreCase(type)) { + return new QuantileLoss(); + } else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) { + return new EpsilonInsensitiveLoss(); + } else if ("HuberLoss".equalsIgnoreCase(type)) { + return new HuberLoss(); + } else if ("HingeLoss".equalsIgnoreCase(type)) { + return new HingeLoss(); + } else if ("LogLoss".equalsIgnoreCase(type)) { + return new LogLoss(); + } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) { + return new SquaredHingeLoss(); + } else if ("ModifiedHuberLoss".equalsIgnoreCase(type)) { + return new ModifiedHuberLoss(); + } + throw new IllegalArgumentException("Unsupported loss function name: " + type); + } + + public static LossFunction getLossFunction(LossType type) { + switch (type) { + case SquaredLoss: + return new SquaredLoss(); + case QuantileLoss: + return new QuantileLoss(); + case EpsilonInsensitiveLoss: + return new EpsilonInsensitiveLoss(); + case HuberLoss: + return new HuberLoss(); + case HingeLoss: + return new HingeLoss(); + case LogLoss: + return new LogLoss(); + case SquaredHingeLoss: + return new SquaredHingeLoss(); + case ModifiedHuberLoss: + return new ModifiedHuberLoss(); + default: + throw new IllegalArgumentException("Unsupported loss function name: " + type); + } + } + + public interface LossFunction { + + /** + * Evaluate the loss function. + * + * @param p The prediction, p = w^T x + * @param y The true value (aka target) + * @return The loss evaluated at `p` and `y`. + */ + public float loss(float p, float y); + + public double loss(double p, double y); + + /** + * Evaluate the derivative of the loss function with respect to the prediction `p`. + * + * @param p The prediction, p = w^T x + * @param y The true value (aka target) + * @return The derivative of the loss function w.r.t. `p`. + */ + public float dloss(float p, float y); + + public boolean forBinaryClassification(); + + public boolean forRegression(); + + public LossType getType(); + + } + + public static abstract class RegressionLoss implements LossFunction { + + @Override + public boolean forBinaryClassification() { + return false; + } + + @Override + public boolean forRegression() { + return true; + } + } + + public static abstract class BinaryLoss implements LossFunction { + + protected static void checkTarget(float y) { + if (!(y == 1.f || y == -1.f)) { + throw new IllegalArgumentException("target must be [+1,-1]: " + y); + } + } + + protected static void checkTarget(double y) { + if (!(y == 1.d || y == -1.d)) { + throw new IllegalArgumentException("target must be [+1,-1]: " + y); + } + } + + @Override + public boolean forBinaryClassification() { + return true; + } + + @Override + public boolean forRegression() { + return false; + } + } + + /** + * Squared loss for regression problems. + * + * If you're trying to minimize the mean error, use squared-loss. + */ + public static final class SquaredLoss extends RegressionLoss { + + @Override + public float loss(float p, float y) { + final float z = p - y; + return z * z * 0.5f; + } + + @Override + public double loss(double p, double y) { + final double z = p - y; + return z * z * 0.5d; + } + + @Override + public float dloss(float p, float y) { + return p - y; // 2 (p - y) / 2 + } + + @Override + public LossType getType() { + return LossType.SquaredLoss; + } + } + + /** + * Quantile loss is useful to predict rank/order and you do not mind the mean error to increase as long as you get the relative order correct. + * + * @link http://en.wikipedia.org/wiki/Quantile_regression + */ + public static final class QuantileLoss extends RegressionLoss { + + private float tau; + + public QuantileLoss() { + this.tau = 0.5f; + } + + public QuantileLoss(float tau) { + setTau(tau); + } + + public void setTau(float tau) { + if (tau <= 0 || tau >= 1.0) { + throw new IllegalArgumentException("tau must be in range (0, 1): " + tau); + } + this.tau = tau; + } + + @Override + public float loss(float p, float y) { + float e = y - p; + if (e > 0.f) { + return tau * e; + } else { + return -(1.f - tau) * e; + } + } + + @Override + public double loss(double p, double y) { + double e = y - p; + if (e > 0.d) { + return tau * e; + } else { + return -(1.d - tau) * e; + } + } + + @Override + public float dloss(float p, float y) { + float e = y - p; + if (e == 0.f) { + return 0.f; + } + return (e > 0.f) ? -tau : (1.f - tau); + } + + @Override + public LossType getType() { + return LossType.QuantileLoss; + } + } + + /** + * Epsilon-Insensitive loss used by Support Vector Regression (SVR). <code>loss = max(0, |y - p| - epsilon)</code> + */ + public static final class EpsilonInsensitiveLoss extends RegressionLoss { + + private float epsilon; + + public EpsilonInsensitiveLoss() { + this(0.1f); + } + + public EpsilonInsensitiveLoss(float epsilon) { + this.epsilon = epsilon; + } + + public void setEpsilon(float epsilon) { + this.epsilon = epsilon; + } + + @Override + public float loss(float p, float y) { + float loss = Math.abs(y - p) - epsilon; + return (loss > 0.f) ? loss : 0.f; + } + + @Override + public double loss(double p, double y) { + double loss = Math.abs(y - p) - epsilon; + return (loss > 0.d) ? loss : 0.d; + } + + @Override + public float dloss(float p, float y) { + if ((y - p) > epsilon) {// real value > predicted value - epsilon + return -1.f; + } + if ((p - y) > epsilon) {// real value < predicted value - epsilon + return 1.f; + } + return 0.f; + } + + @Override + public LossType getType() { + return LossType.EpsilonInsensitiveLoss; + } + } + + /** + * Huber regression loss. + * + * Variant of the SquaredLoss which is robust to outliers. + * + * @link https://en.wikipedia.org/wiki/Huber_Loss_Function + */ + public static final class HuberLoss extends RegressionLoss { + + private float c; + + public HuberLoss() { + this(1.f); // i.e., beyond 1 standard deviation, the loss becomes linear + } + + public HuberLoss(float c) { + this.c = c; + } + + public void setC(float c) { + this.c = c; + } + + @Override + public float loss(float p, float y) { + final float r = p - y; + final float rAbs = Math.abs(r); + if (rAbs <= c) { + return 0.5f * r * r; + } + return c * rAbs - (0.5f * c * c); + } + + @Override + public double loss(double p, double y) { + final double r = p - y; + final double rAbs = Math.abs(r); + if (rAbs <= c) { + return 0.5d * r * r; + } + return c * rAbs - (0.5d * c * c); + } + + @Override + public float dloss(float p, float y) { + final float r = p - y; + final float rAbs = Math.abs(r); + if (rAbs <= c) { + return r; + } else if (r > 0.f) { + return c; + } + return -c; + } + + @Override + public LossType getType() { + return LossType.HuberLoss; + } + } + + /** + * Hinge loss for binary classification tasks with y in {-1,1}. + */ + public static final class HingeLoss extends BinaryLoss { + + private float threshold; + + public HingeLoss() { + this(1.f); + } + + /** + * @param threshold Margin threshold. When threshold=1.0, one gets the loss used by SVM. When threshold=0.0, one gets the loss used by the + * Perceptron. + */ + public HingeLoss(float threshold) { + this.threshold = threshold; + } + + public void setThreshold(float threshold) { + this.threshold = threshold; + } + + @Override + public float loss(float p, float y) { + float loss = hingeLoss(p, y, threshold); + return (loss > 0.f) ? loss : 0.f; + } + + @Override + public double loss(double p, double y) { + double loss = hingeLoss(p, y, threshold); + return (loss > 0.d) ? loss : 0.d; + } + + @Override + public float dloss(float p, float y) { + float loss = hingeLoss(p, y, threshold); + return (loss > 0.f) ? -y : 0.f; + } + + @Override + public LossType getType() { + return LossType.HingeLoss; + } + } + + /** + * Logistic regression loss for binary classification with y in {-1, 1}. + */ + public static final class LogLoss extends BinaryLoss { + + /** + * <code>logloss(p,y) = log(1+exp(-p*y))</code> + */ + @Override + public float loss(float p, float y) { + checkTarget(y); + + final float z = y * p; + if (z > 18.f) { + return (float) Math.exp(-z); + } + if (z < -18.f) { + return -z; + } + return (float) Math.log(1.d + Math.exp(-z)); + } + + @Override + public double loss(double p, double y) { + checkTarget(y); + + final double z = y * p; + if (z > 18.d) { + return Math.exp(-z); + } + if (z < -18.d) { + return -z; + } + return Math.log(1.d + Math.exp(-z)); + } + + @Override + public float dloss(float p, float y) { + checkTarget(y); + + float z = y * p; + if (z > 18.f) { + return (float) Math.exp(-z) * -y; + } + if (z < -18.f) { + return -y; + } + return -y / ((float) Math.exp(z) + 1.f); + } + + @Override + public LossType getType() { + return LossType.LogLoss; + } + } + + /** + * Squared Hinge loss for binary classification tasks with y in {-1,1}. + */ + public static final class SquaredHingeLoss extends BinaryLoss { + + @Override + public float loss(float p, float y) { + return squaredHingeLoss(p, y); + } + + @Override + public double loss(double p, double y) { + return squaredHingeLoss(p, y); + } + + @Override + public float dloss(float p, float y) { + checkTarget(y); + + float d = 1 - (y * p); + return (d > 0.f) ? -2.f * d * y : 0.f; + } + + @Override + public LossType getType() { + return LossType.SquaredHingeLoss; + } + } + + /** + * Modified Huber loss for binary classification with y in {-1, 1}. + * + * Equivalent to quadratically smoothed SVM with gamma = 2. + */ + public static final class ModifiedHuberLoss extends BinaryLoss { + + @Override + public float loss(float p, float y) { + final float z = p * y; + if (z >= 1.f) { + return 0.f; + } else if (z >= -1.f) { + return (1.f - z) * (1.f - z); + } + return -4.f * z; + } + + @Override + public double loss(double p, double y) { + final double z = p * y; + if (z >= 1.d) { + return 0.d; + } else if (z >= -1.d) { + return (1.d - z) * (1.d - z); + } + return -4.d * z; + } + + @Override + public float dloss(float p, float y) { + final float z = p * y; + if (z >= 1.f) { + return 0.f; + } else if (z >= -1.f) { + return 2.f * (1.f - z) * -y; + } + return -4.f * y; + } + + @Override + public LossType getType() { + return LossType.ModifiedHuberLoss; + } + } + + public static float logisticLoss(final float target, final float predicted) { + if (predicted > -100.d) { + return target - (float) MathUtils.sigmoid(predicted); + } else { + return target; + } + } + + public static float logLoss(final float p, final float y) { + BinaryLoss.checkTarget(y); + + final float z = y * p; + if (z > 18.f) { + return (float) Math.exp(-z); + } + if (z < -18.f) { + return -z; + } + return (float) Math.log(1.d + Math.exp(-z)); + } + + public static double logLoss(final double p, final double y) { + BinaryLoss.checkTarget(y); + + final double z = y * p; + if (z > 18.d) { + return Math.exp(-z); + } + if (z < -18.d) { + return -z; + } + return Math.log(1.d + Math.exp(-z)); + } + + public static float squaredLoss(float p, float y) { + final float z = p - y; + return z * z * 0.5f; + } + + public static double squaredLoss(double p, double y) { + final double z = p - y; + return z * z * 0.5d; + } + + public static float hingeLoss(final float p, final float y, final float threshold) { + BinaryLoss.checkTarget(y); + + float z = y * p; + return threshold - z; + } + + public static double hingeLoss(final double p, final double y, final double threshold) { + BinaryLoss.checkTarget(y); + + double z = y * p; + return threshold - z; + } + + public static float hingeLoss(float p, float y) { + return hingeLoss(p, y, 1.f); + } + + public static double hingeLoss(double p, double y) { + return hingeLoss(p, y, 1.d); + } + + public static float squaredHingeLoss(final float p, final float y) { + BinaryLoss.checkTarget(y); + + float z = y * p; + float d = 1.f - z; + return (d > 0.f) ? (d * d) : 0.f; + } + + public static double squaredHingeLoss(final double p, final double y) { + BinaryLoss.checkTarget(y); + + double z = y * p; + double d = 1.d - z; + return (d > 0.d) ? d * d : 0.d; + } + + /** + * Math.abs(target - predicted) - epsilon + */ + public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) { + return Math.abs(target - predicted) - epsilon; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/optimizer/Optimizer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java new file mode 100644 index 0000000..ad70e61 --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.optimizer; + +import hivemall.model.IWeightValue; +import hivemall.model.WeightValue; + +import java.util.Map; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; + +public interface Optimizer { + + /** + * Update the weights of models + */ + float update(@Nonnull Object feature, float weight, float gradient); + + /** + * Count up #step to tune learning rate + */ + void proceedStep(); + + @NotThreadSafe + static abstract class OptimizerBase implements Optimizer { + + @Nonnull + protected final EtaEstimator _eta; + @Nonnull + protected final Regularization _reg; + @Nonnegative + protected int _numStep = 1; + + public OptimizerBase(final Map<String, String> options) { + this._eta = EtaEstimator.get(options); + this._reg = Regularization.get(options); + } + + @Override + public void proceedStep() { + _numStep++; + } + + /** + * Update the given weight by the given gradient. + */ + protected float update(@Nonnull final IWeightValue weight, float gradient) { + float g = _reg.regularize(weight.get(), gradient); + float delta = computeDelta(weight, g); + float newWeight = weight.get() - _eta.eta(_numStep) * delta; + weight.set(newWeight); + return newWeight; + } + + /** + * Compute a delta to update + */ + protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + return gradient; + } + + } + + static final class SGD extends OptimizerBase { + + private final IWeightValue weightValueReused; + + public SGD(final Map<String, String> options) { + super(options); + this.weightValueReused = new WeightValue(0.f); + } + + @Override + public float update(@Nonnull Object feature, float weight, float gradient) { + weightValueReused.set(weight); + update(weightValueReused, gradient); + return weightValueReused.get(); + } + + } + + static abstract class AdaGrad extends OptimizerBase { + + private final float eps; + private final float scale; + + public AdaGrad(Map<String, String> options) { + super(options); + float eps = 1.0f; + float scale = 100.0f; + if (options.containsKey("eps")) { + eps = Float.parseFloat(options.get("eps")); + } + if (options.containsKey("scale")) { + scale = Float.parseFloat(options.get("scale")); + } + this.eps = eps; + this.scale = scale; + } + + @Override + protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + float new_scaled_sum_sqgrad = weight.getSumOfSquaredGradients() + gradient + * (gradient / scale); + weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad); + return gradient / ((float) Math.sqrt(new_scaled_sum_sqgrad * scale) + eps); + } + + } + + static abstract class AdaDelta extends OptimizerBase { + + private final float decay; + private final float eps; + private final float scale; + + public AdaDelta(Map<String, String> options) { + super(options); + float decay = 0.95f; + float eps = 1e-6f; + float scale = 100.0f; + if (options.containsKey("decay")) { + decay = Float.parseFloat(options.get("decay")); + } + if (options.containsKey("eps")) { + eps = Float.parseFloat(options.get("eps")); + } + if (options.containsKey("scale")) { + scale = Float.parseFloat(options.get("scale")); + } + this.decay = decay; + this.eps = eps; + this.scale = scale; + } + + @Override + protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + float old_scaled_sum_sqgrad = weight.getSumOfSquaredGradients(); + float old_sum_squared_delta_x = weight.getSumOfSquaredDeltaX(); + float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad) + + ((1.f - decay) * gradient * (gradient / scale)); + float delta = (float) Math.sqrt((old_sum_squared_delta_x + eps) + / (new_scaled_sum_sqgrad * scale + eps)) + * gradient; + float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x) + + ((1.f - decay) * delta * delta); + weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad); + weight.setSumOfSquaredDeltaX(new_sum_squared_delta_x); + return delta; + } + + } + + /** + * Adam, an algorithm for first-order gradient-based optimization of stochastic objective functions, based on adaptive estimates of lower-order + * moments. + * + * - D. P. Kingma and J. L. Ba: "ADAM: A Method for Stochastic Optimization." arXiv preprint arXiv:1412.6980v8, 2014. + */ + static abstract class Adam extends OptimizerBase { + + private final float beta; + private final float gamma; + private final float eps_hat; + + public Adam(Map<String, String> options) { + super(options); + float beta = 0.9f; + float gamma = 0.999f; + float eps_hat = 1e-8f; + if (options.containsKey("beta")) { + beta = Float.parseFloat(options.get("beta")); + } + if (options.containsKey("gamma")) { + gamma = Float.parseFloat(options.get("gamma")); + } + if (options.containsKey("eps_hat")) { + eps_hat = Float.parseFloat(options.get("eps_hat")); + } + this.beta = beta; + this.gamma = gamma; + this.eps_hat = eps_hat; + } + + @Override + protected float computeDelta(@Nonnull final IWeightValue weight, float gradient) { + float val_m = beta * weight.getM() + (1.f - beta) * gradient; + float val_v = gamma * weight.getV() + (float) ((1.f - gamma) * Math.pow(gradient, 2.0)); + float val_m_hat = val_m / (float) (1.f - Math.pow(beta, _numStep)); + float val_v_hat = val_v / (float) (1.f - Math.pow(gamma, _numStep)); + float delta = val_m_hat / (float) (Math.sqrt(val_v_hat) + eps_hat); + weight.setM(val_m); + weight.setV(val_v); + return delta; + } + + } + + static abstract class AdagradRDA extends OptimizerBase { + + private final OptimizerBase optimizerImpl; + + private final float lambda; + + public AdagradRDA(final OptimizerBase optimizerImpl, Map<String, String> options) { + super(options); + // We assume `optimizerImpl` has the `AdaGrad` implementation only + if (!(optimizerImpl instanceof AdaGrad)) { + throw new IllegalArgumentException(optimizerImpl.getClass().getSimpleName() + + " currently does not support RDA regularization"); + } + float lambda = 1e-6f; + if (options.containsKey("lambda")) { + lambda = Float.parseFloat(options.get("lambda")); + } + this.optimizerImpl = optimizerImpl; + this.lambda = lambda; + } + + @Override + protected float update(@Nonnull final IWeightValue weight, float gradient) { + float new_sum_grad = weight.getSumOfGradients() + gradient; + // sign(u_{t,i}) + float sign = (new_sum_grad > 0.f) ? 1.f : -1.f; + // |u_{t,i}|/t - \lambda + float meansOfGradients = (sign * new_sum_grad / _numStep) - lambda; + if (meansOfGradients < 0.f) { + // x_{t,i} = 0 + weight.set(0.f); + weight.setSumOfSquaredGradients(0.f); + weight.setSumOfGradients(0.f); + return 0.f; + } else { + // x_{t,i} = -sign(u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda) + float newWeight = -1.f * sign * _eta.eta(_numStep) * _numStep + * optimizerImpl.computeDelta(weight, meansOfGradients); + weight.set(newWeight); + weight.setSumOfGradients(new_sum_grad); + return newWeight; + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3848ea60/core/src/main/java/hivemall/optimizer/OptimizerOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/optimizer/OptimizerOptions.java b/core/src/main/java/hivemall/optimizer/OptimizerOptions.java new file mode 100644 index 0000000..19fecb1 --- /dev/null +++ b/core/src/main/java/hivemall/optimizer/OptimizerOptions.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.optimizer; + +import java.util.HashMap; +import java.util.Map; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.Option; + +public final class OptimizerOptions { + + private OptimizerOptions() {} + + @Nonnull + public static Map<String, String> create() { + Map<String, String> opts = new HashMap<String, String>(); + opts.put("optimizer", "adagrad"); + opts.put("regularization", "RDA"); + return opts; + } + + public static void setup(@Nonnull Options opts) { + opts.addOption("opt", "optimizer", true, + "Optimizer to update weights [default: adagrad, sgd, adadelta, adam]"); + opts.addOption("eps", true, "Denominator value of AdaDelta/AdaGrad [default 1e-6]"); + opts.addOption("rho", "decay", true, "Decay rate of AdaDelta [default 0.95]"); + // regularization + opts.addOption("reg", "regularization", true, + "Regularization type [default: rda, l1, l2, elasticnet]"); + opts.addOption("l1_ratio", true, + "Ratio of L1 regularizer as a part of Elastic Net regularization [default: 0.5]"); + opts.addOption("lambda", true, "Regularization term [default 0.0001]"); + // learning rates + opts.addOption("eta", true, "Learning rate scheme [default: inverse/inv, fixed, simple]"); + opts.addOption("eta0", true, "The initial learning rate [default 0.1]"); + opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps"); + opts.addOption("power_t", true, + "The exponent for inverse scaling learning rate [default 0.1]"); + // other + opts.addOption("scale", true, "Scaling factor for cumulative weights [100.0]"); + } + + public static void propcessOptions(@Nullable CommandLine cl, + @Nonnull Map<String, String> options) { + if (cl != null) { + for (Option opt : cl.getOptions()) { + String optName = opt.getLongOpt(); + if (optName == null) { + optName = opt.getOpt(); + } + options.put(optName, opt.getValue()); + } + } + } + +}
