http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java new file mode 100644 index 0000000..6a7f39f --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import hivemall.utils.math.Primes; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Arrays; + +/** + * An open-addressing hash table with float hashing + * + * @see http://en.wikipedia.org/wiki/float_hashing + */ +public final class Long2FloatOpenHashTable implements Externalizable { + + protected static final byte FREE = 0; + protected static final byte FULL = 1; + protected static final byte REMOVED = 2; + + private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_GROW_FACTOR = 2.0f; + + protected final transient float _loadFactor; + protected final transient float _growFactor; + + protected int _used = 0; + protected int _threshold; + protected float defaultReturnValue = 0.f; + + protected long[] _keys; + protected float[] _values; + protected byte[] _states; + + protected Long2FloatOpenHashTable(int size, float loadFactor, float growFactor, + boolean forcePrime) { + if (size < 1) { + throw new IllegalArgumentException(); + } + this._loadFactor = loadFactor; + this._growFactor = growFactor; + int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size; + this._keys = new long[actualSize]; + this._values = new float[actualSize]; + this._states = new byte[actualSize]; + this._threshold = (int) (actualSize * _loadFactor); + } + + public Long2FloatOpenHashTable(int size, int loadFactor, int growFactor) { + this(size, loadFactor, growFactor, true); + } + + public Long2FloatOpenHashTable(int size) { + this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); + } + + public Long2FloatOpenHashTable() {// required for serialization + this._loadFactor = DEFAULT_LOAD_FACTOR; + this._growFactor = DEFAULT_GROW_FACTOR; + } + + public void defaultReturnValue(float v) { + this.defaultReturnValue = v; + } + + public boolean containsKey(final long key) { + return _findKey(key) >= 0; + } + + /** + * @return defaultReturnValue if not found + */ + public float get(final long key) { + return get(key, defaultReturnValue); + } + + public float get(final long key, final float defaultValue) { + final int i = _findKey(key); + if (i < 0) { + return defaultValue; + } + return _values[i]; + } + + public float _get(final int index) { + if (index < 0) { + return defaultReturnValue; + } + return _values[index]; + } + + public float put(final long key, final float value) { + final int hash = keyHash(key); + int keyLength = _keys.length; + int keyIdx = hash % keyLength; + + boolean expanded = preAddEntry(keyIdx); + if (expanded) { + keyLength = _keys.length; + keyIdx = hash % keyLength; + } + + final long[] keys = _keys; + final float[] values = _values; + final byte[] states = _states; + + if (states[keyIdx] == FULL) {// float hashing + if (keys[keyIdx] == key) { + float old = values[keyIdx]; + values[keyIdx] = value; + return old; + } + // try second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + break; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + float old = values[keyIdx]; + values[keyIdx] = value; + return old; + } + } + } + keys[keyIdx] = key; + values[keyIdx] = value; + states[keyIdx] = FULL; + ++_used; + return defaultReturnValue; + } + + /** Return weather the required slot is free for new entry */ + protected boolean isFree(final int index, final long key) { + final byte stat = _states[index]; + if (stat == FREE) { + return true; + } + if (stat == REMOVED && _keys[index] == key) { + return true; + } + return false; + } + + /** @return expanded or not */ + protected boolean preAddEntry(final int index) { + if ((_used + 1) >= _threshold) {// too filled + int newCapacity = Math.round(_keys.length * _growFactor); + ensureCapacity(newCapacity); + return true; + } + return false; + } + + /** + * @return -1 if not found + */ + public int _findKey(final long key) { + final long[] keys = _keys; + final byte[] states = _states; + final int keyLength = keys.length; + + final int hash = keyHash(key); + int keyIdx = hash % keyLength; + if (states[keyIdx] != FREE) { + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + return keyIdx; + } + // try second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + return -1; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + return keyIdx; + } + } + } + return -1; + } + + public float remove(final long key) { + final long[] keys = _keys; + final float[] values = _values; + final byte[] states = _states; + final int keyLength = keys.length; + + final int hash = keyHash(key); + int keyIdx = hash % keyLength; + if (states[keyIdx] != FREE) { + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + float old = values[keyIdx]; + states[keyIdx] = REMOVED; + --_used; + return old; + } + // second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (states[keyIdx] == FREE) { + return defaultReturnValue; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + float old = values[keyIdx]; + states[keyIdx] = REMOVED; + --_used; + return old; + } + } + } + return defaultReturnValue; + } + + public int size() { + return _used; + } + + public void clear() { + Arrays.fill(_states, FREE); + this._used = 0; + } + + public IMapIterator entries() { + return new MapIterator(); + } + + @Override + public String toString() { + int len = size() * 10 + 2; + StringBuilder buf = new StringBuilder(len); + buf.append('{'); + IMapIterator i = entries(); + while (i.next() != -1) { + buf.append(i.getKey()); + buf.append('='); + buf.append(i.getValue()); + if (i.hasNext()) { + buf.append(','); + } + } + buf.append('}'); + return buf.toString(); + } + + protected void ensureCapacity(final int newCapacity) { + int prime = Primes.findLeastPrimeNumber(newCapacity); + rehash(prime); + this._threshold = Math.round(prime * _loadFactor); + } + + private void rehash(final int newCapacity) { + int oldCapacity = _keys.length; + if (newCapacity <= oldCapacity) { + throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); + } + final long[] newkeys = new long[newCapacity]; + final float[] newValues = new float[newCapacity]; + final byte[] newStates = new byte[newCapacity]; + int used = 0; + for (int i = 0; i < oldCapacity; i++) { + if (_states[i] == FULL) { + used++; + long k = _keys[i]; + float v = _values[i]; + int hash = keyHash(k); + int keyIdx = hash % newCapacity; + if (newStates[keyIdx] == FULL) {// second hashing + int decr = 1 + (hash % (newCapacity - 2)); + while (newStates[keyIdx] != FREE) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += newCapacity; + } + } + } + newkeys[keyIdx] = k; + newValues[keyIdx] = v; + newStates[keyIdx] = FULL; + } + } + this._keys = newkeys; + this._values = newValues; + this._states = newStates; + this._used = used; + } + + private static int keyHash(final long key) { + return (int) (key ^ (key >>> 32)) & 0x7FFFFFFF; + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeInt(_threshold); + out.writeInt(_used); + + out.writeInt(_keys.length); + IMapIterator i = entries(); + while (i.next() != -1) { + out.writeLong(i.getKey()); + out.writeFloat(i.getValue()); + } + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this._threshold = in.readInt(); + this._used = in.readInt(); + + final int keylen = in.readInt(); + final long[] keys = new long[keylen]; + final float[] values = new float[keylen]; + final byte[] states = new byte[keylen]; + for (int i = 0; i < _used; i++) { + long k = in.readLong(); + float v = in.readFloat(); + int hash = keyHash(k); + int keyIdx = hash % keylen; + if (states[keyIdx] != FREE) {// second hash + int decr = 1 + (hash % (keylen - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keylen; + } + if (states[keyIdx] == FREE) { + break; + } + } + } + states[keyIdx] = FULL; + keys[keyIdx] = k; + values[keyIdx] = v; + } + this._keys = keys; + this._values = values; + this._states = states; + } + + public interface IMapIterator { + + public boolean hasNext(); + + /** + * @return -1 if not found + */ + public int next(); + + public long getKey(); + + public float getValue(); + + } + + private final class MapIterator implements IMapIterator { + + int nextEntry; + int lastEntry = -1; + + MapIterator() { + this.nextEntry = nextEntry(0); + } + + /** find the index of next full entry */ + int nextEntry(int index) { + while (index < _keys.length && _states[index] != FULL) { + index++; + } + return index; + } + + public boolean hasNext() { + return nextEntry < _keys.length; + } + + public int next() { + if (!hasNext()) { + return -1; + } + int curEntry = nextEntry; + this.lastEntry = curEntry; + this.nextEntry = nextEntry(curEntry + 1); + return curEntry; + } + + public long getKey() { + if (lastEntry == -1) { + throw new IllegalStateException(); + } + return _keys[lastEntry]; + } + + public float getValue() { + if (lastEntry == -1) { + throw new IllegalStateException(); + } + return _values[lastEntry]; + } + } +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java new file mode 100644 index 0000000..51b8f12 --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import hivemall.utils.math.Primes; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Arrays; + +/** + * An open-addressing hash table with double hashing + * + * @see http://en.wikipedia.org/wiki/Double_hashing + */ +public final class Long2IntOpenHashTable implements Externalizable { + + protected static final byte FREE = 0; + protected static final byte FULL = 1; + protected static final byte REMOVED = 2; + + private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_GROW_FACTOR = 2.0f; + + protected final transient float _loadFactor; + protected final transient float _growFactor; + + protected int _used = 0; + protected int _threshold; + protected int defaultReturnValue = -1; + + protected long[] _keys; + protected int[] _values; + protected byte[] _states; + + protected Long2IntOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) { + if (size < 1) { + throw new IllegalArgumentException(); + } + this._loadFactor = loadFactor; + this._growFactor = growFactor; + int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size; + this._keys = new long[actualSize]; + this._values = new int[actualSize]; + this._states = new byte[actualSize]; + this._threshold = (int) (actualSize * _loadFactor); + } + + public Long2IntOpenHashTable(int size, int loadFactor, int growFactor) { + this(size, loadFactor, growFactor, true); + } + + public Long2IntOpenHashTable(int size) { + this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); + } + + public Long2IntOpenHashTable() {// required for serialization + this._loadFactor = DEFAULT_LOAD_FACTOR; + this._growFactor = DEFAULT_GROW_FACTOR; + } + + public void defaultReturnValue(int v) { + this.defaultReturnValue = v; + } + + public boolean containsKey(final long key) { + return _findKey(key) >= 0; + } + + /** + * @return defaultReturnValue if not found + */ + public int get(final long key) { + return get(key, defaultReturnValue); + } + + public int get(final long key, final int defaultValue) { + final int i = _findKey(key); + if (i < 0) { + return defaultValue; + } + return _values[i]; + } + + public int _get(final int index) { + if (index < 0) { + return defaultReturnValue; + } + return _values[index]; + } + + public int put(final long key, final int value) { + final int hash = keyHash(key); + int keyLength = _keys.length; + int keyIdx = hash % keyLength; + + boolean expanded = preAddEntry(keyIdx); + if (expanded) { + keyLength = _keys.length; + keyIdx = hash % keyLength; + } + + final long[] keys = _keys; + final int[] values = _values; + final byte[] states = _states; + + if (states[keyIdx] == FULL) {// double hashing + if (keys[keyIdx] == key) { + int old = values[keyIdx]; + values[keyIdx] = value; + return old; + } + // try second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + break; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + int old = values[keyIdx]; + values[keyIdx] = value; + return old; + } + } + } + keys[keyIdx] = key; + values[keyIdx] = value; + states[keyIdx] = FULL; + ++_used; + return defaultReturnValue; + } + + public int incr(final long key, final int delta) { + final int hash = keyHash(key); + int keyLength = _keys.length; + int keyIdx = hash % keyLength; + + boolean expanded = preAddEntry(keyIdx); + if (expanded) { + keyLength = _keys.length; + keyIdx = hash % keyLength; + } + + final long[] keys = _keys; + final int[] values = _values; + final byte[] states = _states; + + if (states[keyIdx] == FULL) {// double hashing + if (keys[keyIdx] == key) { + int old = values[keyIdx]; + values[keyIdx] += delta; + return old; + } + // try second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + break; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + int old = values[keyIdx]; + values[keyIdx] += delta; + return old; + } + } + } + keys[keyIdx] = key; + values[keyIdx] += delta; + states[keyIdx] = FULL; + ++_used; + return defaultReturnValue; + } + + /** Return weather the required slot is free for new entry */ + protected boolean isFree(final int index, final long key) { + final byte stat = _states[index]; + if (stat == FREE) { + return true; + } + if (stat == REMOVED && _keys[index] == key) { + return true; + } + return false; + } + + /** @return expanded or not */ + protected boolean preAddEntry(final int index) { + if ((_used + 1) >= _threshold) {// too filled + int newCapacity = Math.round(_keys.length * _growFactor); + ensureCapacity(newCapacity); + return true; + } + return false; + } + + /** + * @return -1 if not found + */ + public int _findKey(final long key) { + final long[] keys = _keys; + final byte[] states = _states; + final int keyLength = keys.length; + + final int hash = keyHash(key); + int keyIdx = hash % keyLength; + if (states[keyIdx] != FREE) { + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + return keyIdx; + } + // try second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + return -1; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + return keyIdx; + } + } + } + return -1; + } + + public int remove(final long key) { + final long[] keys = _keys; + final int[] values = _values; + final byte[] states = _states; + final int keyLength = keys.length; + + final int hash = keyHash(key); + int keyIdx = hash % keyLength; + if (states[keyIdx] != FREE) { + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + int old = values[keyIdx]; + states[keyIdx] = REMOVED; + --_used; + return old; + } + // second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (states[keyIdx] == FREE) { + return defaultReturnValue; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + int old = values[keyIdx]; + states[keyIdx] = REMOVED; + --_used; + return old; + } + } + } + return defaultReturnValue; + } + + public int size() { + return _used; + } + + public void clear() { + Arrays.fill(_states, FREE); + this._used = 0; + } + + public IMapIterator entries() { + return new MapIterator(); + } + + @Override + public String toString() { + int len = size() * 10 + 2; + StringBuilder buf = new StringBuilder(len); + buf.append('{'); + IMapIterator i = entries(); + while (i.next() != -1) { + buf.append(i.getKey()); + buf.append('='); + buf.append(i.getValue()); + if (i.hasNext()) { + buf.append(','); + } + } + buf.append('}'); + return buf.toString(); + } + + protected void ensureCapacity(final int newCapacity) { + int prime = Primes.findLeastPrimeNumber(newCapacity); + rehash(prime); + this._threshold = Math.round(prime * _loadFactor); + } + + private void rehash(final int newCapacity) { + int oldCapacity = _keys.length; + if (newCapacity <= oldCapacity) { + throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); + } + final long[] newkeys = new long[newCapacity]; + final int[] newValues = new int[newCapacity]; + final byte[] newStates = new byte[newCapacity]; + int used = 0; + for (int i = 0; i < oldCapacity; i++) { + if (_states[i] == FULL) { + used++; + long k = _keys[i]; + int v = _values[i]; + int hash = keyHash(k); + int keyIdx = hash % newCapacity; + if (newStates[keyIdx] == FULL) {// second hashing + int decr = 1 + (hash % (newCapacity - 2)); + while (newStates[keyIdx] != FREE) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += newCapacity; + } + } + } + newkeys[keyIdx] = k; + newValues[keyIdx] = v; + newStates[keyIdx] = FULL; + } + } + this._keys = newkeys; + this._values = newValues; + this._states = newStates; + this._used = used; + } + + private static int keyHash(final long key) { + return (int) (key ^ (key >>> 32)) & 0x7FFFFFFF; + } + + public void writeExternal(ObjectOutput out) throws IOException { + out.writeInt(_threshold); + out.writeInt(_used); + + out.writeInt(_keys.length); + IMapIterator i = entries(); + while (i.next() != -1) { + out.writeLong(i.getKey()); + out.writeInt(i.getValue()); + } + } + + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this._threshold = in.readInt(); + this._used = in.readInt(); + + final int keylen = in.readInt(); + final long[] keys = new long[keylen]; + final int[] values = new int[keylen]; + final byte[] states = new byte[keylen]; + for (int i = 0; i < _used; i++) { + long k = in.readLong(); + int v = in.readInt(); + int hash = keyHash(k); + int keyIdx = hash % keylen; + if (states[keyIdx] != FREE) {// second hash + int decr = 1 + (hash % (keylen - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keylen; + } + if (states[keyIdx] == FREE) { + break; + } + } + } + states[keyIdx] = FULL; + keys[keyIdx] = k; + values[keyIdx] = v; + } + this._keys = keys; + this._values = values; + this._states = states; + } + + public interface IMapIterator { + + public boolean hasNext(); + + /** + * @return -1 if not found + */ + public int next(); + + public long getKey(); + + public int getValue(); + + } + + private final class MapIterator implements IMapIterator { + + int nextEntry; + int lastEntry = -1; + + MapIterator() { + this.nextEntry = nextEntry(0); + } + + /** find the index of next full entry */ + int nextEntry(int index) { + while (index < _keys.length && _states[index] != FULL) { + index++; + } + return index; + } + + public boolean hasNext() { + return nextEntry < _keys.length; + } + + public int next() { + if (!hasNext()) { + return -1; + } + int curEntry = nextEntry; + this.lastEntry = curEntry; + this.nextEntry = nextEntry(curEntry + 1); + return curEntry; + } + + public long getKey() { + if (lastEntry == -1) { + throw new IllegalStateException(); + } + return _keys[lastEntry]; + } + + public int getValue() { + if (lastEntry == -1) { + throw new IllegalStateException(); + } + return _values[lastEntry]; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java new file mode 100644 index 0000000..152447a --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// +// Copyright (C) 2010 catchpole.net +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +package hivemall.utils.collections.maps; + +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.lang.Copyable; +import hivemall.utils.math.MathUtils; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * An optimized Hashed Map implementation. + * <p/> + * <p> + * This Hashmap does not allow nulls to be used as keys or values. + * <p/> + * <p> + * It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those + * divisable by prime numbers. This allows the hash offset calculation to be a simple binary masking + * operation. + */ +public final class OpenHashMap<K, V> implements Map<K, V>, Externalizable { + private K[] keys; + private V[] values; + + // total number of entries in this table + private int size; + // number of bits for the value table (eg. 8 bits = 256 entries) + private int bits; + // the number of bits in each sweep zone. + private int sweepbits; + // the size of a sweep (2 to the power of sweepbits) + private int sweep; + // the sweepmask used to create sweep zone offsets + private int sweepmask; + + public OpenHashMap() {}// for Externalizable + + public OpenHashMap(int size) { + resize(MathUtils.bitsRequired(size < 256 ? 256 : size)); + } + + public V put(K key, V value) { + if (key == null) { + throw new NullPointerException(this.getClass().getName() + " key"); + } + + for (;;) { + int off = getBucketOffset(key); + int end = off + sweep; + for (; off < end; off++) { + K searchKey = keys[off]; + if (searchKey == null) { + // insert + keys[off] = key; + size++; + + V previous = values[off]; + values[off] = value; + return previous; + } else if (compare(searchKey, key)) { + // replace + V previous = values[off]; + values[off] = value; + return previous; + } + } + resize(this.bits + 1); + } + } + + public V get(Object key) { + int off = getBucketOffset(key); + int end = sweep + off; + for (; off < end; off++) { + if (keys[off] != null && compare(keys[off], key)) { + return values[off]; + } + } + return null; + } + + public V remove(Object key) { + int off = getBucketOffset(key); + int end = sweep + off; + for (; off < end; off++) { + if (keys[off] != null && compare(keys[off], key)) { + keys[off] = null; + V previous = values[off]; + values[off] = null; + size--; + return previous; + } + } + return null; + } + + public int size() { + return size; + } + + public void putAll(Map<? extends K, ? extends V> m) { + for (K key : m.keySet()) { + put(key, m.get(key)); + } + } + + public boolean isEmpty() { + return size == 0; + } + + public boolean containsKey(Object key) { + return get(key) != null; + } + + public boolean containsValue(Object value) { + for (V v : values) { + if (v != null && compare(v, value)) { + return true; + } + } + return false; + } + + public void clear() { + Arrays.fill(keys, null); + Arrays.fill(values, null); + size = 0; + } + + public Set<K> keySet() { + Set<K> set = new HashSet<K>(); + for (K key : keys) { + if (key != null) { + set.add(key); + } + } + return set; + } + + public Collection<V> values() { + Collection<V> list = new ArrayList<V>(); + for (V value : values) { + if (value != null) { + list.add(value); + } + } + return list; + } + + public Set<Entry<K, V>> entrySet() { + Set<Entry<K, V>> set = new HashSet<Entry<K, V>>(); + for (K key : keys) { + if (key != null) { + set.add(new MapEntry<K, V>(this, key)); + } + } + return set; + } + + private static final class MapEntry<K, V> implements Map.Entry<K, V> { + private final Map<K, V> map; + private final K key; + + public MapEntry(Map<K, V> map, K key) { + this.map = map; + this.key = key; + } + + public K getKey() { + return key; + } + + public V getValue() { + return map.get(key); + } + + public V setValue(V value) { + return map.put(key, value); + } + } + + public void writeExternal(ObjectOutput out) throws IOException { + // remember the number of bits + out.writeInt(this.bits); + // remember the total number of entries + out.writeInt(this.size); + // write all entries + for (int x = 0; x < this.keys.length; x++) { + if (keys[x] != null) { + out.writeObject(keys[x]); + out.writeObject(values[x]); + } + } + } + + @SuppressWarnings("unchecked") + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + // resize to old bit size + int bitSize = in.readInt(); + if (bitSize != bits) { + resize(bitSize); + } + // read all entries + int size = in.readInt(); + for (int x = 0; x < size; x++) { + this.put((K) in.readObject(), (V) in.readObject()); + } + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + ' ' + this.size; + } + + @SuppressWarnings("unchecked") + private void resize(int bits) { + this.bits = bits; + this.sweepbits = bits / 4; + this.sweep = MathUtils.powerOf(2, sweepbits) * 4; + this.sweepmask = MathUtils.bitMask(bits - this.sweepbits) << sweepbits; + + // remember old values so we can recreate the entries + K[] existingKeys = this.keys; + V[] existingValues = this.values; + + // create the arrays + this.values = (V[]) new Object[MathUtils.powerOf(2, bits) + sweep]; + this.keys = (K[]) new Object[values.length]; + this.size = 0; + + // re-add the previous entries if resizing + if (existingKeys != null) { + for (int x = 0; x < existingKeys.length; x++) { + if (existingKeys[x] != null) { + put(existingKeys[x], existingValues[x]); + } + } + } + } + + private int getBucketOffset(Object key) { + return (key.hashCode() << this.sweepbits) & this.sweepmask; + } + + private static boolean compare(final Object v1, final Object v2) { + return v1 == v2 || v1.equals(v2); + } + + public IMapIterator<K, V> entries() { + return new MapIterator(); + } + + private final class MapIterator implements IMapIterator<K, V> { + + int nextEntry; + int lastEntry = -1; + + MapIterator() { + this.nextEntry = nextEntry(0); + } + + /** find the index of next full entry */ + int nextEntry(int index) { + while (index < keys.length && keys[index] == null) { + index++; + } + return index; + } + + @Override + public boolean hasNext() { + return nextEntry < keys.length; + } + + @Override + public int next() { + free(lastEntry); + if (!hasNext()) { + return -1; + } + int curEntry = nextEntry; + this.lastEntry = curEntry; + this.nextEntry = nextEntry(curEntry + 1); + return curEntry; + } + + @Override + public K getKey() { + return keys[lastEntry]; + } + + @Override + public V getValue() { + return values[lastEntry]; + } + + @Override + public <T extends Copyable<V>> void getValue(T probe) { + probe.copyFrom(getValue()); + } + + private void free(int index) { + if (index >= 0) { + keys[index] = null; + values[index] = null; + } + } + + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java new file mode 100644 index 0000000..7fec9b0 --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.lang.Copyable; +import hivemall.utils.math.Primes; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Arrays; +import java.util.HashMap; + +import javax.annotation.Nonnull; + +/** + * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}. + */ +public final class OpenHashTable<K, V> implements Externalizable { + + public static final float DEFAULT_LOAD_FACTOR = 0.7f; + public static final float DEFAULT_GROW_FACTOR = 2.0f; + + protected static final byte FREE = 0; + protected static final byte FULL = 1; + protected static final byte REMOVED = 2; + + protected/* final */float _loadFactor; + protected/* final */float _growFactor; + + protected int _used = 0; + protected int _threshold; + + protected K[] _keys; + protected V[] _values; + protected byte[] _states; + + public OpenHashTable() {} // for Externalizable + + public OpenHashTable(int size) { + this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR); + } + + @SuppressWarnings("unchecked") + public OpenHashTable(int size, float loadFactor, float growFactor) { + if (size < 1) { + throw new IllegalArgumentException(); + } + this._loadFactor = loadFactor; + this._growFactor = growFactor; + int actualSize = Primes.findLeastPrimeNumber(size); + this._keys = (K[]) new Object[actualSize]; + this._values = (V[]) new Object[actualSize]; + this._states = new byte[actualSize]; + this._threshold = Math.round(actualSize * _loadFactor); + } + + public OpenHashTable(@Nonnull K[] keys, @Nonnull V[] values, @Nonnull byte[] states, int used) { + this._used = used; + this._threshold = keys.length; + this._keys = keys; + this._values = values; + this._states = states; + } + + public Object[] getKeys() { + return _keys; + } + + public Object[] getValues() { + return _values; + } + + public byte[] getStates() { + return _states; + } + + public boolean containsKey(final K key) { + return findKey(key) >= 0; + } + + public V get(final K key) { + final int i = findKey(key); + if (i < 0) { + return null; + } + return _values[i]; + } + + public V put(final K key, final V value) { + int hash = keyHash(key); + int keyLength = _keys.length; + int keyIdx = hash % keyLength; + + boolean expanded = preAddEntry(keyIdx); + if (expanded) { + keyLength = _keys.length; + keyIdx = hash % keyLength; + } + + K[] keys = _keys; + V[] values = _values; + byte[] states = _states; + + if (states[keyIdx] == FULL) { + if (equals(keys[keyIdx], key)) { + V old = values[keyIdx]; + values[keyIdx] = value; + return old; + } + // try second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + break; + } + if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) { + V old = values[keyIdx]; + values[keyIdx] = value; + return old; + } + } + } + keys[keyIdx] = key; + values[keyIdx] = value; + states[keyIdx] = FULL; + ++_used; + return null; + } + + private static boolean equals(final Object k1, final Object k2) { + return k1 == k2 || k1.equals(k2); + } + + /** Return weather the required slot is free for new entry */ + protected boolean isFree(int index, K key) { + byte stat = _states[index]; + if (stat == FREE) { + return true; + } + if (stat == REMOVED && equals(_keys[index], key)) { + return true; + } + return false; + } + + /** @return expanded or not */ + protected boolean preAddEntry(int index) { + if ((_used + 1) >= _threshold) {// filled enough + int newCapacity = Math.round(_keys.length * _growFactor); + ensureCapacity(newCapacity); + return true; + } + return false; + } + + protected int findKey(final K key) { + K[] keys = _keys; + byte[] states = _states; + int keyLength = keys.length; + + int hash = keyHash(key); + int keyIdx = hash % keyLength; + if (states[keyIdx] != FREE) { + if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) { + return keyIdx; + } + // try second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + return -1; + } + if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) { + return keyIdx; + } + } + } + return -1; + } + + public V remove(final K key) { + K[] keys = _keys; + V[] values = _values; + byte[] states = _states; + int keyLength = keys.length; + + int hash = keyHash(key); + int keyIdx = hash % keyLength; + if (states[keyIdx] != FREE) { + if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) { + V old = values[keyIdx]; + states[keyIdx] = REMOVED; + --_used; + return old; + } + // second hash + int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (states[keyIdx] == FREE) { + return null; + } + if (states[keyIdx] == FULL && equals(keys[keyIdx], key)) { + V old = values[keyIdx]; + states[keyIdx] = REMOVED; + --_used; + return old; + } + } + } + return null; + } + + public int size() { + return _used; + } + + public void clear() { + Arrays.fill(_states, FREE); + this._used = 0; + } + + public IMapIterator<K, V> entries() { + return new MapIterator(); + } + + @Override + public String toString() { + int len = size() * 10 + 2; + final StringBuilder buf = new StringBuilder(len); + buf.append('{'); + final IMapIterator<K, V> i = entries(); + while (i.next() != -1) { + String key = i.getKey().toString(); + buf.append(key); + buf.append('='); + buf.append(i.getValue()); + if (i.hasNext()) { + buf.append(','); + } + } + buf.append('}'); + return buf.toString(); + } + + protected void ensureCapacity(int newCapacity) { + int prime = Primes.findLeastPrimeNumber(newCapacity); + rehash(prime); + this._threshold = Math.round(prime * _loadFactor); + } + + @SuppressWarnings("unchecked") + private void rehash(int newCapacity) { + int oldCapacity = _keys.length; + if (newCapacity <= oldCapacity) { + throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); + } + final K[] newkeys = (K[]) new Object[newCapacity]; + final V[] newValues = (V[]) new Object[newCapacity]; + final byte[] newStates = new byte[newCapacity]; + int used = 0; + for (int i = 0; i < oldCapacity; i++) { + if (_states[i] == FULL) { + used++; + K k = _keys[i]; + V v = _values[i]; + int hash = keyHash(k); + int keyIdx = hash % newCapacity; + if (newStates[keyIdx] == FULL) {// second hashing + int decr = 1 + (hash % (newCapacity - 2)); + while (newStates[keyIdx] != FREE) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += newCapacity; + } + } + } + newStates[keyIdx] = FULL; + newkeys[keyIdx] = k; + newValues[keyIdx] = v; + } + } + this._keys = newkeys; + this._values = newValues; + this._states = newStates; + this._used = used; + } + + private static int keyHash(final Object key) { + int hash = key.hashCode(); + return hash & 0x7fffffff; + } + + private final class MapIterator implements IMapIterator<K, V> { + + int nextEntry; + int lastEntry = -1; + + MapIterator() { + this.nextEntry = nextEntry(0); + } + + /** find the index of next full entry */ + int nextEntry(int index) { + while (index < _keys.length && _states[index] != FULL) { + index++; + } + return index; + } + + public boolean hasNext() { + return nextEntry < _keys.length; + } + + public int next() { + if (!hasNext()) { + return -1; + } + int curEntry = nextEntry; + this.lastEntry = nextEntry; + this.nextEntry = nextEntry(nextEntry + 1); + return curEntry; + } + + public K getKey() { + if (lastEntry == -1) { + throw new IllegalStateException(); + } + return _keys[lastEntry]; + } + + public V getValue() { + if (lastEntry == -1) { + throw new IllegalStateException(); + } + return _values[lastEntry]; + } + + @Override + public <T extends Copyable<V>> void getValue(T probe) { + probe.copyFrom(getValue()); + } + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeFloat(_loadFactor); + out.writeFloat(_growFactor); + out.writeInt(_used); + + final int size = _keys.length; + out.writeInt(size); + + for (int i = 0; i < size; i++) { + out.writeObject(_keys[i]); + out.writeObject(_values[i]); + out.writeByte(_states[i]); + } + } + + @SuppressWarnings("unchecked") + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this._loadFactor = in.readFloat(); + this._growFactor = in.readFloat(); + this._used = in.readInt(); + + final int size = in.readInt(); + final Object[] keys = new Object[size]; + final Object[] values = new Object[size]; + final byte[] states = new byte[size]; + for (int i = 0; i < size; i++) { + keys[i] = in.readObject(); + values[i] = in.readObject(); + states[i] = in.readByte(); + } + this._threshold = size; + this._keys = (K[]) keys; + this._values = (V[]) values; + this._states = states; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java b/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java new file mode 100644 index 0000000..06b6a15 --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/sets/IntArraySet.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.sets; + +import hivemall.utils.lang.ArrayUtils; + +import java.util.Arrays; + +import javax.annotation.Nonnull; + +public final class IntArraySet implements IntSet { + + @Nonnull + private int[] mKeys; + private int mSize; + + public IntArraySet() { + this(0); + } + + public IntArraySet(int initSize) { + this.mKeys = new int[initSize]; + this.mSize = 0; + } + + @Override + public boolean add(final int k) { + final int i = Arrays.binarySearch(mKeys, 0, mSize, k); + if (i >= 0) { + return false; + } + mKeys = ArrayUtils.insert(mKeys, mSize, ~i, k); + mSize++; + return true; + } + + @Override + public boolean remove(final int k) { + final int i = Arrays.binarySearch(mKeys, 0, mSize, k); + if (i < 0) { + return false; + } + System.arraycopy(mKeys, i + 1, mKeys, i, mSize - (i + 1)); + mSize--; + return true; + } + + @Override + public boolean contains(final int k) { + return Arrays.binarySearch(mKeys, 0, mSize, k) >= 0; + } + + @Override + public int size() { + return mSize; + } + + @Override + public void clear() { + this.mSize = 0; + } + + @Override + public int[] toArray(final boolean copy) { + if (copy == false && mKeys.length == mSize) { + return mKeys; + } + + return Arrays.copyOf(mKeys, mSize); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/collections/sets/IntSet.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/sets/IntSet.java b/core/src/main/java/hivemall/utils/collections/sets/IntSet.java new file mode 100644 index 0000000..398955c --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/sets/IntSet.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.sets; + +import javax.annotation.Nonnull; + +public interface IntSet { + + public boolean add(int k); + + public boolean remove(int k); + + public boolean contains(int k); + + public int size(); + + public void clear(); + + @Nonnull + public int[] toArray(boolean copy); + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 5423c9d..b3a2de1 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -56,6 +56,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.StandardConstantListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; @@ -256,6 +258,10 @@ public final class HiveUtils { && isNumberListOI(((ListObjectInspector) oi).getListElementObjectInspector()); } + public static boolean isConstString(@Nonnull final ObjectInspector oi) { + return ObjectInspectorUtils.isConstantObjectInspector(oi) && isStringOI(oi); + } + public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) { return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE; } @@ -308,20 +314,43 @@ public final class HiveUtils { } } - public static boolean isStringTypeInfo(@Nonnull TypeInfo typeInfo) { + public static boolean isIntTypeInfo(@Nonnull TypeInfo typeInfo) { + if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) { + return false; + } + return ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory() == PrimitiveCategory.INT; + } + + public static boolean isFloatingPointTypeInfo(@Nonnull TypeInfo typeInfo) { if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) { return false; } switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { - case STRING: + case DOUBLE: + case FLOAT: return true; default: return false; } } - public static boolean isConstString(@Nonnull final ObjectInspector oi) { - return ObjectInspectorUtils.isConstantObjectInspector(oi) && isStringOI(oi); + public static boolean isStringTypeInfo(@Nonnull TypeInfo typeInfo) { + if (typeInfo.getCategory() != ObjectInspector.Category.PRIMITIVE) { + return false; + } + return ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory() == PrimitiveCategory.STRING; + } + + public static boolean isListTypeInfo(@Nonnull TypeInfo typeInfo) { + return typeInfo.getCategory() == Category.LIST; + } + + public static boolean isFloatingPointListTypeInfo(@Nonnull TypeInfo typeInfo) { + if (typeInfo.getCategory() != Category.LIST) { + return false; + } + TypeInfo elemTypeInfo = ((ListTypeInfo) typeInfo).getListElementTypeInfo(); + return isFloatingPointTypeInfo(elemTypeInfo); } @Nonnull @@ -387,6 +416,38 @@ public final class HiveUtils { return ary; } + @Nullable + public static double[] getConstDoubleArray(@Nonnull final ObjectInspector oi) + throws UDFArgumentException { + if (!ObjectInspectorUtils.isConstantObjectInspector(oi)) { + throw new UDFArgumentException("argument must be a constant value: " + + TypeInfoUtils.getTypeInfoFromObjectInspector(oi)); + } + ConstantObjectInspector constOI = (ConstantObjectInspector) oi; + if (constOI.getCategory() != Category.LIST) { + throw new UDFArgumentException("argument must be an array: " + + TypeInfoUtils.getTypeInfoFromObjectInspector(oi)); + } + StandardConstantListObjectInspector listOI = (StandardConstantListObjectInspector) constOI; + PrimitiveObjectInspector elemOI = HiveUtils.asDoubleCompatibleOI(listOI.getListElementObjectInspector()); + + final List<?> lst = listOI.getWritableConstantValue(); + if (lst == null) { + return null; + } + final int size = lst.size(); + final double[] ary = new double[size]; + for (int i = 0; i < size; i++) { + Object o = lst.get(i); + if (o == null) { + ary[i] = Double.NaN; + } else { + ary[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); + } + } + return ary; + } + public static String getConstString(@Nonnull final ObjectInspector oi) throws UDFArgumentException { if (!isStringOI(oi)) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/lang/ArrayUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java index 24ed7fc..e8e337d 100644 --- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java +++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java @@ -170,6 +170,24 @@ public final class ArrayUtils { arr[j] = tmp; } + public static void swap(@Nonnull final long[] arr, final int i, final int j) { + long tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + + public static void swap(@Nonnull final float[] arr, final int i, final int j) { + float tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + + public static void swap(@Nonnull final double[] arr, final int i, final int j) { + double tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + @Nullable public static Object[] subarray(@Nullable final Object[] array, int startIndexInclusive, int endIndexExclusive) { @@ -198,7 +216,7 @@ public final class ArrayUtils { } } - public static int indexOf(@Nonnull final int[] array, final int valueToFind, + public static int indexOf(@Nullable final int[] array, final int valueToFind, final int startIndex, final int endIndex) { if (array == null) { return INDEX_NOT_FOUND; @@ -215,6 +233,36 @@ public final class ArrayUtils { return INDEX_NOT_FOUND; } + public static int lastIndexOf(@Nullable final int[] array, final int valueToFind, int startIndex) { + if (array == null) { + return INDEX_NOT_FOUND; + } + return lastIndexOf(array, valueToFind, startIndex, array.length); + } + + /** + * @param startIndex inclusive start index + * @param endIndex exclusive end index + */ + public static int lastIndexOf(@Nullable final int[] array, final int valueToFind, + int startIndex, int endIndex) { + if (array == null) { + return INDEX_NOT_FOUND; + } + if (startIndex < 0) { + throw new IllegalArgumentException("startIndex out of bound: " + startIndex); + } + if (endIndex >= array.length) { + throw new IllegalArgumentException("endIndex out of bound: " + endIndex); + } + for (int i = endIndex - 1; i >= startIndex; i--) { + if (valueToFind == array[i]) { + return i; + } + } + return INDEX_NOT_FOUND; + } + @Nonnull public static byte[] copyOf(@Nonnull final byte[] original, final int newLength) { final byte[] copy = new byte[newLength]; @@ -249,6 +297,17 @@ public final class ArrayUtils { } @Nonnull + public static float[] append(@Nonnull float[] array, final int currentSize, final float element) { + if (currentSize + 1 > array.length) { + float[] newArray = new float[currentSize * 2]; + System.arraycopy(array, 0, newArray, 0, currentSize); + array = newArray; + } + array[currentSize] = element; + return array; + } + + @Nonnull public static double[] append(@Nonnull double[] array, final int currentSize, final double element) { if (currentSize + 1 > array.length) { @@ -268,7 +327,22 @@ public final class ArrayUtils { array[index] = element; return array; } - int[] newArray = new int[currentSize * 2]; + final int[] newArray = new int[currentSize * 2]; + System.arraycopy(array, 0, newArray, 0, index); + newArray[index] = element; + System.arraycopy(array, index, newArray, index + 1, array.length - index); + return newArray; + } + + @Nonnull + public static float[] insert(@Nonnull final float[] array, final int currentSize, + final int index, final float element) { + if (currentSize + 1 <= array.length) { + System.arraycopy(array, index, array, index + 1, currentSize - index); + array[index] = element; + return array; + } + final float[] newArray = new float[currentSize * 2]; System.arraycopy(array, 0, newArray, 0, index); newArray[index] = element; System.arraycopy(array, index, newArray, index + 1, array.length - index); @@ -283,7 +357,7 @@ public final class ArrayUtils { array[index] = element; return array; } - double[] newArray = new double[currentSize * 2]; + final double[] newArray = new double[currentSize * 2]; System.arraycopy(array, 0, newArray, 0, index); newArray[index] = element; System.arraycopy(array, index, newArray, index + 1, array.length - index); @@ -314,4 +388,331 @@ public final class ArrayUtils { return true; } + public static void copy(@Nonnull final float[] src, @Nonnull final double[] dst) { + final int size = Math.min(src.length, dst.length); + for (int i = 0; i < size; i++) { + dst[i] = src[i]; + } + } + + public static void sort(final long[] arr, final double[] brr) { + sort(arr, brr, arr.length); + } + + public static void sort(final long[] arr, final double[] brr, final int n) { + final int NSTACK = 64; + final int M = 7; + final int[] istack = new int[NSTACK]; + + int jstack = -1; + int l = 0; + int ir = n - 1; + + int i, j, k; + long a; + double b; + for (;;) { + if (ir - l < M) { + for (j = l + 1; j <= ir; j++) { + a = arr[j]; + b = brr[j]; + for (i = j - 1; i >= l; i--) { + if (arr[i] <= a) { + break; + } + arr[i + 1] = arr[i]; + brr[i + 1] = brr[i]; + } + arr[i + 1] = a; + brr[i + 1] = b; + } + if (jstack < 0) { + break; + } + ir = istack[jstack--]; + l = istack[jstack--]; + } else { + k = (l + ir) >> 1; + swap(arr, k, l + 1); + swap(brr, k, l + 1); + if (arr[l] > arr[ir]) { + swap(arr, l, ir); + swap(brr, l, ir); + } + if (arr[l + 1] > arr[ir]) { + swap(arr, l + 1, ir); + swap(brr, l + 1, ir); + } + if (arr[l] > arr[l + 1]) { + swap(arr, l, l + 1); + swap(brr, l, l + 1); + } + i = l + 1; + j = ir; + a = arr[l + 1]; + b = brr[l + 1]; + for (;;) { + do { + i++; + } while (arr[i] < a); + do { + j--; + } while (arr[j] > a); + if (j < i) { + break; + } + swap(arr, i, j); + swap(brr, i, j); + } + arr[l + 1] = arr[j]; + arr[j] = a; + brr[l + 1] = brr[j]; + brr[j] = b; + jstack += 2; + + if (jstack >= NSTACK) { + throw new IllegalStateException("NSTACK too small in sort."); + } + + if (ir - i + 1 >= j - l) { + istack[jstack] = ir; + istack[jstack - 1] = i; + ir = j - 1; + } else { + istack[jstack] = j - 1; + istack[jstack - 1] = l; + l = i; + } + } + } + } + + public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr, + @Nonnull final double[] crr) { + sort(arr, brr, crr, arr.length); + } + + public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr, + @Nonnull final double[] crr, final int n) { + Preconditions.checkArgument(arr.length >= n); + Preconditions.checkArgument(brr.length >= n); + Preconditions.checkArgument(crr.length >= n); + + final int NSTACK = 64; + final int M = 7; + final int[] istack = new int[NSTACK]; + + int jstack = -1; + int l = 0; + int ir = n - 1; + + int i, j, k; + int a, b; + double c; + for (;;) { + if (ir - l < M) { + for (j = l + 1; j <= ir; j++) { + a = arr[j]; + b = brr[j]; + c = crr[j]; + for (i = j - 1; i >= l; i--) { + if (arr[i] <= a) { + break; + } + arr[i + 1] = arr[i]; + brr[i + 1] = brr[i]; + crr[i + 1] = crr[i]; + } + arr[i + 1] = a; + brr[i + 1] = b; + crr[i + 1] = c; + } + if (jstack < 0) { + break; + } + ir = istack[jstack--]; + l = istack[jstack--]; + } else { + k = (l + ir) >> 1; + swap(arr, k, l + 1); + swap(brr, k, l + 1); + swap(crr, k, l + 1); + if (arr[l] > arr[ir]) { + swap(arr, l, ir); + swap(brr, l, ir); + swap(crr, l, ir); + } + if (arr[l + 1] > arr[ir]) { + swap(arr, l + 1, ir); + swap(brr, l + 1, ir); + swap(crr, l + 1, ir); + } + if (arr[l] > arr[l + 1]) { + swap(arr, l, l + 1); + swap(brr, l, l + 1); + swap(crr, l, l + 1); + } + i = l + 1; + j = ir; + a = arr[l + 1]; + b = brr[l + 1]; + c = crr[l + 1]; + for (;;) { + do { + i++; + } while (arr[i] < a); + do { + j--; + } while (arr[j] > a); + if (j < i) { + break; + } + swap(arr, i, j); + swap(brr, i, j); + swap(crr, i, j); + } + arr[l + 1] = arr[j]; + arr[j] = a; + brr[l + 1] = brr[j]; + brr[j] = b; + crr[l + 1] = crr[j]; + crr[j] = c; + jstack += 2; + + if (jstack >= NSTACK) { + throw new IllegalStateException("NSTACK too small in sort."); + } + + if (ir - i + 1 >= j - l) { + istack[jstack] = ir; + istack[jstack - 1] = i; + ir = j - 1; + } else { + istack[jstack] = j - 1; + istack[jstack - 1] = l; + l = i; + } + } + } + } + + public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr, + @Nonnull final float[] crr) { + sort(arr, brr, crr, arr.length); + } + + public static void sort(@Nonnull final int[] arr, @Nonnull final int[] brr, + @Nonnull final float[] crr, final int n) { + Preconditions.checkArgument(arr.length >= n); + Preconditions.checkArgument(brr.length >= n); + Preconditions.checkArgument(crr.length >= n); + + final int NSTACK = 64; + final int M = 7; + final int[] istack = new int[NSTACK]; + + int jstack = -1; + int l = 0; + int ir = n - 1; + + int i, j, k; + int a, b; + float c; + for (;;) { + if (ir - l < M) { + for (j = l + 1; j <= ir; j++) { + a = arr[j]; + b = brr[j]; + c = crr[j]; + for (i = j - 1; i >= l; i--) { + if (arr[i] <= a) { + break; + } + arr[i + 1] = arr[i]; + brr[i + 1] = brr[i]; + crr[i + 1] = crr[i]; + } + arr[i + 1] = a; + brr[i + 1] = b; + crr[i + 1] = c; + } + if (jstack < 0) { + break; + } + ir = istack[jstack--]; + l = istack[jstack--]; + } else { + k = (l + ir) >> 1; + swap(arr, k, l + 1); + swap(brr, k, l + 1); + swap(crr, k, l + 1); + if (arr[l] > arr[ir]) { + swap(arr, l, ir); + swap(brr, l, ir); + swap(crr, l, ir); + } + if (arr[l + 1] > arr[ir]) { + swap(arr, l + 1, ir); + swap(brr, l + 1, ir); + swap(crr, l + 1, ir); + } + if (arr[l] > arr[l + 1]) { + swap(arr, l, l + 1); + swap(brr, l, l + 1); + swap(crr, l, l + 1); + } + i = l + 1; + j = ir; + a = arr[l + 1]; + b = brr[l + 1]; + c = crr[l + 1]; + for (;;) { + do { + i++; + } while (arr[i] < a); + do { + j--; + } while (arr[j] > a); + if (j < i) { + break; + } + swap(arr, i, j); + swap(brr, i, j); + swap(crr, i, j); + } + arr[l + 1] = arr[j]; + arr[j] = a; + brr[l + 1] = brr[j]; + brr[j] = b; + crr[l + 1] = crr[j]; + crr[j] = c; + jstack += 2; + + if (jstack >= NSTACK) { + throw new IllegalStateException("NSTACK too small in sort."); + } + + if (ir - i + 1 >= j - l) { + istack[jstack] = ir; + istack[jstack - 1] = i; + ir = j - 1; + } else { + istack[jstack] = j - 1; + istack[jstack - 1] = l; + l = i; + } + } + } + } + + public static int count(@Nonnull final int[] values, final int valueToFind) { + int cnt = 0; + for (int i = 0; i < values.length; i++) { + if (values[i] == valueToFind) { + cnt++; + } + } + return cnt; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/lang/Primitives.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java index 8f018f0..31cd8a8 100644 --- a/core/src/main/java/hivemall/utils/lang/Primitives.java +++ b/core/src/main/java/hivemall/utils/lang/Primitives.java @@ -18,6 +18,8 @@ */ package hivemall.utils.lang; +import javax.annotation.Nonnull; + public final class Primitives { public static final int INT_BYTES = Integer.SIZE / Byte.SIZE; public static final int DOUBLE_BYTES = Double.SIZE / Byte.SIZE; @@ -99,4 +101,30 @@ public final class Primitives { return result; } + public static long toLong(final int high, final int low) { + return ((long) high << 32) | ((long) low & 0xffffffffL); + } + + public static int getHigh(final long key) { + return (int) (key >>> 32) & 0xffffffff; + } + + public static int getLow(final long key) { + return (int) key & 0xffffffff; + } + + @Nonnull + public static byte[] toBytes(long l) { + final byte[] retVal = new byte[8]; + for (int i = 0; i < 8; i++) { + retVal[i] = (byte) l; + l >>= 8; + } + return retVal; + } + + public static int hashCode(final long value) { + return (int) (value ^ (value >>> 32)); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index 252ccf6..b71d165 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -36,6 +36,7 @@ package hivemall.utils.math; import java.util.Random; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; public final class MathUtils { @@ -250,6 +251,9 @@ public final class MathUtils { } public static boolean equals(@Nonnull final float value, final float expected, final float delta) { + if (Double.isNaN(value)) { + return false; + } if (Math.abs(expected - value) > delta) { return false; } @@ -258,19 +262,20 @@ public final class MathUtils { public static boolean equals(@Nonnull final double value, final double expected, final double delta) { + if (Double.isNaN(value)) { + return false; + } if (Math.abs(expected - value) > delta) { return false; } return true; } - public static boolean almostEquals(@Nonnull final float value, final float expected, - final float delta) { + public static boolean almostEquals(@Nonnull final float value, final float expected) { return equals(value, expected, 1E-15f); } - public static boolean almostEquals(@Nonnull final double value, final double expected, - final double delta) { + public static boolean almostEquals(@Nonnull final double value, final double expected) { return equals(value, expected, 1E-15d); } @@ -297,4 +302,13 @@ public final class MathUtils { return 0; // 0 or NaN } + @Nonnull + public static int[] permutation(@Nonnegative final int size) { + final int[] perm = new int[size]; + for (int i = 0; i < size; i++) { + perm[i] = i; + } + return perm; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/math/MatrixUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MatrixUtils.java b/core/src/main/java/hivemall/utils/math/MatrixUtils.java index 66d6e8c..a0e5fc7 100644 --- a/core/src/main/java/hivemall/utils/math/MatrixUtils.java +++ b/core/src/main/java/hivemall/utils/math/MatrixUtils.java @@ -18,7 +18,7 @@ */ package hivemall.utils.math; -import hivemall.utils.collections.DoubleArrayList; +import hivemall.utils.collections.lists.DoubleArrayList; import hivemall.utils.lang.Preconditions; import java.util.Arrays; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java b/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java new file mode 100644 index 0000000..f86a788 --- /dev/null +++ b/core/src/main/java/hivemall/utils/sampling/IntReservoirSampler.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.sampling; + +import java.util.Arrays; +import java.util.Random; + +import javax.annotation.Nonnull; + +/** + * Vitter's reservoir sampling implementation that randomly chooses k items from a list containing n items. + * + * @link http://en.wikipedia.org/wiki/Reservoir_sampling + * @link http://portal.acm.org/citation.cfm?id=3165 + */ +public final class IntReservoirSampler { + + private final int[] samples; + private final int numSamples; + private int position; + + private final Random rand; + + public IntReservoirSampler(int sampleSize) { + if (sampleSize <= 0) { + throw new IllegalArgumentException("sampleSize must be greater than 1: " + sampleSize); + } + this.samples = new int[sampleSize]; + this.numSamples = sampleSize; + this.position = 0; + this.rand = new Random(); + } + + public IntReservoirSampler(int sampleSize, long seed) { + this.samples = new int[sampleSize]; + this.numSamples = sampleSize; + this.position = 0; + this.rand = new Random(seed); + } + + public IntReservoirSampler(int[] samples) { + this.samples = samples; + this.numSamples = samples.length; + this.position = 0; + this.rand = new Random(); + } + + public IntReservoirSampler(int[] samples, long seed) { + this.samples = samples; + this.numSamples = samples.length; + this.position = 0; + this.rand = new Random(seed); + } + + public int size() { + return position; + } + + @Nonnull + public int[] getSample() { + if (position >= numSamples) { + return samples; + } + return Arrays.copyOf(samples, position); + } + + public void add(final int item) { + if (position < numSamples) {// reservoir not yet full, just append + samples[position] = item; + } else {// find a item to replace + int replaceIndex = rand.nextInt(position + 1); + if (replaceIndex < numSamples) { + samples[replaceIndex] = item; + } + } + position++; + } + + public void clear() { + Arrays.fill(samples, 0); + this.position = 0; + } +}