http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java new file mode 100644 index 0000000..ffa80d0 --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// +// Copyright (C) 2010 catchpole.net +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +package hivemall.utils.collections.maps; + +import hivemall.utils.hashing.HashUtils; +import hivemall.utils.math.MathUtils; + +import java.util.Arrays; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * A space efficient open-addressing HashMap implementation with integer keys and long values. + * + * Unlike {@link Int2LongOpenHashTable}, it maintains single arrays for keys and object references. + * + * It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those + * divisible by prime numbers. This allows the hash offset calculation to be a simple binary masking + * operation. + * + * The index into the arrays is determined by masking a portion of the key and shifting it to + * provide a series of small buckets within the array. To insert an entry the a sweep is searched + * until an empty key space is found. A sweep is 4 times the length of a bucket, to reduce the need + * to rehash. If no key space is found within a sweep, the table size is doubled. + * + * While performance is high, the slowest situation is where lookup occurs for entries that do not + * exist, as an entire sweep area must be searched. However, this HashMap is more space efficient + * than other open-addressing HashMap implementations as in fastutil. + */ +@NotThreadSafe +public final class Int2LongOpenHashMap { + + // special treatment for key=0 + private boolean hasKey0 = false; + private long value0 = 0L; + + private int[] keys; + private long[] values; + + // total number of entries in this table + private int size; + // number of bits for the value table (eg. 8 bits = 256 entries) + private int bits; + // the number of bits in each sweep zone. + private int sweepbits; + // the size of a sweep (2 to the power of sweepbits) + private int sweep; + // the sweepmask used to create sweep zone offsets + private int sweepmask; + + public Int2LongOpenHashMap(int size) { + resize(MathUtils.bitsRequired(size < 256 ? 256 : size)); + } + + public long put(final int key, final long value) { + if (key == 0) { + if (!hasKey0) { + this.hasKey0 = true; + size++; + } + long old = value0; + this.value0 = value; + return old; + } + + for (;;) { + int off = getBucketOffset(key); + final int end = off + sweep; + for (; off < end; off++) { + final int searchKey = keys[off]; + if (searchKey == 0) { // insert + keys[off] = key; + size++; + long previous = values[off]; + values[off] = value; + return previous; + } else if (searchKey == key) {// replace + long previous = values[off]; + values[off] = value; + return previous; + } + } + resize(this.bits + 1); + } + } + + public long putIfAbsent(final int key, final long value) { + if (key == 0) { + if (hasKey0) { + return value0; + } + this.hasKey0 = true; + long old = value0; + this.value0 = value; + size++; + return old; + } + + for (;;) { + int off = getBucketOffset(key); + final int end = off + sweep; + for (; off < end; off++) { + final int searchKey = keys[off]; + if (searchKey == 0) { // insert + keys[off] = key; + size++; + long previous = values[off]; + values[off] = value; + return previous; + } else if (searchKey == key) {// replace + return values[off]; + } + } + resize(this.bits + 1); + } + } + + public long get(final int key) { + return get(key, 0L); + } + + public long get(final int key, final long defaultValue) { + if (key == 0) { + return hasKey0 ? value0 : defaultValue; + } + + int off = getBucketOffset(key); + final int end = sweep + off; + for (; off < end; off++) { + if (keys[off] == key) { + return values[off]; + } + } + return defaultValue; + } + + public long remove(final int key, final long defaultValue) { + if (key == 0) { + if (hasKey0) { + this.hasKey0 = false; + long old = value0; + this.value0 = 0L; + size--; + return old; + } else { + return defaultValue; + } + } + + int off = getBucketOffset(key); + final int end = sweep + off; + for (; off < end; off++) { + if (keys[off] == key) { + keys[off] = 0; + long previous = values[off]; + values[off] = 0L; + size--; + return previous; + } + } + return defaultValue; + } + + public int size() { + return size; + } + + public boolean isEmpty() { + return size == 0; + } + + public boolean containsKey(final int key) { + if (key == 0) { + return hasKey0; + } + + int off = getBucketOffset(key); + final int end = sweep + off; + for (; off < end; off++) { + if (keys[off] == key) { + return true; + } + } + return false; + } + + public void clear() { + this.hasKey0 = false; + this.value0 = 0L; + Arrays.fill(keys, 0); + Arrays.fill(values, 0L); + this.size = 0; + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + ' ' + size; + } + + private void resize(final int bits) { + this.bits = bits; + this.sweepbits = bits / 4; + this.sweep = MathUtils.powerOf(2, sweepbits) * 4; + this.sweepmask = MathUtils.bitMask(bits - sweepbits) << sweepbits; + + // remember old values so we can recreate the entries + final int[] existingKeys = this.keys; + final long[] existingValues = this.values; + + // create the arrays + this.values = new long[MathUtils.powerOf(2, bits) + sweep]; + this.keys = new int[values.length]; + this.size = hasKey0 ? 1 : 0; + + // re-add the previous entries if resizing + if (existingKeys != null) { + for (int i = 0; i < existingKeys.length; i++) { + final int k = existingKeys[i]; + if (k != 0) { + put(k, existingValues[i]); + } + } + } + } + + private int getBucketOffset(final int key) { + return (HashUtils.fnv1a(key) << sweepbits) & sweepmask; + } + + @Nonnull + public MapIterator entries() { + return new MapIterator(); + } + + public final class MapIterator { + + int nextEntry; + int lastEntry = -2; + + MapIterator() { + this.nextEntry = nextEntry(-1); + } + + /** find the index of next full entry */ + int nextEntry(int index) { + if (index == -1) { + if (hasKey0) { + return -1; + } else { + index = 0; + } + } + while (index < keys.length && keys[index] == 0) { + index++; + } + return index; + } + + public boolean hasNext() { + return nextEntry < keys.length; + } + + public boolean next() { + free(lastEntry); + if (!hasNext()) { + return false; + } + int curEntry = nextEntry; + this.lastEntry = curEntry; + this.nextEntry = nextEntry(curEntry + 1); + return true; + } + + public int getKey() { + if (lastEntry >= 0 && lastEntry < keys.length) { + return keys[lastEntry]; + } else if (lastEntry == -1) { + return 0; + } else { + throw new IllegalStateException( + "next() should be called before getKey(). lastEntry=" + lastEntry + + ", keys.length=" + keys.length); + } + } + + public long getValue() { + if (lastEntry >= 0 && lastEntry < keys.length) { + return values[lastEntry]; + } else if (lastEntry == -1) { + return value0; + } else { + throw new IllegalStateException( + "next() should be called before getKey(). lastEntry=" + lastEntry + + ", keys.length=" + keys.length); + } + } + + private void free(int index) { + if (index >= 0) { + if (index >= keys.length) { + throw new IllegalStateException("index=" + index + ", keys.length=" + + keys.length); + } + keys[index] = 0; + values[index] = 0L; + } else if (index == -1) { + hasKey0 = false; + value0 = 0L; + } + // index may be -2 + } + + } +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java index 68eb42f..22acdb4 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java @@ -33,7 +33,12 @@ import java.util.Arrays; import javax.annotation.Nonnull; /** - * An open-addressing hash table with double hashing + * An open-addressing hash table using double hashing. + * + * <pre> + * Primary hash function: h1(k) = k mod m + * Secondary hash function: h2(k) = 1 + (k mod(m-2)) + * </pre> * * @see http://en.wikipedia.org/wiki/Double_hashing */ @@ -44,7 +49,7 @@ public class Int2LongOpenHashTable implements Externalizable { protected static final byte REMOVED = 2; public static final int DEFAULT_SIZE = 65536; - public static final float DEFAULT_LOAD_FACTOR = 0.7f; + public static final float DEFAULT_LOAD_FACTOR = 0.75f; public static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; @@ -123,23 +128,23 @@ public class Int2LongOpenHashTable implements Externalizable { return _states; } - public boolean containsKey(int key) { + public boolean containsKey(final int key) { return findKey(key) >= 0; } /** * @return -1.f if not found */ - public long get(int key) { - int i = findKey(key); + public long get(final int key) { + final int i = findKey(key); if (i < 0) { return defaultReturnValue; } return _values[i]; } - public long put(int key, long value) { - int hash = keyHash(key); + public long put(final int key, final long value) { + final int hash = keyHash(key); int keyLength = _keys.length; int keyIdx = hash % keyLength; @@ -149,9 +154,9 @@ public class Int2LongOpenHashTable implements Externalizable { keyIdx = hash % keyLength; } - int[] keys = _keys; - long[] values = _values; - byte[] states = _states; + final int[] keys = _keys; + final long[] values = _values; + final byte[] states = _states; if (states[keyIdx] == FULL) {// double hashing if (keys[keyIdx] == key) { @@ -160,7 +165,7 @@ public class Int2LongOpenHashTable implements Externalizable { return old; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -184,8 +189,8 @@ public class Int2LongOpenHashTable implements Externalizable { } /** Return weather the required slot is free for new entry */ - protected boolean isFree(int index, int key) { - byte stat = _states[index]; + protected boolean isFree(final int index, final int key) { + final byte stat = _states[index]; if (stat == FREE) { return true; } @@ -196,7 +201,7 @@ public class Int2LongOpenHashTable implements Externalizable { } /** @return expanded or not */ - protected boolean preAddEntry(int index) { + protected boolean preAddEntry(final int index) { if ((_used + 1) >= _threshold) {// too filled int newCapacity = Math.round(_keys.length * _growFactor); ensureCapacity(newCapacity); @@ -205,19 +210,19 @@ public class Int2LongOpenHashTable implements Externalizable { return false; } - protected int findKey(int key) { - int[] keys = _keys; - byte[] states = _states; - int keyLength = keys.length; + protected int findKey(final int key) { + final int[] keys = _keys; + final byte[] states = _states; + final int keyLength = keys.length; - int hash = keyHash(key); + final int hash = keyHash(key); int keyIdx = hash % keyLength; if (states[keyIdx] != FREE) { if (states[keyIdx] == FULL && keys[keyIdx] == key) { return keyIdx; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -234,13 +239,13 @@ public class Int2LongOpenHashTable implements Externalizable { return -1; } - public long remove(int key) { - int[] keys = _keys; - long[] values = _values; - byte[] states = _states; - int keyLength = keys.length; + public long remove(final int key) { + final int[] keys = _keys; + final long[] values = _values; + final byte[] states = _states; + final int keyLength = keys.length; - int hash = keyHash(key); + final int hash = keyHash(key); int keyIdx = hash % keyLength; if (states[keyIdx] != FREE) { if (states[keyIdx] == FULL && keys[keyIdx] == key) { @@ -250,7 +255,7 @@ public class Int2LongOpenHashTable implements Externalizable { return old; } // second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -283,21 +288,22 @@ public class Int2LongOpenHashTable implements Externalizable { this._used = 0; } - public IMapIterator entries() { + @Nonnull + public MapIterator entries() { return new MapIterator(); } @Override public String toString() { int len = size() * 10 + 2; - StringBuilder buf = new StringBuilder(len); + final StringBuilder buf = new StringBuilder(len); buf.append('{'); - IMapIterator i = entries(); - while (i.next() != -1) { - buf.append(i.getKey()); + final MapIterator itor = entries(); + while (itor.next() != -1) { + buf.append(itor.getKey()); buf.append('='); - buf.append(i.getValue()); - if (i.hasNext()) { + buf.append(itor.getValue()); + if (itor.hasNext()) { buf.append(','); } } @@ -305,30 +311,30 @@ public class Int2LongOpenHashTable implements Externalizable { return buf.toString(); } - protected void ensureCapacity(int newCapacity) { + protected void ensureCapacity(final int newCapacity) { int prime = Primes.findLeastPrimeNumber(newCapacity); rehash(prime); this._threshold = Math.round(prime * _loadFactor); } - private void rehash(int newCapacity) { + private void rehash(final int newCapacity) { int oldCapacity = _keys.length; if (newCapacity <= oldCapacity) { throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); } - int[] newkeys = new int[newCapacity]; - long[] newValues = new long[newCapacity]; - byte[] newStates = new byte[newCapacity]; + final int[] newkeys = new int[newCapacity]; + final long[] newValues = new long[newCapacity]; + final byte[] newStates = new byte[newCapacity]; int used = 0; for (int i = 0; i < oldCapacity; i++) { if (_states[i] == FULL) { used++; - int k = _keys[i]; - long v = _values[i]; - int hash = keyHash(k); + final int k = _keys[i]; + final long v = _values[i]; + final int hash = keyHash(k); int keyIdx = hash % newCapacity; if (newStates[keyIdx] == FULL) {// second hashing - int decr = 1 + (hash % (newCapacity - 2)); + final int decr = 1 + (hash % (newCapacity - 2)); while (newStates[keyIdx] != FREE) { keyIdx -= decr; if (keyIdx < 0) { @@ -347,7 +353,7 @@ public class Int2LongOpenHashTable implements Externalizable { this._used = used; } - private static int keyHash(int key) { + private static int keyHash(final int key) { return key & 0x7fffffff; } @@ -437,22 +443,7 @@ public class Int2LongOpenHashTable implements Externalizable { } } - public interface IMapIterator { - - public boolean hasNext(); - - /** - * @return -1 if not found - */ - public int next(); - - public int getKey(); - - public long getValue(); - - } - - private final class MapIterator implements IMapIterator { + public final class MapIterator { int nextEntry; int lastEntry = -1; @@ -473,6 +464,9 @@ public class Int2LongOpenHashTable implements Externalizable { return nextEntry < _keys.length; } + /** + * @return -1 if not found + */ public int next() { if (!hasNext()) { return -1; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java deleted file mode 100644 index 5ce34a4..0000000 --- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java +++ /dev/null @@ -1,467 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections.maps; - -import hivemall.utils.math.Primes; - -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.Arrays; - -/** - * An open-addressing hash table with double hashing - * - * @see http://en.wikipedia.org/wiki/Double_hashing - */ -public class IntOpenHashMap<V> implements Externalizable { - private static final long serialVersionUID = -8162355845665353513L; - - protected static final byte FREE = 0; - protected static final byte FULL = 1; - protected static final byte REMOVED = 2; - - private static final float DEFAULT_LOAD_FACTOR = 0.7f; - private static final float DEFAULT_GROW_FACTOR = 2.0f; - - protected final transient float _loadFactor; - protected final transient float _growFactor; - - protected int _used = 0; - protected int _threshold; - - protected int[] _keys; - protected V[] _values; - protected byte[] _states; - - @SuppressWarnings("unchecked") - protected IntOpenHashMap(int size, float loadFactor, float growFactor, boolean forcePrime) { - if (size < 1) { - throw new IllegalArgumentException(); - } - this._loadFactor = loadFactor; - this._growFactor = growFactor; - int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size; - this._keys = new int[actualSize]; - this._values = (V[]) new Object[actualSize]; - this._states = new byte[actualSize]; - this._threshold = Math.round(actualSize * _loadFactor); - } - - public IntOpenHashMap(int size, float loadFactor, float growFactor) { - this(size, loadFactor, growFactor, true); - } - - public IntOpenHashMap(int size) { - this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); - } - - public IntOpenHashMap() {// required for serialization - this._loadFactor = DEFAULT_LOAD_FACTOR; - this._growFactor = DEFAULT_GROW_FACTOR; - } - - public boolean containsKey(int key) { - return findKey(key) >= 0; - } - - public final V get(final int key) { - final int i = findKey(key); - if (i < 0) { - return null; - } - recordAccess(i); - return _values[i]; - } - - public V put(final int key, final V value) { - final int hash = keyHash(key); - int keyLength = _keys.length; - int keyIdx = hash % keyLength; - - final boolean expanded = preAddEntry(keyIdx); - if (expanded) { - keyLength = _keys.length; - keyIdx = hash % keyLength; - } - - final int[] keys = _keys; - final V[] values = _values; - final byte[] states = _states; - - if (states[keyIdx] == FULL) {// double hashing - if (keys[keyIdx] == key) { - V old = values[keyIdx]; - values[keyIdx] = value; - recordAccess(keyIdx); - return old; - } - // try second hash - final int decr = 1 + (hash % (keyLength - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keyLength; - } - if (isFree(keyIdx, key)) { - break; - } - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - V old = values[keyIdx]; - values[keyIdx] = value; - recordAccess(keyIdx); - return old; - } - } - } - keys[keyIdx] = key; - values[keyIdx] = value; - states[keyIdx] = FULL; - ++_used; - postAddEntry(keyIdx); - return null; - } - - public V putIfAbsent(final int key, final V value) { - final int hash = keyHash(key); - int keyLength = _keys.length; - int keyIdx = hash % keyLength; - - final boolean expanded = preAddEntry(keyIdx); - if (expanded) { - keyLength = _keys.length; - keyIdx = hash % keyLength; - } - - final int[] keys = _keys; - final V[] values = _values; - final byte[] states = _states; - - if (states[keyIdx] == FULL) {// second hashing - if (keys[keyIdx] == key) { - return values[keyIdx]; - } - // try second hash - final int decr = 1 + (hash % (keyLength - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keyLength; - } - if (isFree(keyIdx, key)) { - break; - } - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - return values[keyIdx]; - } - } - } - keys[keyIdx] = key; - values[keyIdx] = value; - states[keyIdx] = FULL; - _used++; - postAddEntry(keyIdx); - return null; - } - - /** Return weather the required slot is free for new entry */ - protected boolean isFree(int index, int key) { - byte stat = _states[index]; - if (stat == FREE) { - return true; - } - if (stat == REMOVED && _keys[index] == key) { - return true; - } - return false; - } - - /** @return expanded or not */ - protected boolean preAddEntry(int index) { - if ((_used + 1) >= _threshold) {// too filled - int newCapacity = Math.round(_keys.length * _growFactor); - ensureCapacity(newCapacity); - return true; - } - return false; - } - - protected void postAddEntry(int index) {} - - private int findKey(int key) { - int[] keys = _keys; - byte[] states = _states; - int keyLength = keys.length; - - int hash = keyHash(key); - int keyIdx = hash % keyLength; - if (states[keyIdx] != FREE) { - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - return keyIdx; - } - // try second hash - int decr = 1 + (hash % (keyLength - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keyLength; - } - if (isFree(keyIdx, key)) { - return -1; - } - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - return keyIdx; - } - } - } - return -1; - } - - public V remove(int key) { - int[] keys = _keys; - V[] values = _values; - byte[] states = _states; - int keyLength = keys.length; - - int hash = keyHash(key); - int keyIdx = hash % keyLength; - if (states[keyIdx] != FREE) { - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - V old = values[keyIdx]; - states[keyIdx] = REMOVED; - --_used; - recordRemoval(keyIdx); - return old; - } - // second hash - int decr = 1 + (hash % (keyLength - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keyLength; - } - if (states[keyIdx] == FREE) { - return null; - } - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - V old = values[keyIdx]; - states[keyIdx] = REMOVED; - --_used; - recordRemoval(keyIdx); - return old; - } - } - } - return null; - } - - public int size() { - return _used; - } - - public void clear() { - Arrays.fill(_states, FREE); - this._used = 0; - } - - @SuppressWarnings("unchecked") - public IMapIterator<V> entries() { - return new MapIterator(); - } - - @Override - public String toString() { - int len = size() * 10 + 2; - StringBuilder buf = new StringBuilder(len); - buf.append('{'); - IMapIterator<V> i = entries(); - while (i.next() != -1) { - buf.append(i.getKey()); - buf.append('='); - buf.append(i.getValue()); - if (i.hasNext()) { - buf.append(','); - } - } - buf.append('}'); - return buf.toString(); - } - - private void ensureCapacity(int newCapacity) { - int prime = Primes.findLeastPrimeNumber(newCapacity); - rehash(prime); - this._threshold = Math.round(prime * _loadFactor); - } - - @SuppressWarnings("unchecked") - protected void rehash(int newCapacity) { - int oldCapacity = _keys.length; - if (newCapacity <= oldCapacity) { - throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); - } - final int[] oldKeys = _keys; - final V[] oldValues = _values; - final byte[] oldStates = _states; - int[] newkeys = new int[newCapacity]; - V[] newValues = (V[]) new Object[newCapacity]; - byte[] newStates = new byte[newCapacity]; - int used = 0; - for (int i = 0; i < oldCapacity; i++) { - if (oldStates[i] == FULL) { - used++; - int k = oldKeys[i]; - V v = oldValues[i]; - int hash = keyHash(k); - int keyIdx = hash % newCapacity; - if (newStates[keyIdx] == FULL) {// second hashing - int decr = 1 + (hash % (newCapacity - 2)); - while (newStates[keyIdx] != FREE) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += newCapacity; - } - } - } - newkeys[keyIdx] = k; - newValues[keyIdx] = v; - newStates[keyIdx] = FULL; - } - } - this._keys = newkeys; - this._values = newValues; - this._states = newStates; - this._used = used; - } - - private static int keyHash(int key) { - return key & 0x7fffffff; - } - - protected void recordAccess(int idx) {} - - protected void recordRemoval(int idx) {} - - public void writeExternal(ObjectOutput out) throws IOException { - out.writeInt(_threshold); - out.writeInt(_used); - - out.writeInt(_keys.length); - IMapIterator<V> i = entries(); - while (i.next() != -1) { - out.writeInt(i.getKey()); - out.writeObject(i.getValue()); - } - } - - @SuppressWarnings("unchecked") - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - this._threshold = in.readInt(); - this._used = in.readInt(); - - int keylen = in.readInt(); - int[] keys = new int[keylen]; - V[] values = (V[]) new Object[keylen]; - byte[] states = new byte[keylen]; - for (int i = 0; i < _used; i++) { - int k = in.readInt(); - V v = (V) in.readObject(); - int hash = keyHash(k); - int keyIdx = hash % keylen; - if (states[keyIdx] != FREE) {// second hash - int decr = 1 + (hash % (keylen - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keylen; - } - if (states[keyIdx] == FREE) { - break; - } - } - } - states[keyIdx] = FULL; - keys[keyIdx] = k; - values[keyIdx] = v; - } - this._keys = keys; - this._values = values; - this._states = states; - } - - public interface IMapIterator<V> { - - public boolean hasNext(); - - public int next(); - - public int getKey(); - - public V getValue(); - - } - - @SuppressWarnings("rawtypes") - private final class MapIterator implements IMapIterator { - - int nextEntry; - int lastEntry = -1; - - MapIterator() { - this.nextEntry = nextEntry(0); - } - - /** find the index of next full entry */ - int nextEntry(int index) { - while (index < _keys.length && _states[index] != FULL) { - index++; - } - return index; - } - - public boolean hasNext() { - return nextEntry < _keys.length; - } - - public int next() { - if (!hasNext()) { - return -1; - } - int curEntry = nextEntry; - this.lastEntry = curEntry; - this.nextEntry = nextEntry(curEntry + 1); - return curEntry; - } - - public int getKey() { - if (lastEntry == -1) { - throw new IllegalStateException(); - } - return _keys[lastEntry]; - } - - public V getValue() { - if (lastEntry == -1) { - throw new IllegalStateException(); - } - return _values[lastEntry]; - } - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java index dcb64d1..dbade74 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java @@ -25,54 +25,68 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Arrays; -import java.util.HashMap; import javax.annotation.Nonnull; /** - * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}. + * An open-addressing hash table using double hashing. + * + * <pre> + * Primary hash function: h1(k) = k mod m + * Secondary hash function: h2(k) = 1 + (k mod(m-2)) + * </pre> + * + * @see http://en.wikipedia.org/wiki/Double_hashing */ public final class IntOpenHashTable<V> implements Externalizable { + private static final long serialVersionUID = -8162355845665353513L; - public static final float DEFAULT_LOAD_FACTOR = 0.7f; + public static final float DEFAULT_LOAD_FACTOR = 0.75f; public static final float DEFAULT_GROW_FACTOR = 2.0f; - public static final byte FREE = 0; - public static final byte FULL = 1; - public static final byte REMOVED = 2; + protected static final byte FREE = 0; + protected static final byte FULL = 1; + protected static final byte REMOVED = 2; protected/* final */float _loadFactor; protected/* final */float _growFactor; - protected int _used = 0; + protected int _used; protected int _threshold; protected int[] _keys; protected V[] _values; protected byte[] _states; - public IntOpenHashTable() {} // for Externalizable + public IntOpenHashTable() {} // for Externalizable public IntOpenHashTable(int size) { - this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR); + this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); } - @SuppressWarnings("unchecked") public IntOpenHashTable(int size, float loadFactor, float growFactor) { + this(size, loadFactor, growFactor, true); + } + + @SuppressWarnings("unchecked") + protected IntOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) { if (size < 1) { throw new IllegalArgumentException(); } this._loadFactor = loadFactor; this._growFactor = growFactor; - int actualSize = Primes.findLeastPrimeNumber(size); + this._used = 0; + int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size; + this._threshold = Math.round(actualSize * _loadFactor); this._keys = new int[actualSize]; this._values = (V[]) new Object[actualSize]; this._states = new byte[actualSize]; - this._threshold = Math.round(actualSize * _loadFactor); } public IntOpenHashTable(@Nonnull int[] keys, @Nonnull V[] values, @Nonnull byte[] states, int used) { + this._loadFactor = DEFAULT_LOAD_FACTOR; + this._growFactor = DEFAULT_GROW_FACTOR; this._used = used; this._threshold = keys.length; this._keys = keys; @@ -80,14 +94,17 @@ public final class IntOpenHashTable<V> implements Externalizable { this._states = states; } + @Nonnull public int[] getKeys() { return _keys; } + @Nonnull public Object[] getValues() { return _values; } + @Nonnull public byte[] getStates() { return _states; } @@ -109,7 +126,7 @@ public final class IntOpenHashTable<V> implements Externalizable { int keyLength = _keys.length; int keyIdx = hash % keyLength; - boolean expanded = preAddEntry(keyIdx); + final boolean expanded = preAddEntry(keyIdx); if (expanded) { keyLength = _keys.length; keyIdx = hash % keyLength; @@ -119,14 +136,14 @@ public final class IntOpenHashTable<V> implements Externalizable { final V[] values = _values; final byte[] states = _states; - if (states[keyIdx] == FULL) { + if (states[keyIdx] == FULL) {// double hashing if (keys[keyIdx] == key) { V old = values[keyIdx]; values[keyIdx] = value; return old; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -149,10 +166,50 @@ public final class IntOpenHashTable<V> implements Externalizable { return null; } + public V putIfAbsent(final int key, final V value) { + final int hash = keyHash(key); + int keyLength = _keys.length; + int keyIdx = hash % keyLength; + + final boolean expanded = preAddEntry(keyIdx); + if (expanded) { + keyLength = _keys.length; + keyIdx = hash % keyLength; + } + + final int[] keys = _keys; + final V[] values = _values; + final byte[] states = _states; + + if (states[keyIdx] == FULL) {// second hashing + if (keys[keyIdx] == key) { + return values[keyIdx]; + } + // try second hash + final int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + break; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + return values[keyIdx]; + } + } + } + keys[keyIdx] = key; + values[keyIdx] = value; + states[keyIdx] = FULL; + _used++; + return null; + } /** Return weather the required slot is free for new entry */ - protected boolean isFree(int index, int key) { - byte stat = _states[index]; + protected boolean isFree(final int index, final int key) { + final byte stat = _states[index]; if (stat == FREE) { return true; } @@ -163,8 +220,8 @@ public final class IntOpenHashTable<V> implements Externalizable { } /** @return expanded or not */ - protected boolean preAddEntry(int index) { - if ((_used + 1) >= _threshold) {// filled enough + protected boolean preAddEntry(final int index) { + if ((_used + 1) >= _threshold) {// too filled int newCapacity = Math.round(_keys.length * _growFactor); ensureCapacity(newCapacity); return true; @@ -172,7 +229,7 @@ public final class IntOpenHashTable<V> implements Externalizable { return false; } - protected int findKey(final int key) { + private int findKey(final int key) { final int[] keys = _keys; final byte[] states = _states; final int keyLength = keys.length; @@ -184,7 +241,7 @@ public final class IntOpenHashTable<V> implements Externalizable { return keyIdx; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -217,7 +274,7 @@ public final class IntOpenHashTable<V> implements Externalizable { return old; } // second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -255,28 +312,49 @@ public final class IntOpenHashTable<V> implements Externalizable { this._used = 0; } - protected void ensureCapacity(int newCapacity) { + @Override + public String toString() { + int len = size() * 10 + 2; + final StringBuilder buf = new StringBuilder(len); + buf.append('{'); + final IMapIterator<V> i = entries(); + while (i.next() != -1) { + buf.append(i.getKey()); + buf.append('='); + buf.append(i.getValue()); + if (i.hasNext()) { + buf.append(','); + } + } + buf.append('}'); + return buf.toString(); + } + + private void ensureCapacity(final int newCapacity) { int prime = Primes.findLeastPrimeNumber(newCapacity); rehash(prime); this._threshold = Math.round(prime * _loadFactor); } @SuppressWarnings("unchecked") - private void rehash(int newCapacity) { + private void rehash(final int newCapacity) { int oldCapacity = _keys.length; if (newCapacity <= oldCapacity) { throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); } + final int[] oldKeys = _keys; + final V[] oldValues = _values; + final byte[] oldStates = _states; final int[] newkeys = new int[newCapacity]; final V[] newValues = (V[]) new Object[newCapacity]; final byte[] newStates = new byte[newCapacity]; int used = 0; for (int i = 0; i < oldCapacity; i++) { - if (_states[i] == FULL) { + if (oldStates[i] == FULL) { used++; - int k = _keys[i]; - V v = _values[i]; - int hash = keyHash(k); + final int k = oldKeys[i]; + final V v = oldValues[i]; + final int hash = keyHash(k); int keyIdx = hash % newCapacity; if (newStates[keyIdx] == FULL) {// second hashing int decr = 1 + (hash % (newCapacity - 2)); @@ -287,9 +365,9 @@ public final class IntOpenHashTable<V> implements Externalizable { } } } - newStates[keyIdx] = FULL; newkeys[keyIdx] = k; newValues[keyIdx] = v; + newStates[keyIdx] = FULL; } } this._keys = newkeys; @@ -303,7 +381,7 @@ public final class IntOpenHashTable<V> implements Externalizable { } @Override - public void writeExternal(ObjectOutput out) throws IOException { + public void writeExternal(@Nonnull final ObjectOutput out) throws IOException { out.writeFloat(_loadFactor); out.writeFloat(_growFactor); out.writeInt(_used); @@ -319,8 +397,8 @@ public final class IntOpenHashTable<V> implements Externalizable { } @SuppressWarnings("unchecked") - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + public void readExternal(@Nonnull final ObjectInput in) throws IOException, + ClassNotFoundException { this._loadFactor = in.readFloat(); this._growFactor = in.readFloat(); this._used = in.readInt(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java index c758824..b4356ff 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java @@ -27,7 +27,12 @@ import java.io.ObjectOutput; import java.util.Arrays; /** - * An open-addressing hash table with double hashing + * An open-addressing hash table using double hashing. + * + * <pre> + * Primary hash function: h1(k) = k mod m + * Secondary hash function: h2(k) = 1 + (k mod(m-2)) + * </pre> * * @see http://en.wikipedia.org/wiki/Double_hashing */ @@ -37,7 +42,7 @@ public final class Long2DoubleOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java index 6a7f39f..6b0ab59 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java @@ -27,9 +27,14 @@ import java.io.ObjectOutput; import java.util.Arrays; /** - * An open-addressing hash table with float hashing + * An open-addressing hash table using double hashing. + * + * <pre> + * Primary hash function: h1(k) = k mod m + * Secondary hash function: h2(k) = 1 + (k mod(m-2)) + * </pre> * - * @see http://en.wikipedia.org/wiki/float_hashing + * @see http://en.wikipedia.org/wiki/Double_hashing */ public final class Long2FloatOpenHashTable implements Externalizable { @@ -37,7 +42,7 @@ public final class Long2FloatOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java index 51b8f12..1ca4c40 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java @@ -27,7 +27,12 @@ import java.io.ObjectOutput; import java.util.Arrays; /** - * An open-addressing hash table with double hashing + * An open-addressing hash table using double hashing. + * + * <pre> + * Primary hash function: h1(k) = k mod m + * Secondary hash function: h2(k) = 1 + (k mod(m-2)) + * </pre> * * @see http://en.wikipedia.org/wiki/Double_hashing */ @@ -37,7 +42,7 @@ public final class Long2IntOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java index 152447a..f5ee1e6 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java +++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java @@ -48,16 +48,29 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; +import javax.annotation.CheckForNull; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + /** - * An optimized Hashed Map implementation. - * <p/> - * <p> - * This Hashmap does not allow nulls to be used as keys or values. - * <p/> - * <p> + * A space efficient open-addressing HashMap implementation. + * + * Unlike {@link OpenHashTable}, it maintains single arrays for keys and object references. + * * It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those - * divisable by prime numbers. This allows the hash offset calculation to be a simple binary masking + * divisible by prime numbers. This allows the hash offset calculation to be a simple binary masking * operation. + * + * The index into the arrays is determined by masking a portion of the key and shifting it to + * provide a series of small buckets within the array. To insert an entry the a sweep is searched + * until an empty key space is found. A sweep is 4 times the length of a bucket, to reduce the need + * to rehash. If no key space is found within a sweep, the table size is doubled. + * + * While performance is high, the slowest situation is where lookup occurs for entries that do not + * exist, as an entire sweep area must be searched. However, this HashMap is more space efficient + * than other open-addressing HashMap implementations as in fastutil. + * + * Note that this HashMap does not allow nulls to be used as keys. */ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { private K[] keys; @@ -80,21 +93,21 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { resize(MathUtils.bitsRequired(size < 256 ? 256 : size)); } - public V put(K key, V value) { + @Nullable + public V put(@CheckForNull final K key, @Nullable final V value) { if (key == null) { throw new NullPointerException(this.getClass().getName() + " key"); } for (;;) { int off = getBucketOffset(key); - int end = off + sweep; + final int end = off + sweep; for (; off < end; off++) { - K searchKey = keys[off]; + final K searchKey = keys[off]; if (searchKey == null) { // insert keys[off] = key; size++; - V previous = values[off]; values[off] = value; return previous; @@ -109,9 +122,36 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { } } - public V get(Object key) { + @Nullable + public V putIfAbsent(@CheckForNull final K key, @Nullable final V value) { + if (key == null) { + throw new NullPointerException(this.getClass().getName() + " key"); + } + + for (;;) { + int off = getBucketOffset(key); + final int end = off + sweep; + for (; off < end; off++) { + final K searchKey = keys[off]; + if (searchKey == null) { + // insert + keys[off] = key; + size++; + V previous = values[off]; + values[off] = value; + return previous; + } else if (compare(searchKey, key)) { + return values[off]; + } + } + resize(this.bits + 1); + } + } + + @Nullable + public V get(@Nonnull final Object key) { int off = getBucketOffset(key); - int end = sweep + off; + final int end = sweep + off; for (; off < end; off++) { if (keys[off] != null && compare(keys[off], key)) { return values[off]; @@ -120,9 +160,10 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { return null; } - public V remove(Object key) { + @Nullable + public V remove(@Nonnull final Object key) { int off = getBucketOffset(key); - int end = sweep + off; + final int end = sweep + off; for (; off < end; off++) { if (keys[off] != null && compare(keys[off], key)) { keys[off] = null; @@ -139,7 +180,7 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { return size; } - public void putAll(Map<? extends K, ? extends V> m) { + public void putAll(@Nonnull final Map<? extends K, ? extends V> m) { for (K key : m.keySet()) { put(key, m.get(key)); } @@ -149,11 +190,11 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { return size == 0; } - public boolean containsKey(Object key) { + public boolean containsKey(@Nonnull final Object key) { return get(key) != null; } - public boolean containsValue(Object value) { + public boolean containsValue(@Nonnull final Object value) { for (V v : values) { if (v != null && compare(v, value)) { return true; @@ -165,11 +206,12 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { public void clear() { Arrays.fill(keys, null); Arrays.fill(values, null); - size = 0; + this.size = 0; } + @Nonnull public Set<K> keySet() { - Set<K> set = new HashSet<K>(); + final Set<K> set = new HashSet<K>(); for (K key : keys) { if (key != null) { set.add(key); @@ -178,8 +220,9 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { return set; } + @Nonnull public Collection<V> values() { - Collection<V> list = new ArrayList<V>(); + final Collection<V> list = new ArrayList<V>(); for (V value : values) { if (value != null) { list.add(value); @@ -188,8 +231,9 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { return list; } + @Nonnull public Set<Entry<K, V>> entrySet() { - Set<Entry<K, V>> set = new HashSet<Entry<K, V>>(); + final Set<Entry<K, V>> set = new HashSet<Entry<K, V>>(); for (K key : keys) { if (key != null) { set.add(new MapEntry<K, V>(this, key)); @@ -207,19 +251,23 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { this.key = key; } + @Override public K getKey() { return key; } + @Override public V getValue() { return map.get(key); } + @Override public V setValue(V value) { return map.put(key, value); } } + @Override public void writeExternal(ObjectOutput out) throws IOException { // remember the number of bits out.writeInt(this.bits); @@ -235,6 +283,7 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { } @SuppressWarnings("unchecked") + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { // resize to old bit size int bitSize = in.readInt(); @@ -250,19 +299,19 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { @Override public String toString() { - return this.getClass().getSimpleName() + ' ' + this.size; + return this.getClass().getSimpleName() + ' ' + size; } @SuppressWarnings("unchecked") - private void resize(int bits) { + private void resize(final int bits) { this.bits = bits; this.sweepbits = bits / 4; this.sweep = MathUtils.powerOf(2, sweepbits) * 4; - this.sweepmask = MathUtils.bitMask(bits - this.sweepbits) << sweepbits; + this.sweepmask = MathUtils.bitMask(bits - sweepbits) << sweepbits; // remember old values so we can recreate the entries - K[] existingKeys = this.keys; - V[] existingValues = this.values; + final K[] existingKeys = this.keys; + final V[] existingValues = this.values; // create the arrays this.values = (V[]) new Object[MathUtils.powerOf(2, bits) + sweep]; @@ -272,31 +321,38 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { // re-add the previous entries if resizing if (existingKeys != null) { for (int x = 0; x < existingKeys.length; x++) { - if (existingKeys[x] != null) { - put(existingKeys[x], existingValues[x]); + final K k = existingKeys[x]; + if (k != null) { + put(k, existingValues[x]); } } } } - private int getBucketOffset(Object key) { - return (key.hashCode() << this.sweepbits) & this.sweepmask; + private int getBucketOffset(@Nonnull final Object key) { + return (key.hashCode() << sweepbits) & sweepmask; } - private static boolean compare(final Object v1, final Object v2) { + private static boolean compare(@Nonnull final Object v1, @Nonnull final Object v2) { return v1 == v2 || v1.equals(v2); } public IMapIterator<K, V> entries() { - return new MapIterator(); + return new MapIterator(false); + } + + public IMapIterator<K, V> entries(boolean releaseSeen) { + return new MapIterator(releaseSeen); } private final class MapIterator implements IMapIterator<K, V> { + final boolean releaseSeen; int nextEntry; int lastEntry = -1; - MapIterator() { + MapIterator(boolean releaseSeen) { + this.releaseSeen = releaseSeen; this.nextEntry = nextEntry(0); } @@ -315,7 +371,9 @@ public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { @Override public int next() { - free(lastEntry); + if (releaseSeen) { + free(lastEntry); + } if (!hasNext()) { return -1; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java index 7fec9b0..4599bfc 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java @@ -27,16 +27,22 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Arrays; -import java.util.HashMap; import javax.annotation.Nonnull; /** - * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}. + * An open-addressing hash table using double-hashing. + * + * <pre> + * Primary hash function: h1(k) = k mod m + * Secondary hash function: h2(k) = 1 + (k mod(m-2)) + * </pre> + * + * @see http://en.wikipedia.org/wiki/Double_hashing */ public final class OpenHashTable<K, V> implements Externalizable { - public static final float DEFAULT_LOAD_FACTOR = 0.7f; + public static final float DEFAULT_LOAD_FACTOR = 0.75f; public static final float DEFAULT_GROW_FACTOR = 2.0f; protected static final byte FREE = 0; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 0b68de8..db56b82 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -289,12 +289,21 @@ public final class HiveUtils { } } - @Nonnull public static boolean isListOI(@Nonnull final ObjectInspector oi) { Category category = oi.getCategory(); return category == Category.LIST; } + public static boolean isStringListOI(@Nonnull final ObjectInspector oi) + throws UDFArgumentException { + Category category = oi.getCategory(); + if (category != Category.LIST) { + throw new UDFArgumentException("Expected List OI but was: " + oi); + } + ListObjectInspector listOI = (ListObjectInspector) oi; + return isStringOI(listOI.getListElementObjectInspector()); + } + public static boolean isMapOI(@Nonnull final ObjectInspector oi) { return oi.getCategory() == Category.MAP; } @@ -670,6 +679,36 @@ public final class HiveUtils { } @Nullable + public static float[] asFloatArray(@Nullable final Object argObj, + @Nonnull final ListObjectInspector listOI, + @Nonnull final PrimitiveObjectInspector elemOI) throws UDFArgumentException { + return asFloatArray(argObj, listOI, elemOI, true); + } + + @Nullable + public static float[] asFloatArray(@Nullable final Object argObj, + @Nonnull final ListObjectInspector listOI, + @Nonnull final PrimitiveObjectInspector elemOI, final boolean avoidNull) + throws UDFArgumentException { + if (argObj == null) { + return null; + } + final int length = listOI.getListLength(argObj); + final float[] ary = new float[length]; + for (int i = 0; i < length; i++) { + Object o = listOI.getListElement(argObj, i); + if (o == null) { + if (avoidNull) { + continue; + } + throw new UDFArgumentException("Found null at index " + i); + } + ary[i] = PrimitiveObjectInspectorUtils.getFloat(o, elemOI); + } + return ary; + } + + @Nullable public static double[] asDoubleArray(@Nullable final Object argObj, @Nonnull final ListObjectInspector listOI, @Nonnull final PrimitiveObjectInspector elemOI) throws UDFArgumentException { @@ -694,8 +733,7 @@ public final class HiveUtils { } throw new UDFArgumentException("Found null at index " + i); } - double d = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); - ary[i] = d; + ary[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); } return ary; } @@ -721,8 +759,7 @@ public final class HiveUtils { } throw new UDFArgumentException("Found null at index " + i); } - double d = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); - out[i] = d; + out[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); } return; } @@ -746,8 +783,7 @@ public final class HiveUtils { out[i] = nullValue; continue; } - double d = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); - out[i] = d; + out[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); } return; } @@ -766,11 +802,11 @@ public final class HiveUtils { int count = 0; final int length = listOI.getListLength(argObj); for (int i = 0; i < length; i++) { - Object o = listOI.getListElement(argObj, i); + final Object o = listOI.getListElement(argObj, i); if (o == null) { continue; } - int index = PrimitiveObjectInspectorUtils.getInt(o, elemOI); + final int index = PrimitiveObjectInspectorUtils.getInt(o, elemOI); if (index < 0) { throw new UDFArgumentException("Negative index is not allowed: " + index); } @@ -955,6 +991,26 @@ public final class HiveUtils { } @Nonnull + public static PrimitiveObjectInspector asFloatingPointOI(@Nonnull final ObjectInspector argOI) + throws UDFArgumentTypeException { + if (argOI.getCategory() != Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but " + + argOI.getTypeName() + " is passed."); + } + final PrimitiveObjectInspector oi = (PrimitiveObjectInspector) argOI; + switch (oi.getPrimitiveCategory()) { + case FLOAT: + case DOUBLE: + break; + default: + throw new UDFArgumentTypeException(0, + "Only numeric or string type arguments are accepted but " + argOI.getTypeName() + + " is passed."); + } + return oi; + } + + @Nonnull public static PrimitiveObjectInspector asNumberOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/hashing/HashUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hashing/HashUtils.java b/core/src/main/java/hivemall/utils/hashing/HashUtils.java new file mode 100644 index 0000000..710d8f6 --- /dev/null +++ b/core/src/main/java/hivemall/utils/hashing/HashUtils.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.hashing; + +public final class HashUtils { + + private HashUtils() {} + + public static int jenkins32(int k) { + k = (k + 0x7ed55d16) + (k << 12); + k = (k ^ 0xc761c23c) ^ (k >> 19); + k = (k + 0x165667b1) + (k << 5); + k = (k + 0xd3a2646c) ^ (k << 9); + k = (k + 0xfd7046c5) + (k << 3); + k = (k ^ 0xb55a4f09) ^ (k >> 16); + return k; + } + + public static int murmurHash3(int k) { + k ^= k >>> 16; + k *= 0x85ebca6b; + k ^= k >>> 13; + k *= 0xc2b2ae35; + k ^= k >>> 16; + return k; + } + + public static int fnv1a(final int k) { + int hash = 0x811c9dc5; + for (int i = 0; i < 4; i++) { + hash ^= k << (i * 8); + hash *= 0x01000193; + } + return hash; + } + + /** + * https://gist.github.com/badboy/6267743 + */ + public static int hash32shift(int k) { + k = ~k + (k << 15); // key = (key << 15) - key - 1; + k = k ^ (k >>> 12); + k = k + (k << 2); + k = k ^ (k >>> 4); + k = k * 2057; // key = (key + (key << 3)) + (key << 11); + k = k ^ (k >>> 16); + return k; + } + + public static int hash32shiftmult(int k) { + k = (k ^ 61) ^ (k >>> 16); + k = k + (k << 3); + k = k ^ (k >>> 4); + k = k * 0x27d4eb2d; + k = k ^ (k >>> 15); + return k; + } + + /** + * http://burtleburtle.net/bob/hash/integer.html + */ + public static int hash7shifts(int k) { + k -= (k << 6); + k ^= (k >> 17); + k -= (k << 9); + k ^= (k << 4); + k -= (k << 3); + k ^= (k << 10); + k ^= (k >> 15); + return k; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/lang/NumberUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/NumberUtils.java b/core/src/main/java/hivemall/utils/lang/NumberUtils.java index 0d3f895..4b04f04 100644 --- a/core/src/main/java/hivemall/utils/lang/NumberUtils.java +++ b/core/src/main/java/hivemall/utils/lang/NumberUtils.java @@ -107,4 +107,72 @@ public final class NumberUtils { return true; } + /** + * @throws ArithmeticException + */ + public static int castToInt(final long value) { + final int result = (int) value; + if (result != value) { + throw new ArithmeticException("Out of range: " + value); + } + return result; + } + + /** + * @throws ArithmeticException + */ + public static short castToShort(final int value) { + final short result = (short) value; + if (result != value) { + throw new ArithmeticException("Out of range: " + value); + } + return result; + } + + /** + * Cast Double to Float. + * + * @throws ArithmeticException + */ + public static float castToFloat(final double v) { + if ((v < Float.MIN_VALUE) || (v > Float.MAX_VALUE)) { + throw new ArithmeticException("Double value is out of Float range: " + v); + } + return (float) v; + } + + /** + * Cast Double to Float. + * + * @return v if v is Float range; Float.MIN_VALUE or Float.MAX_VALUE otherwise + */ + public static float safeCast(final double v) { + if (v < Float.MIN_VALUE) { + return Float.MIN_VALUE; + } else if (v > Float.MAX_VALUE) { + return Float.MAX_VALUE; + } + return (float) v; + } + + /** + * Cast Double to Float. + * + * @return v if v is Float range; defaultValue otherwise + */ + public static float safeCast(final double v, final float defaultValue) { + if ((v < Float.MIN_VALUE) || (v > Float.MAX_VALUE)) { + return defaultValue; + } + return (float) v; + } + + public static int toUnsignedShort(final short v) { + return v & 0xFFFF; // convert to range 0-65535 from -32768-32767. + } + + public static int toUnsignedInt(final byte x) { + return ((int) x) & 0xff; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/lang/Primitives.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java index 2ec012c..7d43da1 100644 --- a/core/src/main/java/hivemall/utils/lang/Primitives.java +++ b/core/src/main/java/hivemall/utils/lang/Primitives.java @@ -26,14 +26,6 @@ public final class Primitives { private Primitives() {} - public static int toUnsignedShort(final short v) { - return v & 0xFFFF; // convert to range 0-65535 from -32768-32767. - } - - public static int toUnsignedInt(final byte x) { - return ((int) x) & 0xff; - } - public static short parseShort(final String s, final short defaultValue) { if (s == null) { return defaultValue; @@ -92,22 +84,6 @@ public final class Primitives { b[off] = (byte) (val >>> 8); } - public static int toIntExact(final long longValue) { - final int casted = (int) longValue; - if (casted != longValue) { - throw new ArithmeticException("integer overflow: " + longValue); - } - return casted; - } - - public static int castToInt(final long value) { - final int result = (int) value; - if (result != value) { - throw new IllegalArgumentException("Out of range: " + value); - } - return result; - } - public static long toLong(final int high, final int low) { return ((long) high << 32) | ((long) low & 0xffffffffL); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index 3f41b6f..6162adb 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -264,7 +264,7 @@ public final class MathUtils { return r; } - public static boolean equals(@Nonnull final float value, final float expected, final float delta) { + public static boolean equals(final float value, final float expected, final float delta) { if (Double.isNaN(value)) { return false; } @@ -274,8 +274,7 @@ public final class MathUtils { return true; } - public static boolean equals(@Nonnull final double value, final double expected, - final double delta) { + public static boolean equals(final double value, final double expected, final double delta) { if (Double.isNaN(value)) { return false; } @@ -285,26 +284,34 @@ public final class MathUtils { return true; } - public static boolean almostEquals(@Nonnull final float value, final float expected) { + public static boolean almostEquals(final float value, final float expected) { return equals(value, expected, 1E-15f); } - public static boolean almostEquals(@Nonnull final double value, final double expected) { + public static boolean almostEquals(final double value, final double expected) { return equals(value, expected, 1E-15d); } - public static boolean closeToZero(@Nonnull final float value) { - if (Math.abs(value) > 1E-15f) { - return false; + public static boolean closeToZero(final float value) { + return closeToZero(value, 1E-15f); + } + + public static boolean closeToZero(final float value, @Nonnegative final float tol) { + if (value == 0.f) { + return true; } - return true; + return Math.abs(value) <= tol; } - public static boolean closeToZero(@Nonnull final double value) { - if (Math.abs(value) > 1E-15d) { - return false; + public static boolean closeToZero(final double value) { + return closeToZero(value, 1E-15d); + } + + public static boolean closeToZero(final double value, @Nonnegative final double tol) { + if (value == 0.d) { + return true; } - return true; + return Math.abs(value) <= tol; } public static double sign(final double x) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java deleted file mode 100644 index 076387f..0000000 --- a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.fm; - -import hivemall.utils.buffer.HeapBuffer; -import hivemall.utils.collections.maps.Int2LongOpenHashTable; - -import java.io.IOException; - -import org.junit.Assert; -import org.junit.Test; - -public class FFMPredictionModelTest { - - @Test - public void testSerialize() throws IOException, ClassNotFoundException { - final int factors = 3; - final int entrySize = Entry.sizeOf(factors); - - HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE); - Int2LongOpenHashTable map = Int2LongOpenHashTable.newInstance(); - - Entry e1 = new Entry(buf, factors, buf.allocate(entrySize)); - e1.setW(1f); - e1.setV(new float[] {1f, -1f, -1f}); - - Entry e2 = new Entry(buf, factors, buf.allocate(entrySize)); - e2.setW(2f); - e2.setV(new float[] {1f, 2f, -1f}); - - Entry e3 = new Entry(buf, factors, buf.allocate(entrySize)); - e3.setW(3f); - e3.setV(new float[] {1f, 2f, 3f}); - - map.put(1, e1.getOffset()); - map.put(2, e2.getOffset()); - map.put(3, e3.getOffset()); - - FFMPredictionModel expected = new FFMPredictionModel(map, buf, 0.d, 3, - Feature.DEFAULT_NUM_FEATURES, Feature.DEFAULT_NUM_FIELDS); - byte[] b = expected.serialize(); - - FFMPredictionModel actual = FFMPredictionModel.deserialize(b, b.length); - Assert.assertEquals(3, actual.getNumFactors()); - Assert.assertEquals(Feature.DEFAULT_NUM_FEATURES, actual.getNumFeatures()); - Assert.assertEquals(Feature.DEFAULT_NUM_FIELDS, actual.getNumFields()); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/test/java/hivemall/fm/FeatureTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FeatureTest.java b/core/src/test/java/hivemall/fm/FeatureTest.java index 25e5671..911a4a5 100644 --- a/core/src/test/java/hivemall/fm/FeatureTest.java +++ b/core/src/test/java/hivemall/fm/FeatureTest.java @@ -34,7 +34,7 @@ public class FeatureTest { @Test public void testParseFFMFeature() throws HiveException { - IntFeature f1 = Feature.parseFFMFeature("2:1163:0.3651"); + IntFeature f1 = Feature.parseFFMFeature("2:1163:0.3651", -1); Assert.assertEquals(2, f1.getField()); Assert.assertEquals(1163, f1.getFeatureIndex()); Assert.assertEquals("1163", f1.getFeature()); @@ -85,4 +85,9 @@ public class FeatureTest { Feature.parseFeature("2:1163:0.3651", true); } + @Test(expected = HiveException.class) + public void testParseFeatureZeroIndex() throws HiveException { + Feature.parseFFMFeature("0:0.3652"); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3410ba64/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java index 792ede1..3b219c6 100644 --- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java @@ -23,11 +23,11 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.ArrayList; +import java.util.List; import java.util.zip.GZIPInputStream; import javax.annotation.Nonnull; -import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; @@ -44,32 +44,29 @@ public class FieldAwareFactorizationMachineUDTFTest { @Test public void testSGD() throws HiveException, IOException { - runTest("Pure SGD test", - "-classification -factors 10 -w0 -seed 43 -disable_adagrad -disable_ftrl", 0.60f); + runTest("Pure SGD test", "-opt sgd -classification -factors 10 -w0 -seed 43", 0.60f); } @Test - public void testSGDWithFTRL() throws HiveException, IOException { - runTest("SGD w/ FTRL test", "-classification -factors 10 -w0 -seed 43 -disable_adagrad", - 0.60f); + public void testAdaGrad() throws HiveException, IOException { + runTest("AdaGrad test", "-opt adagrad -classification -factors 10 -w0 -seed 43", 0.30f); } @Test public void testAdaGradNoCoeff() throws HiveException, IOException { - runTest("AdaGrad No Coeff test", "-classification -factors 10 -w0 -seed 43 -no_coeff", - 0.30f); + runTest("AdaGrad No Coeff test", + "-opt adagrad -no_coeff -classification -factors 10 -w0 -seed 43", 0.30f); } @Test - public void testAdaGradNoFTRL() throws HiveException, IOException { - runTest("AdaGrad w/o FTRL test", "-classification -factors 10 -w0 -seed 43 -disable_ftrl", - 0.30f); + public void testFTRL() throws HiveException, IOException { + runTest("FTRL test", "-opt ftrl -classification -factors 10 -w0 -seed 43", 0.30f); } @Test - public void testAdaGradDefault() throws HiveException, IOException { - runTest("AdaGrad DEFAULT (adagrad for V + FTRL for W)", - "-classification -factors 10 -w0 -seed 43", 0.30f); + public void testFTRLNoCoeff() throws HiveException, IOException { + runTest("FTRL Coeff test", "-opt ftrl -no_coeff -classification -factors 10 -w0 -seed 43", + 0.30f); } private static void runTest(String testName, String testOptions, float lossThreshold) @@ -100,30 +97,22 @@ public class FieldAwareFactorizationMachineUDTFTest { if (input == null) { break; } - ArrayList<String> featureStrings = new ArrayList<String>(); - ArrayList<StringFeature> features = new ArrayList<StringFeature>(); - - //make StringFeature for each word = data point - String remaining = input; - int wordCut = remaining.indexOf(' '); - while (wordCut != -1) { - featureStrings.add(remaining.substring(0, wordCut)); - remaining = remaining.substring(wordCut + 1); - wordCut = remaining.indexOf(' '); - } - int end = featureStrings.size(); - double y = Double.parseDouble(featureStrings.get(0)); + String[] featureStrings = input.split(" "); + + double y = Double.parseDouble(featureStrings[0]); if (y == 0) { y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1} } - for (int wordNumber = 1; wordNumber < end; ++wordNumber) { - String entireFeature = featureStrings.get(wordNumber); - int featureCut = StringUtils.ordinalIndexOf(entireFeature, ":", 2); - String feature = entireFeature.substring(0, featureCut); - double value = Double.parseDouble(entireFeature.substring(featureCut + 1)); - features.add(new StringFeature(feature, value)); + + final List<String> features = new ArrayList<String>(featureStrings.length - 1); + for (int j = 1; j < featureStrings.length; ++j) { + String[] splitted = featureStrings[j].split(":"); + Assert.assertEquals(3, splitted.length); + int index = Integer.parseInt(splitted[1]) + 1; + String f = splitted[0] + ':' + index + ':' + splitted[2]; + features.add(f); } - udtf.process(new Object[] {toStringArray(features), y}); + udtf.process(new Object[] {features, y}); } cumul = udtf._cvState.getCumulativeLoss(); loss = (cumul - loss) / lines; @@ -143,15 +132,6 @@ public class FieldAwareFactorizationMachineUDTFTest { return new BufferedReader(new InputStreamReader(is)); } - private static String[] toStringArray(ArrayList<StringFeature> x) { - final int size = x.size(); - final String[] ret = new String[size]; - for (int i = 0; i < size; i++) { - ret[i] = x.get(i).toString(); - } - return ret; - } - private static void println(String line) { if (DEBUG) { System.out.println(line);