http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/SparseMatrixStorage.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/SparseMatrixStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/SparseMatrixStorage.java new file mode 100644 index 0000000..fb2f3d9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/SparseMatrixStorage.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.primitives.matrix.storage; + +import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2DoubleRBTreeMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap; +import it.unimi.dsi.fastutil.ints.IntSet; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.math.primitives.matrix.MatrixStorage; +import org.apache.ignite.ml.math.StorageConstants; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; + +/** + * Storage for sparse, local, on-heap matrix. + */ +public class SparseMatrixStorage implements MatrixStorage, StorageConstants { + /** Default zero value. */ + private static final double DEFAULT_VALUE = 0.0; + /** */ + private int rows; + /** */ + private int cols; + /** */ + private int acsMode; + /** */ + private int stoMode; + + /** Actual map storage. */ + private Map<Integer, Map<Integer, Double>> sto; + + /** */ + public SparseMatrixStorage() { + // No-op. + } + + /** */ + public SparseMatrixStorage(int rows, int cols, int acsMode, int stoMode) { + assert rows > 0; + assert cols > 0; + assertAccessMode(acsMode); + assertStorageMode(stoMode); + + this.rows = rows; + this.cols = cols; + this.acsMode = acsMode; + this.stoMode = stoMode; + + sto = new HashMap<>(); + } + + /** + * @return Matrix elements storage mode. + */ + public int storageMode() { + return stoMode; + } + + /** {@inheritDoc} */ + @Override public int accessMode() { + return acsMode; + } + + /** {@inheritDoc} */ + @Override public double get(int x, int y) { + if (stoMode == ROW_STORAGE_MODE) { + Map<Integer, Double> row = sto.get(x); + + if (row != null) { + Double val = row.get(y); + + if (val != null) + return val; + } + + return DEFAULT_VALUE; + } + else { + Map<Integer, Double> col = sto.get(y); + + if (col != null) { + Double val = col.get(x); + + if (val != null) + return val; + } + + return DEFAULT_VALUE; + } + } + + /** {@inheritDoc} */ + @Override public void set(int x, int y, double v) { + // Ignore default values (currently 0.0). + if (v != DEFAULT_VALUE) { + if (stoMode == ROW_STORAGE_MODE) { + Map<Integer, Double> row = sto.computeIfAbsent(x, k -> + acsMode == SEQUENTIAL_ACCESS_MODE ? new Int2DoubleRBTreeMap() : new Int2DoubleOpenHashMap()); + + row.put(y, v); + } + else { + Map<Integer, Double> col = sto.computeIfAbsent(y, k -> + acsMode == SEQUENTIAL_ACCESS_MODE ? new Int2DoubleRBTreeMap() : new Int2DoubleOpenHashMap()); + + col.put(x, v); + } + } + else { + if (stoMode == ROW_STORAGE_MODE) { + if (sto.containsKey(x)) { + Map<Integer, Double> row = sto.get(x); + + if (row.containsKey(y)) + row.remove(y); + } + + } + else { + if (sto.containsKey(y)) { + Map<Integer, Double> col = sto.get(y); + + if (col.containsKey(x)) + col.remove(x); + } + } + } + } + + /** {@inheritDoc} */ + @Override public int columnSize() { + return cols; + } + + /** {@inheritDoc} */ + @Override public int rowSize() { + return rows; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeInt(rows); + out.writeInt(cols); + out.writeInt(acsMode); + out.writeInt(stoMode); + out.writeObject(sto); + } + + /** {@inheritDoc} */ + @SuppressWarnings({"unchecked"}) + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + rows = in.readInt(); + cols = in.readInt(); + acsMode = in.readInt(); + stoMode = in.readInt(); + sto = (Map<Integer, Map<Integer, Double>>)in.readObject(); + } + + /** {@inheritDoc} */ + @Override public boolean isSequentialAccess() { + return acsMode == SEQUENTIAL_ACCESS_MODE; + } + + /** {@inheritDoc} */ + @Override public boolean isDense() { + return false; + } + + /** {@inheritDoc} */ + @Override public boolean isRandomAccess() { + return acsMode == RANDOM_ACCESS_MODE; + } + + /** {@inheritDoc} */ + @Override public boolean isDistributed() { + return false; + } + + /** {@inheritDoc} */ + @Override public boolean isArrayBased() { + return false; + } + + // TODO: IGNITE-5777, optimize this + + /** {@inheritDoc} */ + @Override public double[] data() { + double[] res = new double[rows * cols]; + + boolean isRowStorage = stoMode == ROW_STORAGE_MODE; + + sto.forEach((fstIdx, map) -> + map.forEach((sndIdx, val) -> { + if (isRowStorage) + res[sndIdx * rows + fstIdx] = val; + else + res[fstIdx * cols + sndIdx] = val; + + })); + + return res; + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = 1; + + res = res * 37 + rows; + res = res * 37 + cols; + res = res * 37 + sto.hashCode(); + + return res; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + + if (o == null || getClass() != o.getClass()) + return false; + + SparseMatrixStorage that = (SparseMatrixStorage)o; + + return rows == that.rows && cols == that.cols && acsMode == that.acsMode && stoMode == that.stoMode + && (sto != null ? sto.equals(that.sto) : that.sto == null); + } + + /** */ + public void compute(int row, int col, IgniteTriFunction<Integer, Integer, Double, Double> f) { + sto.get(row).compute(col, (c, val) -> f.apply(row, c, val)); + } + + /** */ + public Int2ObjectArrayMap<IntSet> indexesMap() { + Int2ObjectArrayMap<IntSet> res = new Int2ObjectArrayMap<>(); + + for (Integer row : sto.keySet()) + res.put(row.intValue(), (IntSet)sto.get(row).keySet()); + + return res; + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/ViewMatrixStorage.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/ViewMatrixStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/ViewMatrixStorage.java new file mode 100644 index 0000000..2c3ba9a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/ViewMatrixStorage.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.primitives.matrix.storage; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import org.apache.ignite.ml.math.primitives.matrix.MatrixStorage; + +/** + * {@link MatrixStorage} implementation that delegates to parent matrix. + */ +public class ViewMatrixStorage implements MatrixStorage { + /** Parent matrix storage. */ + private MatrixStorage dlg; + + /** Row offset in the parent matrix. */ + private int rowOff; + /** Column offset in the parent matrix. */ + private int colOff; + + /** Amount of rows in the matrix. */ + private int rows; + /** Amount of columns in the matrix. */ + private int cols; + + /** + * + */ + public ViewMatrixStorage() { + // No-op. + } + + /** + * @param dlg Backing parent storage. + * @param rowOff Row offset to parent matrix. + * @param colOff Column offset to parent matrix. + * @param rows Amount of rows in the view. + * @param cols Amount of columns in the view. + */ + public ViewMatrixStorage(MatrixStorage dlg, int rowOff, int colOff, int rows, int cols) { + assert dlg != null; + assert rowOff >= 0; + assert colOff >= 0; + assert rows > 0; + assert cols > 0; + + this.dlg = dlg; + + this.rowOff = rowOff; + this.colOff = colOff; + + this.rows = rows; + this.cols = cols; + } + + /** + * + */ + public MatrixStorage delegate() { + return dlg; + } + + /** + * + */ + public int rowOffset() { + return rowOff; + } + + /** + * + */ + public int columnOffset() { + return colOff; + } + + /** + * + */ + public int rowsLength() { + return rows; + } + + /** + * + */ + public int columnsLength() { + return cols; + } + + /** {@inheritDoc} */ + @Override public double get(int x, int y) { + return dlg.get(rowOff + x, colOff + y); + } + + /** {@inheritDoc} */ + @Override public void set(int x, int y, double v) { + dlg.set(rowOff + x, colOff + y, v); + } + + /** {@inheritDoc} */ + @Override public int columnSize() { + return cols; + } + + /** {@inheritDoc} */ + @Override public int rowSize() { + return rows; + } + + /** {@inheritDoc} */ + @Override public int storageMode() { + return dlg.storageMode(); + } + + /** {@inheritDoc} */ + @Override public int accessMode() { + return dlg.accessMode(); + } + + /** {@inheritDoc} */ + @Override public boolean isArrayBased() { + return dlg.isArrayBased() && rowOff == 0 && colOff == 0; + } + + /** {@inheritDoc} */ + @Override public boolean isSequentialAccess() { + return dlg.isSequentialAccess(); + } + + /** {@inheritDoc} */ + @Override public boolean isDense() { + return dlg.isDense(); + } + + /** {@inheritDoc} */ + @Override public boolean isRandomAccess() { + return dlg.isRandomAccess(); + } + + /** {@inheritDoc} */ + @Override public boolean isDistributed() { + return dlg.isDistributed(); + } + + /** {@inheritDoc} */ + @Override public double[] data() { + return dlg.data(); + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(dlg); + + out.writeInt(rowOff); + out.writeInt(colOff); + + out.writeInt(rows); + out.writeInt(cols); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + dlg = (MatrixStorage)in.readObject(); + + rowOff = in.readInt(); + colOff = in.readInt(); + + rows = in.readInt(); + cols = in.readInt(); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = 1; + + res = res * 37 + rows; + res = res * 37 + cols; + res = res * 37 + rowOff; + res = res * 37 + colOff; + res = res * 37 + dlg.hashCode(); + + return res; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + + if (o == null || getClass() != o.getClass()) + return false; + + ViewMatrixStorage that = (ViewMatrixStorage)o; + + return rows == that.rows && cols == that.cols && rowOff == that.rowOff && colOff == that.colOff && + (dlg != null ? dlg.equals(that.dlg) : that.dlg == null); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/package-info.java new file mode 100644 index 0000000..6e12073 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/matrix/storage/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains several matrix storages. + */ +package org.apache.ignite.ml.math.primitives.matrix.storage; http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/package-info.java new file mode 100644 index 0000000..fbcaa2e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains classes for vector/matrix algebra. + */ +package org.apache.ignite.ml.math.primitives; http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/AbstractVector.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/AbstractVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/AbstractVector.java new file mode 100644 index 0000000..01b630e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/AbstractVector.java @@ -0,0 +1,914 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.primitives.vector; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Spliterator; +import java.util.function.Consumer; +import java.util.function.IntToDoubleFunction; +import org.apache.ignite.lang.IgniteUuid; +import org.apache.ignite.ml.math.primitives.matrix.Matrix; +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.exceptions.IndexException; +import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; +import org.apache.ignite.ml.math.functions.Functions; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteIntDoubleToDoubleBiFunction; +import org.apache.ignite.ml.math.primitives.matrix.impl.ViewMatrix; +import org.apache.ignite.ml.math.primitives.vector.impl.VectorView; +import org.jetbrains.annotations.NotNull; + +/** + * This class provides a helper implementation of the {@link Vector} + * interface to minimize the effort required to implement it. + * Subclasses may override some of the implemented methods if a more + * specific or optimized implementation is desirable. + */ +public abstract class AbstractVector implements Vector { + /** Vector storage implementation. */ + private VectorStorage sto; + + /** Meta attribute storage. */ + private Map<String, Object> meta = new HashMap<>(); + + /** Vector's GUID. */ + private IgniteUuid guid = IgniteUuid.randomUuid(); + + /** Cached value for length squared. */ + private double lenSq = 0.0; + + /** Maximum cached element. */ + private Element maxElm = null; + /** Minimum cached element. */ + private Element minElm = null; + + /** Readonly flag (false by default). */ + private boolean readOnly = false; + + /** Read-only error message. */ + private static final String RO_MSG = "Vector is read-only."; + + /** */ + private void ensureReadOnly() { + if (readOnly) + throw new UnsupportedOperationException(RO_MSG); + } + + /** + * @param sto Storage. + */ + public AbstractVector(VectorStorage sto) { + this(false, sto); + } + + /** + * @param readOnly Is read only. + * @param sto Storage. + */ + public AbstractVector(boolean readOnly, VectorStorage sto) { + assert sto != null; + + this.readOnly = readOnly; + this.sto = sto; + } + + /** + * + */ + public AbstractVector() { + // No-op. + } + + /** + * Set storage. + * + * @param sto Storage. + */ + protected void setStorage(VectorStorage sto) { + this.sto = sto; + } + + /** + * @param i Index. + * @param v Value. + */ + protected void storageSet(int i, double v) { + ensureReadOnly(); + + sto.set(i, v); + + // Reset cached values. + lenSq = 0.0; + maxElm = minElm = null; + } + + /** + * @param i Index. + * @return Value. + */ + protected double storageGet(int i) { + return sto.get(i); + } + + /** {@inheritDoc} */ + @Override public int size() { + return sto.size(); + } + + /** + * Check index bounds. + * + * @param idx Index to check. + */ + protected void checkIndex(int idx) { + if (idx < 0 || idx >= sto.size()) + throw new IndexException(idx); + } + + /** {@inheritDoc} */ + @Override public double get(int idx) { + checkIndex(idx); + + return storageGet(idx); + } + + /** {@inheritDoc} */ + @Override public double getX(int idx) { + return storageGet(idx); + } + + /** {@inheritDoc} */ + @Override public boolean isArrayBased() { + return sto.isArrayBased(); + } + + /** {@inheritDoc} */ + @Override public Vector sort() { + if (isArrayBased()) + Arrays.parallelSort(sto.data()); + else + throw new UnsupportedOperationException(); + + return this; + } + + /** {@inheritDoc} */ + @Override public Vector map(IgniteDoubleFunction<Double> fun) { + if (sto.isArrayBased()) { + double[] data = sto.data(); + + Arrays.setAll(data, (idx) -> fun.apply(data[idx])); + } + else { + int len = size(); + + for (int i = 0; i < len; i++) + storageSet(i, fun.apply(storageGet(i))); + } + + return this; + } + + /** {@inheritDoc} */ + @Override public Vector map(Vector vec, IgniteBiFunction<Double, Double, Double> fun) { + checkCardinality(vec); + + int len = size(); + + for (int i = 0; i < len; i++) + storageSet(i, fun.apply(storageGet(i), vec.get(i))); + + return this; + } + + /** {@inheritDoc} */ + @Override public Vector map(IgniteBiFunction<Double, Double, Double> fun, double y) { + int len = size(); + + for (int i = 0; i < len; i++) + storageSet(i, fun.apply(storageGet(i), y)); + + return this; + } + + /** + * @param idx Index. + * @return Value. + */ + protected Element makeElement(int idx) { + checkIndex(idx); + + return new Element() { + /** {@inheritDoc} */ + @Override public double get() { + return storageGet(idx); + } + + /** {@inheritDoc} */ + @Override public int index() { + return idx; + } + + /** {@inheritDoc} */ + @Override public void set(double val) { + storageSet(idx, val); + } + }; + } + + /** {@inheritDoc} */ + @Override public Element minElement() { + if (minElm == null) { + int minIdx = 0; + int len = size(); + + for (int i = 0; i < len; i++) + if (storageGet(i) < storageGet(minIdx)) + minIdx = i; + + minElm = makeElement(minIdx); + } + + return minElm; + } + + /** {@inheritDoc} */ + @Override public Element maxElement() { + if (maxElm == null) { + int maxIdx = 0; + int len = size(); + + for (int i = 0; i < len; i++) + if (storageGet(i) > storageGet(maxIdx)) + maxIdx = i; + + maxElm = makeElement(maxIdx); + } + + return maxElm; + } + + /** {@inheritDoc} */ + @Override public double minValue() { + return minElement().get(); + } + + /** {@inheritDoc} */ + @Override public double maxValue() { + return maxElement().get(); + } + + /** {@inheritDoc} */ + @Override public Vector set(int idx, double val) { + checkIndex(idx); + + storageSet(idx, val); + + return this; + } + + /** {@inheritDoc} */ + @Override public Vector setX(int idx, double val) { + storageSet(idx, val); + + return this; + } + + /** {@inheritDoc} */ + @Override public Vector increment(int idx, double val) { + checkIndex(idx); + + storageSet(idx, storageGet(idx) + val); + + return this; + } + + /** {@inheritDoc} */ + @Override public Vector incrementX(int idx, double val) { + storageSet(idx, storageGet(idx) + val); + + return this; + } + + /** + * Tests if given value is considered a zero value. + * + * @param val Value to check. + */ + protected boolean isZero(double val) { + return val == 0.0; + } + + /** {@inheritDoc} */ + @Override public double sum() { + double sum = 0; + int len = size(); + + for (int i = 0; i < len; i++) + sum += storageGet(i); + + return sum; + } + + /** {@inheritDoc} */ + @Override public IgniteUuid guid() { + return guid; + } + + /** {@inheritDoc} */ + @Override public Iterable<Element> all() { + return new Iterable<Element>() { + private int idx = 0; + + /** {@inheritDoc} */ + @NotNull + @Override public Iterator<Element> iterator() { + return new Iterator<Element>() { + /** {@inheritDoc} */ + @Override public boolean hasNext() { + return size() > 0 && idx < size(); + } + + /** {@inheritDoc} */ + @Override public Element next() { + if (hasNext()) + return getElement(idx++); + + throw new NoSuchElementException(); + } + }; + } + }; + } + + /** {@inheritDoc} */ + @Override public int nonZeroElements() { + int cnt = 0; + + for (Element ignored : nonZeroes()) + cnt++; + + return cnt; + } + + /** {@inheritDoc} */ + @Override public <T> T foldMap(IgniteBiFunction<T, Double, T> foldFun, IgniteDoubleFunction<Double> mapFun, + T zeroVal) { + T res = zeroVal; + int len = size(); + + for (int i = 0; i < len; i++) + res = foldFun.apply(res, mapFun.apply(storageGet(i))); + + return res; + } + + /** {@inheritDoc} */ + @Override public <T> T foldMap(Vector vec, IgniteBiFunction<T, Double, T> foldFun, + IgniteBiFunction<Double, Double, Double> combFun, T zeroVal) { + checkCardinality(vec); + + T res = zeroVal; + int len = size(); + + for (int i = 0; i < len; i++) + res = foldFun.apply(res, combFun.apply(storageGet(i), vec.getX(i))); + + return res; + } + + /** {@inheritDoc} */ + @Override public Iterable<Element> nonZeroes() { + return new Iterable<Element>() { + private int idx = 0; + private int idxNext = -1; + + /** {@inheritDoc} */ + @NotNull + @Override public Iterator<Element> iterator() { + return new Iterator<Element>() { + @Override public boolean hasNext() { + findNext(); + + return !over(); + } + + @Override public Element next() { + if (hasNext()) { + idx = idxNext; + + return getElement(idxNext); + } + + throw new NoSuchElementException(); + } + + private void findNext() { + if (over()) + return; + + if (idxNextInitialized() && idx != idxNext) + return; + + if (idxNextInitialized()) + idx = idxNext + 1; + + while (idx < size() && isZero(get(idx))) + idx++; + + idxNext = idx++; + } + + private boolean over() { + return idxNext >= size(); + } + + private boolean idxNextInitialized() { + return idxNext != -1; + } + }; + } + }; + } + + /** {@inheritDoc} */ + @Override public Map<String, Object> getMetaStorage() { + return meta; + } + + /** {@inheritDoc} */ + @Override public Vector assign(double val) { + if (sto.isArrayBased()) { + ensureReadOnly(); + + Arrays.fill(sto.data(), val); + } + else { + int len = size(); + + for (int i = 0; i < len; i++) + storageSet(i, val); + } + + return this; + } + + /** {@inheritDoc} */ + @Override public Vector assign(double[] vals) { + checkCardinality(vals); + + if (sto.isArrayBased()) { + ensureReadOnly(); + + System.arraycopy(vals, 0, sto.data(), 0, vals.length); + + lenSq = 0.0; + } + else { + int len = size(); + + for (int i = 0; i < len; i++) + storageSet(i, vals[i]); + } + + return this; + } + + /** {@inheritDoc} */ + @Override public Vector assign(Vector vec) { + checkCardinality(vec); + + for (Vector.Element x : vec.all()) + storageSet(x.index(), x.get()); + + return this; + } + + /** {@inheritDoc} */ + @Override public Vector assign(IntToDoubleFunction fun) { + assert fun != null; + + if (sto.isArrayBased()) { + ensureReadOnly(); + + Arrays.setAll(sto.data(), fun); + } + else { + int len = size(); + + for (int i = 0; i < len; i++) + storageSet(i, fun.applyAsDouble(i)); + } + + return this; + } + + /** {@inheritDoc} */ + @Override public Spliterator<Double> allSpliterator() { + return new Spliterator<Double>() { + /** {@inheritDoc} */ + @Override public boolean tryAdvance(Consumer<? super Double> act) { + int len = size(); + + for (int i = 0; i < len; i++) + act.accept(storageGet(i)); + + return true; + } + + /** {@inheritDoc} */ + @Override public Spliterator<Double> trySplit() { + return null; // No Splitting. + } + + /** {@inheritDoc} */ + @Override public long estimateSize() { + return size(); + } + + /** {@inheritDoc} */ + @Override public int characteristics() { + return ORDERED | SIZED; + } + }; + } + + /** {@inheritDoc} */ + @Override public Spliterator<Double> nonZeroSpliterator() { + return new Spliterator<Double>() { + /** {@inheritDoc} */ + @Override public boolean tryAdvance(Consumer<? super Double> act) { + int len = size(); + + for (int i = 0; i < len; i++) { + double val = storageGet(i); + + if (!isZero(val)) + act.accept(val); + } + + return true; + } + + /** {@inheritDoc} */ + @Override public Spliterator<Double> trySplit() { + return null; // No Splitting. + } + + /** {@inheritDoc} */ + @Override public long estimateSize() { + return nonZeroElements(); + } + + /** {@inheritDoc} */ + @Override public int characteristics() { + return ORDERED | SIZED; + } + }; + } + + /** {@inheritDoc} */ + @Override public double dot(Vector vec) { + checkCardinality(vec); + + double sum = 0.0; + int len = size(); + + for (int i = 0; i < len; i++) + sum += storageGet(i) * vec.getX(i); + + return sum; + } + + /** {@inheritDoc} */ + @Override public double getLengthSquared() { + if (lenSq == 0.0) + lenSq = dotSelf(); + + return lenSq; + } + + /** {@inheritDoc} */ + @Override public boolean isDense() { + return sto.isDense(); + } + + /** {@inheritDoc} */ + @Override public boolean isSequentialAccess() { + return sto.isSequentialAccess(); + } + + /** {@inheritDoc} */ + @Override public boolean isRandomAccess() { + return sto.isRandomAccess(); + } + + /** {@inheritDoc} */ + @Override public boolean isDistributed() { + return sto.isDistributed(); + } + + /** {@inheritDoc} */ + @Override public VectorStorage getStorage() { + return sto; + } + + /** {@inheritDoc} */ + @Override public Vector viewPart(int off, int len) { + return new VectorView(this, off, len); + } + + /** {@inheritDoc} */ + @Override public Matrix cross(Vector vec) { + Matrix res = likeMatrix(size(), vec.size()); + + if (res == null) + return null; + + for (Element e : nonZeroes()) { + int row = e.index(); + + res.assignRow(row, vec.times(getX(row))); + } + + return res; + } + + /** {@inheritDoc} */ + @Override public Matrix toMatrix(boolean rowLike) { + Matrix res = likeMatrix(rowLike ? 1 : size(), rowLike ? size() : 1); + + if (res == null) + return null; + + if (rowLike) + res.assignRow(0, this); + else + res.assignColumn(0, this); + + return res; + } + + /** {@inheritDoc} */ + @Override public Matrix toMatrixPlusOne(boolean rowLike, double zeroVal) { + Matrix res = likeMatrix(rowLike ? 1 : size() + 1, rowLike ? size() + 1 : 1); + + if (res == null) + return null; + + res.set(0, 0, zeroVal); + + if (rowLike) + new ViewMatrix(res, 0, 1, 1, size()).assignRow(0, this); + else + new ViewMatrix(res, 1, 0, size(), 1).assignColumn(0, this); + + return res; + } + + /** {@inheritDoc} */ + @Override public double getDistanceSquared(Vector vec) { + checkCardinality(vec); + + double thisLenSq = getLengthSquared(); + double thatLenSq = vec.getLengthSquared(); + double dot = dot(vec); + double distEst = thisLenSq + thatLenSq - 2 * dot; + + if (distEst > 1.0e-3 * (thisLenSq + thatLenSq)) + // The vectors are far enough from each other that the formula is accurate. + return Math.max(distEst, 0); + else + return foldMap(vec, Functions.PLUS, Functions.MINUS_SQUARED, 0d); + } + + /** + * @param vec Vector to check for valid cardinality. + */ + protected void checkCardinality(Vector vec) { + if (vec.size() != size()) + throw new CardinalityException(size(), vec.size()); + } + + /** + * @param vec Array to check for valid cardinality. + */ + protected void checkCardinality(double[] vec) { + if (vec.length != size()) + throw new CardinalityException(size(), vec.length); + } + + /** + * @param arr Array to check for valid cardinality. + */ + protected void checkCardinality(int[] arr) { + if (arr.length != size()) + throw new CardinalityException(size(), arr.length); + } + + /** {@inheritDoc} */ + @Override public Vector minus(Vector vec) { + checkCardinality(vec); + + Vector cp = copy(); + + return cp.map(vec, Functions.MINUS); + } + + /** {@inheritDoc} */ + @Override public Vector plus(double x) { + Vector cp = copy(); + + return x != 0.0 ? cp.map(Functions.plus(x)) : cp; + } + + /** {@inheritDoc} */ + @Override public Vector divide(double x) { + Vector cp = copy(); + + if (x != 1.0) + for (Element element : cp.all()) + element.set(element.get() / x); + + return cp; + } + + /** {@inheritDoc} */ + @Override public Vector times(double x) { + if (x == 0.0) + return like(size()); + else + return copy().map(Functions.mult(x)); + } + + /** {@inheritDoc} */ + @Override public Vector times(Vector vec) { + checkCardinality(vec); + + return copy().map(vec, Functions.MULT); + } + + /** {@inheritDoc} */ + @Override public Vector plus(Vector vec) { + checkCardinality(vec); + + Vector cp = copy(); + + return cp.map(vec, Functions.PLUS); + } + + /** {@inheritDoc} */ + @Override public Vector logNormalize() { + return logNormalize(2.0, Math.sqrt(getLengthSquared())); + } + + /** {@inheritDoc} */ + @Override public Vector logNormalize(double power) { + return logNormalize(power, kNorm(power)); + } + + /** + * @param power Power. + * @param normLen Normalized length. + * @return logNormalized value. + */ + private Vector logNormalize(double power, double normLen) { + assert !(Double.isInfinite(power) || power <= 1.0); + + double denominator = normLen * Math.log(power); + + Vector cp = copy(); + + for (Element element : cp.all()) + element.set(Math.log1p(element.get()) / denominator); + + return cp; + } + + /** {@inheritDoc} */ + @Override public double kNorm(double power) { + assert power >= 0.0; + + // Special cases. + if (Double.isInfinite(power)) + return foldMap(Math::max, Math::abs, 0d); + else if (power == 2.0) + return Math.sqrt(getLengthSquared()); + else if (power == 1.0) + return foldMap(Functions.PLUS, Math::abs, 0d); + else if (power == 0.0) + return nonZeroElements(); + else + // Default case. + return Math.pow(foldMap(Functions.PLUS, Functions.pow(power), 0d), 1.0 / power); + } + + /** {@inheritDoc} */ + @Override public Vector normalize() { + return divide(Math.sqrt(getLengthSquared())); + } + + /** {@inheritDoc} */ + @Override public Vector normalize(double power) { + return divide(kNorm(power)); + } + + /** {@inheritDoc} */ + @Override public Vector copy() { + return like(size()).assign(this); + } + + /** + * @return Result of dot with self. + */ + protected double dotSelf() { + double sum = 0.0; + int len = size(); + + for (int i = 0; i < len; i++) { + double v = storageGet(i); + + sum += v * v; + } + + return sum; + } + + /** {@inheritDoc} */ + @Override public Element getElement(int idx) { + return makeElement(idx); + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(sto); + out.writeObject(meta); + out.writeObject(guid); + out.writeBoolean(readOnly); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + sto = (VectorStorage)in.readObject(); + meta = (Map<String, Object>)in.readObject(); + guid = (IgniteUuid)in.readObject(); + readOnly = in.readBoolean(); + } + + /** {@inheritDoc} */ + @Override public void destroy() { + sto.destroy(); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = 1; + res += res * 37 + guid.hashCode(); + res += sto == null ? 0 : res * 37 + sto.hashCode(); + return res; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + + if (obj == null || getClass() != obj.getClass()) + return false; + + AbstractVector that = (AbstractVector)obj; + + return (sto != null ? sto.equals(that.sto) : that.sto == null); + } + + /** {@inheritDoc} */ + @Override public void compute(int idx, IgniteIntDoubleToDoubleBiFunction f) { + storageSet(idx, f.apply(idx, storageGet(idx))); + lenSq = 0.0; + maxElm = minElm = null; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/Vector.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/Vector.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/Vector.java new file mode 100644 index 0000000..41ce988 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/Vector.java @@ -0,0 +1,521 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.primitives.vector; + +import java.io.Externalizable; +import java.util.Spliterator; +import java.util.function.IntToDoubleFunction; +import org.apache.ignite.lang.IgniteUuid; +import org.apache.ignite.ml.math.Destroyable; +import org.apache.ignite.ml.math.primitives.matrix.Matrix; +import org.apache.ignite.ml.math.MetaAttributes; +import org.apache.ignite.ml.math.StorageOpsMetrics; +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.exceptions.IndexException; +import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteIntDoubleToDoubleBiFunction; + +/** + * A vector interface. + * + * Based on its flavor it can have vastly different implementations tailored for + * for different types of data (e.g. dense vs. sparse), different sizes of data or different operation + * optimizations. + * + * Note also that not all operations can be supported by all underlying implementations. If an operation is not + * supported a {@link UnsupportedOperationException} is thrown. This exception can also be thrown in partial cases + * where an operation is unsupported only in special cases, e.g. where a given operation cannot be deterministically + * completed in polynomial time. + * + * Based on ideas from <a href="http://mahout.apache.org/">Apache Mahout</a>. + */ +public interface Vector extends MetaAttributes, Externalizable, StorageOpsMetrics, Destroyable { + /** + * Holder for vector's element. + */ + interface Element { + /** + * Gets element's value. + * + * @return The value of this vector element. + */ + double get(); + + /** + * Gets element's index in the vector. + * + * @return The index of this vector element. + */ + int index(); + + /** + * Sets element's value. + * + * @param val Value to set. + */ + void set(double val); + } + + /** + * Gets cardinality of this vector (maximum number of the elements). + * + * @return This vector's cardinality. + */ + public int size(); + + /** + * Creates new copy of this vector. + * + * @return New copy vector. + */ + public Vector copy(); + + /** + * Gets iterator over all elements in this vector. + * + * NOTE: implementation can choose to reuse {@link Element} instance so you need to copy it + * if you want to retain it outside of iteration. + * + * @return Iterator. + */ + public Iterable<Element> all(); + + /** + * Iterates ove all non-zero elements in this vector. + * + * NOTE: implementation can choose to reuse {@link Element} instance so you need to copy it + * if you want to retain it outside of iteration. + * + * @return Iterator. + */ + public Iterable<Element> nonZeroes(); + + /** + * Gets spliterator for all values in this vector. + * + * @return Spliterator for all values. + */ + public Spliterator<Double> allSpliterator(); + + /** + * Gets spliterator for all non-zero values in this vector. + * + * @return Spliterator for all non-zero values. + */ + public Spliterator<Double> nonZeroSpliterator(); + + /** + * Sorts this vector in ascending order. + */ + public Vector sort(); + + /** + * Gets element at the given index. + * + * NOTE: implementation can choose to reuse {@link Element} instance so you need to copy it + * if you want to retain it outside of iteration. + * + * @param idx Element's index. + * @return Vector's element at the given index. + * @throws IndexException Throw if index is out of bounds. + */ + public Element getElement(int idx); + + /** + * Assigns given value to all elements of this vector. + * + * @param val Value to assign. + * @return This vector. + */ + public Vector assign(double val); + + /** + * Assigns values from given array to this vector. + * + * @param vals Values to assign. + * @return This vector. + * @throws CardinalityException Thrown if cardinalities mismatch. + */ + public Vector assign(double[] vals); + + /** + * Copies values from the argument vector to this one. + * + * @param vec Argument vector. + * @return This vector. + * @throws CardinalityException Thrown if cardinalities mismatch. + */ + public Vector assign(Vector vec); + + /** + * Assigns each vector element to the value generated by given function. + * + * @param fun Function that takes the index and returns value. + * @return This vector. + */ + public Vector assign(IntToDoubleFunction fun); + + /** + * Maps all values in this vector through a given function. + * + * @param fun Mapping function. + * @return This vector. + */ + public Vector map(IgniteDoubleFunction<Double> fun); + + /** + * Maps all values in this vector through a given function. + * + * For this vector <code>A</code>, argument vector <code>B</code> and the + * function <code>F</code> this method maps every element <code>x</code> as: + * <code>A(x) = F(A(x), B(x))</code> + * + * @param vec Argument vector. + * @param fun Mapping function. + * @return This function. + * @throws CardinalityException Thrown if cardinalities mismatch. + */ + public Vector map(Vector vec, IgniteBiFunction<Double, Double, Double> fun); + + /** + * Maps all elements of this vector by applying given function to each element with a constant + * second parameter <code>y</code>. + * + * @param fun Mapping function. + * @param y Second parameter for mapping function. + * @return This vector. + */ + public Vector map(IgniteBiFunction<Double, Double, Double> fun, double y); + + /** + * Creates new vector containing values from this vector divided by the argument. + * + * @param x Division argument. + * @return New vector. + */ + public Vector divide(double x); + + /** + * Gets dot product of two vectors. + * + * @param vec Argument vector. + * @return Dot product of two vectors. + */ + public double dot(Vector vec); + + /** + * Gets the value at specified index. + * + * @param idx Vector index. + * @return Vector value. + * @throws IndexException Throw if index is out of bounds. + */ + public double get(int idx); + + /** + * Gets the value at specified index without checking for index boundaries. + * + * @param idx Vector index. + * @return Vector value. + */ + public double getX(int idx); + + /** + * Creates new empty vector of the same underlying class but of different cardinality. + * + * @param crd Cardinality for new vector. + * @return New vector. + */ + public Vector like(int crd); + + /** + * Creates new matrix of compatible flavor with given size. + * + * @param rows Number of rows. + * @param cols Number of columns. + * @return New matrix. + */ + public Matrix likeMatrix(int rows, int cols); + + /** + * Converts this vector into [N x 1] or [1 x N] matrix where N is this vector cardinality. + * + * @param rowLike {@code true} for rowLike [N x 1], or {@code false} for column [1 x N] matrix. + * @return Newly created matrix. + */ + public Matrix toMatrix(boolean rowLike); + + /** + * Converts this vector into [N+1 x 1] or [1 x N+1] matrix where N is this vector cardinality. + * (0,0) element of this matrix will be {@code zeroVal} parameter. + * + * @param rowLike {@code true} for rowLike [N+1 x 1], or {@code false} for column [1 x N+1] matrix. + * @return Newly created matrix. + */ + public Matrix toMatrixPlusOne(boolean rowLike, double zeroVal); + + /** + * Creates new vector containing element by element difference between this vector and the argument one. + * + * @param vec Argument vector. + * @return New vector. + * @throws CardinalityException Thrown if cardinalities mismatch. + */ + public Vector minus(Vector vec); + + /** + * Creates new vector containing the normalized (L_2 norm) values of this vector. + * + * @return New vector. + */ + public Vector normalize(); + + /** + * Creates new vector containing the normalized (L_power norm) values of this vector. + * See http://en.wikipedia.org/wiki/Lp_space for details. + * + * @param power The power to use. Must be >= 0. May also be {@link Double#POSITIVE_INFINITY}. + * @return New vector {@code x} such that {@code norm(x, power) == 1} + */ + public Vector normalize(double power); + + /** + * Creates new vector containing the {@code log(1 + entry) / L_2 norm} values of this vector. + * + * @return New vector. + */ + public Vector logNormalize(); + + /** + * Creates new vector with a normalized value calculated as {@code log_power(1 + entry) / L_power norm}. + * + * @param power The power to use. Must be > 1. Cannot be {@link Double#POSITIVE_INFINITY}. + * @return New vector + */ + public Vector logNormalize(double power); + + /** + * Gets the k-norm of the vector. See http://en.wikipedia.org/wiki/Lp_space for more details. + * + * @param power The power to use. + * @see #normalize(double) + */ + public double kNorm(double power); + + /** + * Gets minimal value in this vector. + * + * @return Minimal value. + */ + public double minValue(); + + /** + * Gets maximum value in this vector. + * + * @return Maximum c. + */ + public double maxValue(); + + /** + * Gets minimal element in this vector. + * + * @return Minimal element. + */ + public Element minElement(); + + /** + * Gets maximum element in this vector. + * + * @return Maximum element. + */ + public Element maxElement(); + + /** + * Creates new vector containing sum of each element in this vector and argument. + * + * @param x Argument value. + * @return New vector. + */ + public Vector plus(double x); + + /** + * Creates new vector containing element by element sum from both vectors. + * + * @param vec Other argument vector to add. + * @return New vector. + * @throws CardinalityException Thrown if cardinalities mismatch. + */ + public Vector plus(Vector vec); + + /** + * Sets value. + * + * @param idx Vector index to set value at. + * @param val Value to set. + * @return This vector. + * @throws IndexException Throw if index is out of bounds. + */ + public Vector set(int idx, double val); + + /** + * Sets value without checking for index boundaries. + * + * @param idx Vector index to set value at. + * @param val Value to set. + * @return This vector. + */ + public Vector setX(int idx, double val); + + /** + * Increments value at given index without checking for index boundaries. + * + * @param idx Vector index. + * @param val Increment value. + * @return This vector. + */ + public Vector incrementX(int idx, double val); + + /** + * Increments value at given index. + * + * @param idx Vector index. + * @param val Increment value. + * @return This vector. + * @throws IndexException Throw if index is out of bounds. + */ + public Vector increment(int idx, double val); + + /** + * Gets number of non-zero elements in this vector. + * + * @return Number of non-zero elements in this vector. + */ + public int nonZeroElements(); + + /** + * Gets a new vector that contains product of each element and the argument. + * + * @param x Multiply argument. + * @return New vector. + */ + public Vector times(double x); + + /** + * Gets a new vector that is an element-wie product of this vector and the argument. + * + * @param vec Vector to multiply by. + * @return New vector. + * @throws CardinalityException Thrown if cardinalities mismatch. + */ + public Vector times(Vector vec); + + /** + * @param off Offset into parent vector. + * @param len Length of the view. + */ + public Vector viewPart(int off, int len); + + /** + * Gets vector storage model. + */ + public VectorStorage getStorage(); + + /** + * Gets the sum of all elements in this vector. + * + * @return Vector's sum + */ + public double sum(); + + /** + * Gets the cross product of this vector and the other vector. + * + * @param vec Second vector. + * @return New matrix as a cross product of two vectors. + */ + public Matrix cross(Vector vec); + + /** + * Folds this vector into a single value. + * + * @param foldFun Folding function that takes two parameters: accumulator and the current value. + * @param mapFun Mapping function that is called on each vector element before its passed to the accumulator (as its + * second parameter). + * @param <T> Type of the folded value. + * @param zeroVal Zero value for fold operation. + * @return Folded value of this vector. + */ + public <T> T foldMap(IgniteBiFunction<T, Double, T> foldFun, IgniteDoubleFunction<Double> mapFun, T zeroVal); + + /** + * Combines & maps two vector and folds them into a single value. + * + * @param vec Another vector to combine with. + * @param foldFun Folding function. + * @param combFun Combine function. + * @param <T> Type of the folded value. + * @param zeroVal Zero value for fold operation. + * @return Folded value of these vectors. + * @throws CardinalityException Thrown when cardinalities mismatch. + */ + public <T> T foldMap(Vector vec, IgniteBiFunction<T, Double, T> foldFun, + IgniteBiFunction<Double, Double, Double> combFun, + T zeroVal); + + /** + * Gets the sum of squares of all elements in this vector. + * + * @return Length squared value. + */ + public double getLengthSquared(); + + /** + * Get the square of the distance between this vector and the argument vector. + * + * @param vec Another vector. + * @return Distance squared. + * @throws CardinalityException Thrown if cardinalities mismatch. + */ + public double getDistanceSquared(Vector vec); + + /** + * Auto-generated globally unique vector ID. + * + * @return Vector GUID. + */ + public IgniteUuid guid(); + + /** + * Replace vector entry with value oldVal at i with result of computing f(i, oldVal). + * + * @param i Position. + * @param f Function used for replacing. + **/ + public void compute(int i, IgniteIntDoubleToDoubleBiFunction f); + + + /** + * Returns array of doubles corresponds to vector components. + * @return Array of doubles. + */ + public default double[] asArray() { + return getStorage().data(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorStorage.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorStorage.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorStorage.java new file mode 100644 index 0000000..eed7c17 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorStorage.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.primitives.vector; + +import java.io.Externalizable; +import org.apache.ignite.ml.math.Destroyable; +import org.apache.ignite.ml.math.StorageOpsMetrics; + +/** + * Data storage support for {@link Vector}. + */ +public interface VectorStorage extends Externalizable, StorageOpsMetrics, Destroyable { + /** + * + * + */ + public int size(); + + /** + * @param i Vector element index. + * @return Value obtained for given element index. + */ + public double get(int i); + + /** + * @param i Vector element index. + * @param v Value to set at given index. + */ + public void set(int i, double v); + + /** + * Gets underlying array if {@link StorageOpsMetrics#isArrayBased()} returns {@code true}. + * Returns {@code null} if in other cases. + * + * @see StorageOpsMetrics#isArrayBased() + */ + public default double[] data() { + return null; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java new file mode 100644 index 0000000..5e1341b --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.primitives.vector; + +import java.util.Arrays; +import java.util.Objects; +import org.apache.ignite.internal.util.typedef.internal.A; +import org.apache.ignite.ml.math.StorageConstants; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector; + +/** + * Some utils for {@link Vector}. + */ +public class VectorUtils { + /** Create new vector like given vector initialized by zeroes. */ + public static Vector zeroesLike(Vector v) { + return v.like(v.size()).assign(0.0); + } + + /** Create new */ + public static DenseVector zeroes(int n) { + return (DenseVector)new DenseVector(n).assign(0.0); + } + + /** + * Turn number into a local Vector of given size with one-hot encoding. + * + * @param num Number to turn into vector. + * @param vecSize Vector size of output vector. + * @return One-hot encoded number. + */ + public static Vector num2Vec(int num, int vecSize) { + return num2Vec(num, vecSize, false); + } + + /** + * Turn number into Vector of given size with one-hot encoding. + * + * @param num Number to turn into vector. + * @param vecSize Vector size of output vector. + * @param isDistributed Flag indicating if distributed vector should be created. + * @return One-hot encoded number. + */ + public static Vector num2Vec(int num, int vecSize, boolean isDistributed) { + Vector res = new DenseVector(vecSize); + return res.setX(num, 1); + } + + /** + * Turn Vector into number by looking at index of maximal element in vector. + * + * @param vec Vector to be turned into number. + * @return Number. + */ + public static double vec2Num(Vector vec) { + int max = 0; + double maxVal = Double.NEGATIVE_INFINITY; + + for (int i = 0; i < vec.size(); i++) { + double curVal = vec.getX(i); + if (curVal > maxVal) { + max = i; + maxVal = curVal; + } + } + + return max; + } + + /** + * Performs in-place vector multiplication. + * + * @param vec1 Operand to be changed and first multiplication operand. + * @param vec2 Second multiplication operand. + * @return Updated first operand. + */ + public static Vector elementWiseTimes(Vector vec1, Vector vec2) { + vec1.map(vec2, (a, b) -> a * b); + + return vec1; + } + + /** + * Performs in-place vector subtraction. + * + * @param vec1 Operand to be changed and subtracted from. + * @param vec2 Operand to subtract. + * @return Updated first operand. + */ + public static Vector elementWiseMinus(Vector vec1, Vector vec2) { + vec1.map(vec2, (a, b) -> a - b); + + return vec1; + } + + /** + * Zip two vectors with given binary function + * (i.e. apply binary function to both vector elementwise and construct vector from results). + * + * Example zipWith({0, 2, 4}, {1, 3, 5}, plus) = {0 + 1, 2 + 3, 4 + 5}. + * Length of result is length of shortest of vectors. + * + * @param v1 First vector. + * @param v2 Second vector. + * @param f Function to zip with. + * @return Result of zipping. + */ + public static Vector zipWith(Vector v1, Vector v2, IgniteBiFunction<Double, Double, Double> f) { + int size = Math.min(v1.size(), v2.size()); + + Vector res = v1.like(size); + + for (int row = 0; row < size; row++) + res.setX(row, f.apply(v1.getX(row), v2.getX(row))); + + return res; + } + + /** + * Get copy of part of given length of given vector starting from given offset. + * + * @param v Vector to copy part from. + * @param off Offset. + * @param len Length. + * @return Copy of part of given length of given vector starting from given offset. + */ + public static Vector copyPart(Vector v, int off, int len) { + assert off >= 0; + assert len <= v.size(); + + Vector res = v.like(len); + + for (int i = 0; i < len; i++) + res.setX(i, v.getX(off + i)); + + return res; + } + + /** + * Creates dense local on heap vector based on array of doubles. + * + * @param values Values. + */ + public static Vector of(double ... values) { + A.notNull(values, "values"); + + return new DenseVector(values); + } + + /** + * Creates vector based on array of Doubles. If array contains null-elements then + * method returns sparse local on head vector. In other case method returns + * dense local on heap vector. + * + * @param values Values. + */ + public static Vector of(Double[] values) { + A.notNull(values, "values"); + + Vector answer; + if (Arrays.stream(values).anyMatch(Objects::isNull)) + answer = new SparseVector(values.length, StorageConstants.RANDOM_ACCESS_MODE); + else + answer = new DenseVector(values.length); + + for (int i = 0; i < values.length; i++) + if (values[i] != null) + answer.set(i, values[i]); + + return answer; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DelegatingVector.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DelegatingVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DelegatingVector.java new file mode 100644 index 0000000..b31f6fb --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DelegatingVector.java @@ -0,0 +1,402 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.primitives.vector.impl; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.HashMap; +import java.util.Map; +import java.util.Spliterator; +import java.util.function.IntToDoubleFunction; +import org.apache.ignite.lang.IgniteUuid; +import org.apache.ignite.ml.math.primitives.matrix.Matrix; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorStorage; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteDoubleFunction; +import org.apache.ignite.ml.math.functions.IgniteIntDoubleToDoubleBiFunction; + +/** + * Convenient class that can be used to add decorations to an existing vector. Subclasses + * can add weights, indices, etc. while maintaining full vector functionality. + */ +public class DelegatingVector implements Vector { + /** Delegating vector. */ + private Vector dlg; + + /** Meta attribute storage. */ + private Map<String, Object> meta = new HashMap<>(); + + /** GUID. */ + private IgniteUuid guid = IgniteUuid.randomUuid(); + + /** */ + public DelegatingVector() { + // No-op. + } + + /** + * @param dlg Parent vector. + */ + public DelegatingVector(Vector dlg) { + assert dlg != null; + + this.dlg = dlg; + } + + /** Get the delegating vector */ + public Vector getVector() { + return dlg; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(dlg); + out.writeObject(meta); + out.writeObject(guid); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + dlg = (Vector)in.readObject(); + meta = (Map<String, Object>)in.readObject(); + guid = (IgniteUuid)in.readObject(); + } + + /** {@inheritDoc} */ + @Override public Map<String, Object> getMetaStorage() { + return meta; + } + + /** {@inheritDoc} */ + @Override public Matrix likeMatrix(int rows, int cols) { + return dlg.likeMatrix(rows, cols); + } + + /** {@inheritDoc} */ + @Override public Matrix toMatrix(boolean rowLike) { + return dlg.toMatrix(rowLike); + } + + /** {@inheritDoc} */ + @Override public Matrix toMatrixPlusOne(boolean rowLike, double zeroVal) { + return dlg.toMatrixPlusOne(rowLike, zeroVal); + } + + /** {@inheritDoc} */ + @Override public int size() { + return dlg.size(); + } + + /** {@inheritDoc} */ + @Override public boolean isDense() { + return dlg.isDense(); + } + + /** {@inheritDoc} */ + @Override public double minValue() { + return dlg.minValue(); + } + + /** {@inheritDoc} */ + @Override public double maxValue() { + return dlg.maxValue(); + } + + /** {@inheritDoc} */ + @Override public boolean isSequentialAccess() { + return dlg.isSequentialAccess(); + } + + /** {@inheritDoc} */ + @Override public boolean isArrayBased() { + return dlg.isArrayBased(); + } + + /** {@inheritDoc} */ + @Override public Vector copy() { + return new DelegatingVector(dlg); + } + + /** {@inheritDoc} */ + @Override public Iterable<Element> all() { + return dlg.all(); + } + + /** {@inheritDoc} */ + @Override public Iterable<Element> nonZeroes() { + return dlg.nonZeroes(); + } + + /** {@inheritDoc} */ + @Override public Vector sort() { + return dlg.sort(); + } + + /** {@inheritDoc} */ + @Override public Spliterator<Double> allSpliterator() { + return dlg.allSpliterator(); + } + + /** {@inheritDoc} */ + @Override public Spliterator<Double> nonZeroSpliterator() { + return dlg.nonZeroSpliterator(); + } + + /** {@inheritDoc} */ + @Override public Element getElement(int idx) { + return dlg.getElement(idx); + } + + /** {@inheritDoc} */ + @Override public Vector assign(double val) { + return dlg.assign(val); + } + + /** {@inheritDoc} */ + @Override public Vector assign(double[] vals) { + return dlg.assign(vals); + } + + /** {@inheritDoc} */ + @Override public Vector assign(Vector vec) { + return dlg.assign(vec); + } + + /** {@inheritDoc} */ + @Override public Vector assign(IntToDoubleFunction fun) { + return dlg.assign(fun); + } + + /** {@inheritDoc} */ + @Override public Vector map(IgniteDoubleFunction<Double> fun) { + return dlg.map(fun); + } + + /** {@inheritDoc} */ + @Override public Vector map(Vector vec, IgniteBiFunction<Double, Double, Double> fun) { + return dlg.map(vec, fun); + } + + /** {@inheritDoc} */ + @Override public Vector map(IgniteBiFunction<Double, Double, Double> fun, double y) { + return dlg.map(fun, y); + } + + /** {@inheritDoc} */ + @Override public Vector divide(double x) { + return dlg.divide(x); + } + + /** {@inheritDoc} */ + @Override public double dot(Vector vec) { + return dlg.dot(vec); + } + + /** {@inheritDoc} */ + @Override public double get(int idx) { + return dlg.get(idx); + } + + /** {@inheritDoc} */ + @Override public double getX(int idx) { + return dlg.getX(idx); + } + + /** {@inheritDoc} */ + @Override public Vector like(int crd) { + return dlg.like(crd); + } + + /** {@inheritDoc} */ + @Override public Vector minus(Vector vec) { + return dlg.minus(vec); + } + + /** {@inheritDoc} */ + @Override public Vector normalize() { + return dlg.normalize(); + } + + /** {@inheritDoc} */ + @Override public Vector normalize(double power) { + return dlg.normalize(power); + } + + /** {@inheritDoc} */ + @Override public Vector logNormalize() { + return dlg.logNormalize(); + } + + /** {@inheritDoc} */ + @Override public Vector logNormalize(double power) { + return dlg.logNormalize(power); + } + + /** {@inheritDoc} */ + @Override public double kNorm(double power) { + return dlg.kNorm(power); + } + + /** {@inheritDoc} */ + @Override public Element minElement() { + return dlg.minElement(); + } + + /** {@inheritDoc} */ + @Override public Element maxElement() { + return dlg.maxElement(); + } + + /** {@inheritDoc} */ + @Override public Vector plus(double x) { + return dlg.plus(x); + } + + /** {@inheritDoc} */ + @Override public Vector plus(Vector vec) { + return dlg.plus(vec); + } + + /** {@inheritDoc} */ + @Override public Vector set(int idx, double val) { + return dlg.set(idx, val); + } + + /** {@inheritDoc} */ + @Override public Vector setX(int idx, double val) { + return dlg.setX(idx, val); + } + + /** {@inheritDoc} */ + @Override public Vector incrementX(int idx, double val) { + return dlg.incrementX(idx, val); + } + + /** {@inheritDoc} */ + @Override public Vector increment(int idx, double val) { + return dlg.increment(idx, val); + } + + /** {@inheritDoc} */ + @Override public int nonZeroElements() { + return dlg.nonZeroElements(); + } + + /** {@inheritDoc} */ + @Override public Vector times(double x) { + return dlg.times(x); + } + + /** {@inheritDoc} */ + @Override public Vector times(Vector vec) { + return dlg.times(vec); + } + + /** {@inheritDoc} */ + @Override public Vector viewPart(int off, int len) { + return dlg.viewPart(off, len); + } + + /** {@inheritDoc} */ + @Override public VectorStorage getStorage() { + return dlg.getStorage(); + } + + /** {@inheritDoc} */ + @Override public double sum() { + return dlg.sum(); + } + + /** {@inheritDoc} */ + @Override public Matrix cross(Vector vec) { + return dlg.cross(vec); + } + + /** {@inheritDoc} */ + @Override public <T> T foldMap(IgniteBiFunction<T, Double, T> foldFun, IgniteDoubleFunction<Double> mapFun, + T zeroVal) { + return dlg.foldMap(foldFun, mapFun, zeroVal); + } + + /** {@inheritDoc} */ + @Override public <T> T foldMap(Vector vec, IgniteBiFunction<T, Double, T> foldFun, + IgniteBiFunction<Double, Double, Double> combFun, T zeroVal) { + return dlg.foldMap(vec, foldFun, combFun, zeroVal); + } + + /** {@inheritDoc} */ + @Override public double getLengthSquared() { + return dlg.getLengthSquared(); + } + + /** {@inheritDoc} */ + @Override public double getDistanceSquared(Vector vec) { + return dlg.getDistanceSquared(vec); + } + + /** {@inheritDoc} */ + @Override public boolean isRandomAccess() { + return dlg.isRandomAccess(); + } + + /** {@inheritDoc} */ + @Override public boolean isDistributed() { + return dlg.isDistributed(); + } + + /** {@inheritDoc} */ + @Override public IgniteUuid guid() { + return guid; + } + + /** {@inheritDoc} */ + @Override public void compute(int i, IgniteIntDoubleToDoubleBiFunction f) { + dlg.compute(i, f); + } + + /** {@inheritDoc} */ + @Override public void destroy() { + dlg.destroy(); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = 1; + + res = res * 37 + meta.hashCode(); + res = res * 37 + dlg.hashCode(); + + return res; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + + if (o == null || getClass() != o.getClass()) + return false; + + DelegatingVector that = (DelegatingVector)o; + + return meta.equals(that.meta) && dlg.equals(that.dlg); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DenseVector.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DenseVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DenseVector.java new file mode 100644 index 0000000..48b9212 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/DenseVector.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.primitives.vector.impl; + +import java.util.Map; +import org.apache.ignite.ml.math.primitives.matrix.Matrix; +import org.apache.ignite.ml.math.primitives.vector.storage.DenseVectorStorage; +import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorStorage; +import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; +import org.apache.ignite.ml.math.primitives.vector.AbstractVector; + +/** + * Basic implementation for vector. + * <p> + * This is a trivial implementation for vector assuming dense logic, local on-heap JVM storage + * based on {@code double[]} array. It is only suitable for data sets where + * local, non-distributed execution is satisfactory and on-heap JVM storage is enough + * to keep the entire data set. + */ +public class DenseVector extends AbstractVector { + /** + * @param size Vector cardinality. + */ + private VectorStorage mkStorage(int size) { + return new DenseVectorStorage(size); + } + + /** + * @param arr Source array. + * @param cp {@code true} to clone array, reuse it otherwise. + */ + private VectorStorage mkStorage(double[] arr, boolean cp) { + assert arr != null; + + return new DenseVectorStorage(cp ? arr.clone() : arr); + } + + /** + * @param args Parameters for new Vector. + */ + public DenseVector(Map<String, Object> args) { + assert args != null; + + if (args.containsKey("size")) + setStorage(mkStorage((int)args.get("size"))); + else if (args.containsKey("arr") && args.containsKey("copy")) + setStorage(mkStorage((double[])args.get("arr"), (boolean)args.get("copy"))); + else + throw new UnsupportedOperationException("Invalid constructor argument(s)."); + } + + /** */ + public DenseVector() { + // No-op. + } + + /** + * @param size Vector cardinality. + */ + public DenseVector(int size) { + setStorage(mkStorage(size)); + } + + /** + * @param arr Source array. + * @param shallowCp {@code true} to use shallow copy. + */ + public DenseVector(double[] arr, boolean shallowCp) { + setStorage(mkStorage(arr, shallowCp)); + } + + /** + * @param arr Source array. + */ + public DenseVector(double[] arr) { + this(arr, false); + } + + /** {@inheritDoc} */ + @Override public Matrix likeMatrix(int rows, int cols) { + return new DenseMatrix(rows, cols); + } + + /** {@inheritDoc} */ + @Override public Vector like(int crd) { + return new DenseVector(crd); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/26e40528/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/SparseVector.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/SparseVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/SparseVector.java new file mode 100644 index 0000000..8b3a274 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/impl/SparseVector.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.primitives.vector.impl; + +import it.unimi.dsi.fastutil.ints.IntSet; +import java.util.Map; +import java.util.Set; +import java.util.Spliterator; +import java.util.function.Consumer; +import org.apache.ignite.ml.math.primitives.matrix.Matrix; +import org.apache.ignite.ml.math.StorageConstants; +import org.apache.ignite.ml.math.primitives.matrix.impl.SparseMatrix; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.AbstractVector; +import org.apache.ignite.ml.math.primitives.vector.storage.SparseVectorStorage; + +/** + * Local on-heap sparse vector based on hash map storage. + */ +public class SparseVector extends AbstractVector implements StorageConstants { + /** + * + */ + public SparseVector() { + // No-op. + } + + /** + * @param map Underlying map. + * @param cp Should given map be copied. + */ + public SparseVector(Map<Integer, Double> map, boolean cp) { + setStorage(new SparseVectorStorage(map, cp)); + } + + /** + * @param size Vector size. + * @param acsMode Vector elements access mode. + */ + public SparseVector(int size, int acsMode) { + assertAccessMode(acsMode); + + setStorage(new SparseVectorStorage(size, acsMode)); + } + + /** */ + private SparseVectorStorage storage() { + return (SparseVectorStorage)getStorage(); + } + + /** {@inheritDoc} */ + @Override public Vector like(int crd) { + SparseVectorStorage sto = storage(); + + return new SparseVector(crd, sto.getAccessMode()); + } + + /** {@inheritDoc} */ + @Override public Matrix likeMatrix(int rows, int cols) { + return new SparseMatrix(rows, cols); + } + + /** {@inheritDoc} */ + @Override public Vector times(double x) { + if (x == 0.0) + return assign(0); + else + return super.times(x); + } + + /** Indexes of non-default elements. */ + public IntSet indexes() { + return storage().indexes(); + } + + /** {@inheritDoc} */ + @Override public Spliterator<Double> nonZeroSpliterator() { + return new Spliterator<Double>() { + /** {@inheritDoc} */ + @Override public boolean tryAdvance(Consumer<? super Double> act) { + Set<Integer> indexes = storage().indexes(); + + for (Integer index : indexes) + act.accept(storageGet(index)); + + return true; + } + + /** {@inheritDoc} */ + @Override public Spliterator<Double> trySplit() { + return null; // No Splitting. + } + + /** {@inheritDoc} */ + @Override public long estimateSize() { + return storage().indexes().size(); + } + + /** {@inheritDoc} */ + @Override public int characteristics() { + return ORDERED | SIZED; + } + }; + } +}
