Repository: ignite Updated Branches: refs/heads/master ee2a6f7c3 -> 429f9544a
IGNITE-7205: Dataset API this closes #3303 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/429f9544 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/429f9544 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/429f9544 Branch: refs/heads/master Commit: 429f9544a6935d2c087e0ccbfb46f65e8723b57b Parents: ee2a6f7 Author: zaleslaw <[email protected]> Authored: Thu Dec 28 16:31:42 2017 +0300 Committer: Yury Babak <[email protected]> Committed: Thu Dec 28 16:31:42 2017 +0300 ---------------------------------------------------------------------- .../KNNClassificationExample.java | 9 +- .../ml/knn/regression/KNNRegressionExample.java | 13 +- .../apache/ignite/ml/knn/models/KNNModel.java | 2 +- .../ignite/ml/knn/models/Normalization.java | 32 -- .../apache/ignite/ml/structures/Dataset.java | 232 +++++++++++++ .../apache/ignite/ml/structures/DatasetRow.java | 79 +++++ .../ignite/ml/structures/FeatureMetadata.java | 82 +++++ .../ignite/ml/structures/LabeledDataset.java | 338 ++----------------- .../structures/LabeledDatasetTestTrainPair.java | 8 +- .../ignite/ml/structures/LabeledVector.java | 53 +-- .../preprocessing/LabeledDatasetLoader.java | 133 ++++++++ .../preprocessing/LabellingMachine.java | 41 +++ .../ml/structures/preprocessing/Normalizer.java | 78 +++++ .../structures/preprocessing/package-info.java | 22 ++ .../org/apache/ignite/ml/knn/BaseKNNTest.java | 3 +- .../ml/knn/KNNMultipleLinearRegressionTest.java | 6 +- .../ignite/ml/knn/LabeledDatasetTest.java | 25 +- .../ignite/ml/math/ExternalizableTest.java | 67 ++++ 18 files changed, 837 insertions(+), 386 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java index fb7eebd..efdacd7 100644 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java @@ -30,6 +30,8 @@ import org.apache.ignite.ml.knn.models.KNNStrategy; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.structures.LabeledDataset; import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair; +import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; +import org.apache.ignite.ml.structures.preprocessing.LabellingMachine; import org.apache.ignite.thread.IgniteThread; /** @@ -71,7 +73,7 @@ public class KNNClassificationExample { Path path = Paths.get(KNNClassificationExample.class.getClassLoader().getResource(KNN_IRIS_TXT).toURI()); // Read dataset from file - LabeledDataset dataset = LabeledDataset.loadTxt(path, SEPARATOR, true, false); + LabeledDataset dataset = LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, true, false); // Random splitting of iris data as 70% train and 30% test datasets LabeledDatasetTestTrainPair split = new LabeledDatasetTestTrainPair(dataset, 0.3); @@ -88,10 +90,7 @@ public class KNNClassificationExample { final double[] labels = test.labels(); // Save predicted classes to test dataset - for (int i = 0; i < test.rowSize(); i++) { - double predictedCls = knnMdl.apply(test.getRow(i).features()); - test.setLabel(i, predictedCls); - } + LabellingMachine.assignLabels(test, knnMdl); // Calculate amount of errors on test dataset int amountOfErrors = 0; http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java index 6ed0dd6..31f7191 100644 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java @@ -26,11 +26,13 @@ import org.apache.ignite.Ignition; import org.apache.ignite.examples.ExampleNodeStartup; import org.apache.ignite.examples.ml.knn.classification.KNNClassificationExample; import org.apache.ignite.ml.knn.models.KNNStrategy; -import org.apache.ignite.ml.knn.models.Normalization; import org.apache.ignite.ml.knn.regression.KNNMultipleLinearRegression; import org.apache.ignite.ml.math.distances.ManhattanDistance; import org.apache.ignite.ml.structures.LabeledDataset; import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair; +import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; +import org.apache.ignite.ml.structures.preprocessing.LabellingMachine; +import org.apache.ignite.ml.structures.preprocessing.Normalizer; import org.apache.ignite.thread.IgniteThread; /** @@ -72,10 +74,10 @@ public class KNNRegressionExample { Path path = Paths.get(KNNClassificationExample.class.getClassLoader().getResource(KNN_CLEARED_MACHINES_TXT).toURI()); // Read dataset from file - LabeledDataset dataset = LabeledDataset.loadTxt(path, SEPARATOR, false, false); + LabeledDataset dataset = LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, false, false); // Normalize dataset - dataset.normalizeWith(Normalization.MINIMAX); + Normalizer.normalizeWithMiniMax(dataset); // Random splitting of iris data as 80% train and 20% test datasets LabeledDatasetTestTrainPair split = new LabeledDatasetTestTrainPair(dataset, 0.2); @@ -93,10 +95,7 @@ public class KNNRegressionExample { final double[] labels = test.labels(); // Save predicted classes to test dataset - for (int i = 0; i < test.rowSize(); i++) { - double predictedCls = knnMdl.apply(test.getRow(i).features()); - test.setLabel(i, predictedCls); - } + LabellingMachine.assignLabels(test, knnMdl); // Calculate mean squared error (MSE) double mse = 0.0; http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java index d3dff8c..3951be4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java @@ -96,7 +96,7 @@ public class KNNModel implements Model<Vector, Double>, Exportable<KNNModelForma * @return K-nearest neighbors. */ protected LabeledVector[] findKNearestNeighbors(Vector v, boolean isCashedDistance) { - LabeledVector[] trainingData = training.data(); + LabeledVector[] trainingData = (LabeledVector[])training.data(); TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, trainingData); http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/Normalization.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/Normalization.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/Normalization.java deleted file mode 100644 index aa4b291..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/Normalization.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.knn.models; - -/** This enum contains names of different normalization approaches. */ -public enum Normalization { - /** Minimax. - * - * x'=(x-MIN[X])/(MAX[X]-MIN[X]) - */ - MINIMAX, - /** Z normalization. - * - * x'=(x-M[X])/\sigma [X] - */ - Z_NORMALIZATION -} http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/Dataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/Dataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/Dataset.java new file mode 100644 index 0000000..cb50516 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/Dataset.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.io.Serializable; +import java.util.Arrays; +import org.apache.ignite.ml.math.Vector; + +/** + * Class for set of vectors. This is a base class in hierarchy of datasets. + */ +public class Dataset<Row extends DatasetRow> implements Serializable, Externalizable { + /** Data to keep. */ + protected Row[] data; + + /** Metadata to identify feature. */ + protected FeatureMetadata[] meta; + + /** Amount of instances. */ + protected int rowSize; + + /** Amount of attributes in each vector. */ + protected int colSize; + + /** + * Default constructor (required by Externalizable). + */ + public Dataset(){} + + /** + * Creates new Dataset by given data. + * + * @param data Given data. Should be initialized with one vector at least. + * @param meta Feature's metadata. + */ + public Dataset(Row[] data, FeatureMetadata[] meta) { + this.data = data; + this.meta = meta; + } + + /** + * Creates new Dataset by given data. + * + * @param data Given data. Should be initialized with one vector at least. + * @param featureNames Column names. + * @param colSize Amount of observed attributes in each vector. + */ + public Dataset(Row[] data, String[] featureNames, int colSize) { + this(data.length, colSize, featureNames); + + assert data != null; + + this.data = data; + } + + /** + * Creates new Dataset by given data. + * + * @param data Should be initialized with one vector at least. + * @param colSize Amount of observed attributes in each vector. + */ + public Dataset(Row[] data, int colSize) { + this(data, null, colSize); + } + + /** + * Creates new Dataset and initialized with empty data structure. + * + * @param rowSize Amount of instances. Should be > 0. + * @param colSize Amount of attributes. Should be > 0 + * @param featureNames Column names. + */ + public Dataset(int rowSize, int colSize, String[] featureNames) { + assert rowSize > 0; + assert colSize > 0; + + if (featureNames == null) + generateFeatureNames(); + else { + assert colSize == featureNames.length; + convertStringNamesToFeatureMetadata(featureNames); + } + + this.rowSize = rowSize; + this.colSize = colSize; + } + + /** */ + protected void convertStringNamesToFeatureMetadata(String[] featureNames) { + this.meta = new FeatureMetadata[featureNames.length]; + for (int i = 0; i < featureNames.length; i++) + this.meta[i] = new FeatureMetadata(featureNames[i]); + } + + /** */ + protected void generateFeatureNames() { + String[] featureNames = new String[colSize]; + + for (int i = 0; i < colSize; i++) + featureNames[i] = "f_" + i; + + convertStringNamesToFeatureMetadata(featureNames); + } + + /** + * Returns feature name for column with given index. + * + * @param i The given index. + * @return Feature name. + */ + public String getFeatureName(int i) { + return meta[i].name(); + } + + /** */ + public DatasetRow[] data() { + return data; + } + + /** */ + public void setData(Row[] data) { + this.data = data; + } + + /** */ + public FeatureMetadata[] meta() { + return meta; + } + + /** */ + public void setMeta(FeatureMetadata[] meta) { + this.meta = meta; + } + + /** + * Gets amount of attributes. + * + * @return Amount of attributes in each Labeled Vector. + */ + public int colSize() { + return colSize; + } + + /** + * Gets amount of observation. + * + * @return Amount of rows in dataset. + */ + public int rowSize() { + return rowSize; + } + + /** + * Retrieves Labeled Vector by given index. + * + * @param idx Index of observation. + * @return Labeled features. + */ + public Row getRow(int idx) { + return data[idx]; + } + + /** + * Get the features. + * + * @param idx Index of observation. + * @return Vector with features. + */ + public Vector features(int idx) { + assert idx < rowSize; + assert data != null; + assert data[idx] != null; + + return data[idx].features(); + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + Dataset that = (Dataset)o; + + return rowSize == that.rowSize && colSize == that.colSize && Arrays.equals(data, that.data) && Arrays.equals(meta, that.meta); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = Arrays.hashCode(data); + res = 31 * res + Arrays.hashCode(meta); + res = 31 * res + rowSize; + res = 31 * res + colSize; + return res; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(data); + out.writeObject(meta); + out.writeInt(rowSize); + out.writeInt(colSize); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + data = (Row[]) in.readObject(); + meta = (FeatureMetadata[]) in.readObject(); + rowSize = in.readInt(); + colSize = in.readInt(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/DatasetRow.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/DatasetRow.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/DatasetRow.java new file mode 100644 index 0000000..3ba0cf7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/DatasetRow.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.io.Serializable; +import org.apache.ignite.ml.math.Vector; + +/** Class to keep one observation in dataset. This is a base class for labeled and unlabeled rows. */ +public class DatasetRow<V extends Vector> implements Serializable, Externalizable { + /** Vector. */ + protected V vector; + + + /** + * Default constructor (required by Externalizable). + */ + public DatasetRow() { + } + + /** */ + public DatasetRow(V vector) { + this.vector = vector; + } + + /** + * Get the vector. + * + * @return Vector. + */ + public V features() { + return vector; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + DatasetRow vector1 = (DatasetRow)o; + + return vector != null ? !vector.equals(vector1.vector) : vector1.vector != null; + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + return vector != null ? vector.hashCode() : 0; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(vector); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + vector = (V)in.readObject(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/FeatureMetadata.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/FeatureMetadata.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/FeatureMetadata.java new file mode 100644 index 0000000..5d07bdb --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/FeatureMetadata.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.io.Serializable; + +/** Class for feature metadata. */ +public class FeatureMetadata implements Serializable, Externalizable { + /** Feature name */ + private String name; + + /** + * Default constructor (required by Externalizable). + */ + public FeatureMetadata() { + } + + /** + * Creates an instance of Feature Metadata class. + * + * @param name Name. + */ + public FeatureMetadata(String name) { + this.name = name; + } + + /** */ + public String name() { + return name; + } + + /** */ + public void setName(String name) { + this.name = name; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + FeatureMetadata metadata = (FeatureMetadata)o; + + return name != null ? name.equals(metadata.name) : metadata.name == null; + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + return name != null ? name.hashCode() : 0; + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(name); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + name = (String)in.readObject(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java index ee2f442..53f74f3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java @@ -17,21 +17,9 @@ package org.apache.ignite.ml.structures; -import java.io.IOException; -import java.io.Serializable; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Stream; -import org.apache.ignite.ml.knn.models.Normalization; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.exceptions.NoDataException; -import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; -import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException; -import org.apache.ignite.ml.math.exceptions.knn.FileParsingException; import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector; @@ -40,50 +28,12 @@ import org.jetbrains.annotations.NotNull; /** * Class for set of labeled vectors. */ -public class LabeledDataset implements Serializable { - /** Data to keep. */ - private final LabeledVector[] data; - - /** Feature names (one name for each attribute in vector). */ - private String[] featureNames; - - /** Amount of instances. */ - private int rowSize; - - /** Amount of attributes in each vector. */ - private int colSize; - +public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> { /** - * Creates new Labeled Dataset by given data. - * - * @param data Should be initialized with one vector at least. - * @param colSize Amount of observed attributes in each vector. + * Default constructor (required by Externalizable). */ - public LabeledDataset(LabeledVector[] data, int colSize) { - this(data, null, colSize); - } - - /** - * Creates new Labeled Dataset by given data. - * - * @param data Given data. Should be initialized with one vector at least. - * @param featureNames Column names. - * @param colSize Amount of observed attributes in each vector. - */ - public LabeledDataset(LabeledVector[] data, String[] featureNames, int colSize) { - assert data != null; - assert data.length > 0; - - this.data = data; - this.rowSize = data.length; - this.colSize = colSize; - - if(featureNames == null) generateFeatureNames(); - else { - assert colSize == featureNames.length; - this.featureNames = featureNames; - } - + public LabeledDataset() { + super(); } /** @@ -116,22 +66,26 @@ public class LabeledDataset implements Serializable { * @param isDistributed Use distributed data structures to keep data. */ public LabeledDataset(int rowSize, int colSize, String[] featureNames, boolean isDistributed){ - assert rowSize > 0; - assert colSize > 0; - - if(featureNames == null) generateFeatureNames(); - else { - assert colSize == featureNames.length; - this.featureNames = featureNames; - } + super(rowSize, colSize, featureNames); - this.rowSize = rowSize; - this.colSize = colSize; + initializeDataWithLabeledVectors(rowSize, colSize, isDistributed); + } - data = new LabeledVector[rowSize]; + /** */ + private void initializeDataWithLabeledVectors(int rowSize, int colSize, boolean isDistributed) { + data = (Row[])new LabeledVector[rowSize]; for (int i = 0; i < rowSize; i++) - data[i] = new LabeledVector(getVector(colSize, isDistributed), null); + data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), null); + } + /** + * Creates new Labeled Dataset by given data. + * + * @param data Should be initialized with one vector at least. + * @param colSize Amount of observed attributes in each vector. + */ + public LabeledDataset(Row[] data, int colSize) { + super(data, colSize); } @@ -154,6 +108,7 @@ public class LabeledDataset implements Serializable { * @param isDistributed Use distributed data structures to keep data. */ public LabeledDataset(double[][] mtx, double[] lbs, String[] featureNames, boolean isDistributed) { + super(); assert mtx != null; assert lbs != null; @@ -166,14 +121,17 @@ public class LabeledDataset implements Serializable { this.rowSize = lbs.length; this.colSize = mtx[0].length; - if(featureNames == null) generateFeatureNames(); - else this.featureNames = featureNames; - + if(featureNames == null) + generateFeatureNames(); + else { + assert colSize == featureNames.length; + convertStringNamesToFeatureMetadata(featureNames); + } - data = new LabeledVector[rowSize]; + data = (Row[])new LabeledVector[rowSize]; for (int i = 0; i < rowSize; i++){ - data[i] = new LabeledVector(getVector(colSize, isDistributed), lbs[i]); + data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), lbs[i]); for (int j = 0; j < colSize; j++) { try { data[i].features().set(j, mtx[i][j]); @@ -184,76 +142,6 @@ public class LabeledDataset implements Serializable { } } - /** */ - private void generateFeatureNames() { - featureNames = new String[colSize]; - - for (int i = 0; i < colSize; i++) - featureNames[i] = "f_" + i; - } - - - /** - * Get vectors and their labels. - * - * @return Array of Label Vector instances. - */ - public LabeledVector[] data() { - return data; - } - - /** - * Gets amount of observation. - * - * @return Amount of rows in dataset. - */ - public int rowSize(){ - return rowSize; - } - - /** - * Returns feature name for column with given index. - * - * @param i The given index. - * @return Feature name. - */ - public String getFeatureName(int i){ - return featureNames[i]; - } - - /** - * Gets amount of attributes. - * - * @return Amount of attributes in each Labeled Vector. - */ - public int colSize(){ - return colSize; - } - - /** - * Retrieves Labeled Vector by given index. - * - * @param idx Index of observation. - * @return Labeled features. - */ - public LabeledVector getRow(int idx){ - return data[idx]; - } - - /** - * Get the features. - * - * @param idx Index of observation. - * @return Vector with features. - */ - public Vector features(int idx){ - assert idx < rowSize; - assert data != null; - assert data[idx] != null; - - return data[idx].features(); - } - /** * Returns label if label is attached or null if label is missed. * @@ -261,7 +149,7 @@ public class LabeledDataset implements Serializable { * @return Label. */ public double label(int idx) { - LabeledVector labeledVector = data[idx]; + LabeledVector labeledVector = (LabeledVector)data[idx]; if(labeledVector!=null) return (double)labeledVector.label(); @@ -302,174 +190,10 @@ public class LabeledDataset implements Serializable { throw new NoLabelVectorException(idx); } - /** - * Datafile should keep class labels in the first column. - * - * @param pathToFile Path to file. - * @param separator Element to tokenize row on separate tokens. - * @param isDistributed Generates distributed dataset if true. - * @param isFallOnBadData Fall on incorrect data if true. - * @return Labeled Dataset parsed from file. - */ - public static LabeledDataset loadTxt(Path pathToFile, String separator, boolean isDistributed, boolean isFallOnBadData) throws IOException { - Stream<String> stream = Files.lines(pathToFile); - List<String> list = new ArrayList<>(); - stream.forEach(list::add); - - final int rowSize = list.size(); - - List<Double> labels = new ArrayList<>(); - List<Vector> vectors = new ArrayList<>(); - - if (rowSize > 0) { - - final int colSize = getColumnSize(separator, list) - 1; - - if (colSize > 0) { - - for (int i = 0; i < rowSize; i++) { - Double clsLb; - - String[] rowData = list.get(i).split(separator); - - try { - clsLb = Double.parseDouble(rowData[0]); - Vector vec = parseFeatures(pathToFile, isDistributed, isFallOnBadData, colSize, i, rowData); - labels.add(clsLb); - vectors.add(vec); - } - catch (NumberFormatException e) { - if(isFallOnBadData) - throw new FileParsingException(rowData[0], i, pathToFile); - } - } - - LabeledVector[] data = new LabeledVector[vectors.size()]; - for (int i = 0; i < vectors.size(); i++) - data[i] = new LabeledVector(vectors.get(i), labels.get(i)); - - return new LabeledDataset(data, colSize); - } - else - throw new NoDataException("File should contain first row with data"); - } - else - throw new EmptyFileException(pathToFile.toString()); - } - /** */ - @NotNull private static Vector parseFeatures(Path pathToFile, boolean isDistributed, boolean isFallOnBadData, - int colSize, int rowIdx, String[] rowData) { - final Vector vec = getVector(colSize, isDistributed); - - for (int j = 0; j < colSize; j++) { - - if (rowData.length == colSize + 1) { - double val = fillMissedData(); - - try { - val = Double.parseDouble(rowData[j + 1]); - vec.set(j, val); - } - catch (NumberFormatException e) { - if(isFallOnBadData) - throw new FileParsingException(rowData[j + 1], rowIdx, pathToFile); - else - vec.set(j,val); - } - } - else throw new CardinalityException(colSize + 1, rowData.length); - } - return vec; - } - - // TODO: IGNITE-7025 add filling with mean, mode, ignoring and so on - /** */ - private static double fillMissedData() { - return 0.0; - } - - /** */ - @NotNull private static Vector getVector(int size, boolean isDistributed) { + @NotNull public static Vector emptyVector(int size, boolean isDistributed) { if(isDistributed) return new SparseBlockDistributedVector(size); else return new DenseLocalOnHeapVector(size); } - - /** */ - private static int getColumnSize(String separator, List<String> list) { - String[] rowData = list.get(0).split(separator, -1); // assume that all observation has the same length as a first row - - return rowData.length; - } - - /** - * Scales features in dataset. - * - * @param normalization normalization approach - * @return Labeled dataset - */ - public LabeledDataset normalizeWith(Normalization normalization) { - switch (normalization){ - case MINIMAX: minMaxFeatures(); - break; - case Z_NORMALIZATION: throw new UnsupportedOperationException("Z-normalization is not supported yet"); - } - - return this; - } - - /** - * Complexity 2*N^2. Try to optimize. - */ - private void minMaxFeatures() { - double[] mins = new double[colSize]; - double[] maxs = new double[colSize]; - - for (int j = 0; j < colSize; j++) { - double maxInCurrCol = Double.MIN_VALUE; - double minInCurrCol = Double.MAX_VALUE; - - for (int i = 0; i < rowSize; i++) { - double e = data[i].features().get(j); - maxInCurrCol = Math.max(e, maxInCurrCol); - minInCurrCol = Math.min(e, minInCurrCol); - } - - mins[j] = minInCurrCol; - maxs[j] = maxInCurrCol; - } - - for (int j = 0; j < colSize; j++) { - double div = maxs[j] - mins[j]; - - for (int i = 0; i < rowSize; i++) { - double oldVal = data[i].features().get(j); - double newVal = (oldVal - mins[j])/div; - // x'=(x-MIN[X])/(MAX[X]-MIN[X]) - data[i].features().set(j, newVal); - } - } - } - - /** */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - LabeledDataset that = (LabeledDataset)o; - - return rowSize == that.rowSize && colSize == that.colSize && Arrays.equals(data, that.data) && Arrays.equals(featureNames, that.featureNames); - } - - /** */ - @Override public int hashCode() { - int res = Arrays.hashCode(data); - res = 31 * res + Arrays.hashCode(featureNames); - res = 31 * res + rowSize; - res = 31 * res + colSize; - return res; - } } http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java index dd3d244..baf72d8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java @@ -53,21 +53,19 @@ public class LabeledDatasetTestTrainPair implements Serializable { final TreeSet<Integer> sortedTestIndices = getSortedIndices(datasetSize, testSize); - LabeledVector[] testVectors = new LabeledVector[testSize]; LabeledVector[] trainVectors = new LabeledVector[trainSize]; - int datasetCntr = 0; int trainCntr = 0; int testCntr = 0; for (Integer idx: sortedTestIndices){ // guarantee order as iterator - testVectors[testCntr] = dataset.getRow(idx); + testVectors[testCntr] = (LabeledVector)dataset.getRow(idx); testCntr++; for (int i = datasetCntr; i < idx; i++) { - trainVectors[trainCntr] = dataset.getRow(i); + trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i); trainCntr++; } @@ -75,7 +73,7 @@ public class LabeledDatasetTestTrainPair implements Serializable { } if(datasetCntr < datasetSize){ for (int i = datasetCntr; i < datasetSize; i++) { - trainVectors[trainCntr] = dataset.getRow(i); + trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i); trainCntr++; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java index a4e218b..9f0a881 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java @@ -17,40 +17,37 @@ package org.apache.ignite.ml.structures; -import java.io.Serializable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import org.apache.ignite.ml.math.Vector; /** * Class for vector with label. * * @param <V> Some class extending {@link Vector}. - * @param <T> Type of label. + * @param <L> Type of label. */ -public class LabeledVector<V extends Vector, T> implements Serializable { - /** Vector. */ - private final V vector; - +public class LabeledVector<V extends Vector, L> extends DatasetRow<V> { /** Label. */ - private T lb; + private L lb; /** - * Construct labeled vector. - * - * @param vector Vector. - * @param lb Label. + * Default constructor. */ - public LabeledVector(V vector, T lb) { - this.vector = vector; - this.lb = lb; + public LabeledVector() { + super(); } /** - * Get the vector. + * Construct labeled vector. * - * @return Vector. + * @param vector Vector. + * @param lb Label. */ - public V features() { - return vector; + public LabeledVector(V vector, L lb) { + super(vector); + this.lb = lb; } /** @@ -58,7 +55,7 @@ public class LabeledVector<V extends Vector, T> implements Serializable { * * @return Label. */ - public T label() { + public L label() { return lb; } @@ -67,11 +64,11 @@ public class LabeledVector<V extends Vector, T> implements Serializable { * * @param lb Label. */ - public void setLabel(T lb) { + public void setLabel(L lb) { this.lb = lb; } - /** */ + /** {@inheritDoc} */ @Override public boolean equals(Object o) { if (this == o) return true; @@ -85,10 +82,22 @@ public class LabeledVector<V extends Vector, T> implements Serializable { return lb != null ? lb.equals(vector1.lb) : vector1.lb == null; } - /** */ + /** {@inheritDoc} */ @Override public int hashCode() { int res = vector != null ? vector.hashCode() : 0; res = 31 * res + (lb != null ? lb.hashCode() : 0); return res; } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(vector); + out.writeObject(lb); + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + vector = (V)in.readObject(); + lb = (L)in.readObject(); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java new file mode 100644 index 0000000..0faa416 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures.preprocessing; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.exceptions.NoDataException; +import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException; +import org.apache.ignite.ml.math.exceptions.knn.FileParsingException; +import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.LabeledVector; +import org.jetbrains.annotations.NotNull; + +/** Data pre-processing step which loads data from different file types. */ +public class LabeledDatasetLoader { + /** + * Datafile should keep class labels in the first column. + * + * @param pathToFile Path to file. + * @param separator Element to tokenize row on separate tokens. + * @param isDistributed Generates distributed dataset if true. + * @param isFallOnBadData Fall on incorrect data if true. + * @return Labeled Dataset parsed from file. + */ + public static LabeledDataset loadFromTxtFile(Path pathToFile, String separator, boolean isDistributed, + boolean isFallOnBadData) throws IOException { + Stream<String> stream = Files.lines(pathToFile); + List<String> list = new ArrayList<>(); + stream.forEach(list::add); + + final int rowSize = list.size(); + + List<Double> labels = new ArrayList<>(); + List<Vector> vectors = new ArrayList<>(); + + if (rowSize > 0) { + + final int colSize = getColumnSize(separator, list) - 1; + + if (colSize > 0) { + + for (int i = 0; i < rowSize; i++) { + Double clsLb; + + String[] rowData = list.get(i).split(separator); + + try { + clsLb = Double.parseDouble(rowData[0]); + Vector vec = parseFeatures(pathToFile, isDistributed, isFallOnBadData, colSize, i, rowData); + labels.add(clsLb); + vectors.add(vec); + } + catch (NumberFormatException e) { + if (isFallOnBadData) + throw new FileParsingException(rowData[0], i, pathToFile); + } + } + + LabeledVector[] data = new LabeledVector[vectors.size()]; + for (int i = 0; i < vectors.size(); i++) + data[i] = new LabeledVector(vectors.get(i), labels.get(i)); + + return new LabeledDataset(data, colSize); + } + else + throw new NoDataException("File should contain first row with data"); + } + else + throw new EmptyFileException(pathToFile.toString()); + } + + /** */ + @NotNull private static Vector parseFeatures(Path pathToFile, boolean isDistributed, boolean isFallOnBadData, + int colSize, int rowIdx, String[] rowData) { + final Vector vec = LabeledDataset.emptyVector(colSize, isDistributed); + + for (int j = 0; j < colSize; j++) { + + if (rowData.length == colSize + 1) { + double val = fillMissedData(); + + try { + val = Double.parseDouble(rowData[j + 1]); + vec.set(j, val); + } + catch (NumberFormatException e) { + if (isFallOnBadData) + throw new FileParsingException(rowData[j + 1], rowIdx, pathToFile); + else + vec.set(j, val); + } + } + else + throw new CardinalityException(colSize + 1, rowData.length); + } + return vec; + } + + // TODO: IGNITE-7025 add filling with mean, mode, ignoring and so on + + /** */ + private static double fillMissedData() { + return 0.0; + } + + /** */ + private static int getColumnSize(String separator, List<String> list) { + String[] rowData = list.get(0).split(separator, -1); // assume that all observation has the same length as a first row + + return rowData.length; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabellingMachine.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabellingMachine.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabellingMachine.java new file mode 100644 index 0000000..44719cf --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabellingMachine.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures.preprocessing; + +import org.apache.ignite.ml.knn.models.KNNModel; +import org.apache.ignite.ml.structures.LabeledDataset; + +/** Data pre-processing step which assigns labels to all observations according model. */ +public class LabellingMachine { + /** + * Set labels to each observation according passed Model. + * <p> + * NOTE: In-place operation. + * </p> + * @param ds The given labeled dataset. + * @param knnMdl The given kNN Model. + * @return Dataset with predicted labels. + */ + public static LabeledDataset assignLabels(LabeledDataset ds, KNNModel knnMdl) { + for (int i = 0; i < ds.rowSize(); i++) { + double predictedCls = knnMdl.apply(ds.getRow(i).features()); + ds.setLabel(i, predictedCls); + } + return ds; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/Normalizer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/Normalizer.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/Normalizer.java new file mode 100644 index 0000000..26d8bf9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/Normalizer.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures.preprocessing; + +import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; +import org.apache.ignite.ml.structures.Dataset; +import org.apache.ignite.ml.structures.DatasetRow; + +/** Data pre-processing step which scales features according normalization algorithms. */ +public class Normalizer { + /** + * Scales features in dataset with MiniMax algorithm x'=(x-MIN[X])/(MAX[X]-MIN[X]). This is an in-place operation. + * <p> + * NOTE: Complexity 2*N^2. + * </p> + * @param ds The given dataset. + * @return Transformed dataset. + */ + public static Dataset normalizeWithMiniMax(Dataset ds) { + int colSize = ds.colSize(); + double[] mins = new double[colSize]; + double[] maxs = new double[colSize]; + + int rowSize = ds.rowSize(); + DatasetRow[] data = ds.data(); + for (int j = 0; j < colSize; j++) { + double maxInCurrCol = Double.MIN_VALUE; + double minInCurrCol = Double.MAX_VALUE; + + for (int i = 0; i < rowSize; i++) { + double e = data[i].features().get(j); + maxInCurrCol = Math.max(e, maxInCurrCol); + minInCurrCol = Math.min(e, minInCurrCol); + } + + mins[j] = minInCurrCol; + maxs[j] = maxInCurrCol; + } + + for (int j = 0; j < colSize; j++) { + double div = maxs[j] - mins[j]; + + for (int i = 0; i < rowSize; i++) { + double oldVal = data[i].features().get(j); + double newVal = (oldVal - mins[j]) / div; + // x'=(x-MIN[X])/(MAX[X]-MIN[X]) + data[i].features().set(j, newVal); + } + } + + return ds; + } + + /** + * Scales features in dataset with Z-Normalization algorithm x'=(x-M[X])/\sigma [X]. This is an in-place operation. + * + * @param ds The given dataset. + * @return Transformed dataset. + */ + public static Dataset normalizeWithZNormalization(Dataset ds) { + throw new UnsupportedOperationException("Z-normalization is not supported yet"); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/package-info.java new file mode 100644 index 0000000..c243074 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains main APIs for dataset pre-processing. + */ +package org.apache.ignite.ml.structures.preprocessing; http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java index 9075978..1651588 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java @@ -23,6 +23,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import org.apache.ignite.Ignite; import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; /** @@ -76,7 +77,7 @@ public class BaseKNNTest extends GridCommonAbstractTest { try { Path path = Paths.get(this.getClass().getClassLoader().getResource(rsrcPath).toURI()); try { - return LabeledDataset.loadTxt(path, SEPARATOR, false, isFallOnBadData); + return LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, false, isFallOnBadData); } catch (IOException e) { e.printStackTrace(); http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java index d973686..e5d9b13 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java @@ -19,13 +19,13 @@ package org.apache.ignite.ml.knn; import org.apache.ignite.internal.util.IgniteUtils; import org.apache.ignite.ml.knn.models.KNNStrategy; -import org.apache.ignite.ml.knn.models.Normalization; import org.apache.ignite.ml.knn.regression.KNNMultipleLinearRegression; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector; import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.preprocessing.Normalizer; import org.junit.Assert; /** @@ -115,7 +115,7 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest { LabeledDataset training = new LabeledDataset(x, y); - final LabeledDataset normalizedTrainingDataset = training.normalizeWith(Normalization.MINIMAX); + final LabeledDataset normalizedTrainingDataset = (LabeledDataset)Normalizer.normalizeWithMiniMax(training); KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(5, new EuclideanDistance(), KNNStrategy.SIMPLE, normalizedTrainingDataset); Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); @@ -147,7 +147,7 @@ public class KNNMultipleLinearRegressionTest extends BaseKNNTest { LabeledDataset training = new LabeledDataset(x, y); - final LabeledDataset normalizedTrainingDataset = training.normalizeWith(Normalization.MINIMAX); + final LabeledDataset normalizedTrainingDataset = (LabeledDataset)Normalizer.normalizeWithMiniMax(training); KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(5, new EuclideanDistance(), KNNStrategy.WEIGHTED, normalizedTrainingDataset); Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956}); http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java index c64a8d8..c4ae70f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java @@ -22,6 +22,7 @@ import java.net.URISyntaxException; import java.nio.file.Path; import java.nio.file.Paths; import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.math.ExternalizableTest; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.exceptions.NoDataException; @@ -30,9 +31,10 @@ import org.apache.ignite.ml.math.exceptions.knn.FileParsingException; import org.apache.ignite.ml.structures.LabeledDataset; import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair; import org.apache.ignite.ml.structures.LabeledVector; +import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader; /** Tests behaviour of KNNClassificationTest. */ -public class LabeledDatasetTest extends BaseKNNTest { +public class LabeledDatasetTest extends BaseKNNTest implements ExternalizableTest<LabeledDataset> { /** */ private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt"; @@ -88,7 +90,7 @@ public class LabeledDatasetTest extends BaseKNNTest { assertEquals(dataset.colSize(), 2); assertEquals(dataset.rowSize(), 6); - final LabeledVector<Vector, Double> row = dataset.getRow(0); + final LabeledVector<Vector, Double> row = (LabeledVector<Vector, Double>)dataset.getRow(0); assertEquals(row.features().get(0), 1.0); assertEquals(row.label(), 1.0); @@ -202,7 +204,7 @@ public class LabeledDatasetTest extends BaseKNNTest { Path path = Paths.get(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA).toURI()); - LabeledDataset training = LabeledDataset.loadTxt(path, ",", false, false); + LabeledDataset training = LabeledDatasetLoader.loadFromTxtFile(path, ",", false, false); assertEquals(training.features(2).get(1), 0.0); } @@ -263,4 +265,21 @@ public class LabeledDatasetTest extends BaseKNNTest { for (int i = 0; i < lbs.length; i++) assertEquals(lbs[i], labels[i]); } + + @Override public void testExternalization() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + double[][] mtx = + new double[][] { + {1.0, 1.0}, + {1.0, 2.0}, + {2.0, 1.0}, + {-1.0, -1.0}, + {-1.0, -2.0}, + {-2.0, -1.0}}; + double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0}; + + LabeledDataset dataset = new LabeledDataset(mtx, lbs); + this.externalizeTest(dataset); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/429f9544/modules/ml/src/test/java/org/apache/ignite/ml/math/ExternalizableTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/ExternalizableTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/ExternalizableTest.java new file mode 100644 index 0000000..e4080ff --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/ExternalizableTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import org.apache.ignite.ml.math.impls.MathTestConstants; +import org.junit.Test; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Common test for externalization. + */ +public interface ExternalizableTest<T extends Externalizable> { + /** */ + @SuppressWarnings("unchecked") + public default void externalizeTest(T initObj) { + T objRestored = null; + + try { + ByteArrayOutputStream byteArrOutputStream = new ByteArrayOutputStream(); + ObjectOutputStream objOutputStream = new ObjectOutputStream(byteArrOutputStream); + + objOutputStream.writeObject(initObj); + + ByteArrayInputStream byteArrInputStream = new ByteArrayInputStream(byteArrOutputStream.toByteArray()); + ObjectInputStream objInputStream = new ObjectInputStream(byteArrInputStream); + + objRestored = (T)objInputStream.readObject(); + + assertTrue(MathTestConstants.VAL_NOT_EQUALS, initObj.equals(objRestored)); + assertTrue(MathTestConstants.VAL_NOT_EQUALS, Integer.compare(initObj.hashCode(), objRestored.hashCode()) == 0); + } + catch (ClassNotFoundException | IOException e) { + fail(e + " [" + e.getMessage() + "]"); + } + finally { + if (objRestored != null && objRestored instanceof Destroyable) + ((Destroyable)objRestored).destroy(); + } + } + + /** */ + @Test + public void testExternalization(); +}
