http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java deleted file mode 100644 index fd5a624..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.logistic.multiclass; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.apache.ignite.ml.dataset.Dataset; -import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.nn.UpdatesStrategy; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; -import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; -import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap; -import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; - -/** - * All common parameters are shared with bunch of binary classification trainers. - */ -public class LogRegressionMultiClassTrainer<P extends Serializable> - extends SingleLabelDatasetTrainer<LogRegressionMultiClassModel> { - /** Update strategy. */ - private UpdatesStrategy updatesStgy = new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ); - - /** Max number of iteration. */ - private int amountOfIterations = 100; - - /** Batch size. */ - private int batchSize = 100; - - /** Number of local iterations. */ - private int amountOfLocIterations = 100; - - /** Seed for random generator. */ - private long seed = 1234L; - - /** - * Trains model based on the specified data. - * - * @param datasetBuilder Dataset builder. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @return Model. - */ - @Override public <K, V> LogRegressionMultiClassModel fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { - List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); - - return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); - } - - /** {@inheritDoc} */ - @Override public <K, V> LogRegressionMultiClassModel updateModel(LogRegressionMultiClassModel newMdl, - DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { - - List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); - - if(classes.isEmpty()) - return getLastTrainedModelOrThrowEmptyDatasetException(newMdl); - - LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel(); - - classes.forEach(clsLb -> { - LogisticRegressionSGDTrainer<?> trainer = - new LogisticRegressionSGDTrainer<>() - .withBatchSize(batchSize) - .withLocIterations(amountOfLocIterations) - .withMaxIterations(amountOfIterations) - .withSeed(seed); - - IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> { - Double lb = lbExtractor.apply(k, v); - - if (lb.equals(clsLb)) - return 1.0; - else - return 0.0; - }; - - LogisticRegressionModel mdl = Optional.ofNullable(newMdl) - .flatMap(multiClassModel -> multiClassModel.getModel(clsLb)) - .map(learnedModel -> trainer.update(learnedModel, datasetBuilder, featureExtractor, lbTransformer)) - .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, lbTransformer)); - - multiClsMdl.add(clsLb, mdl); - }); - - return multiClsMdl; - } - - /** {@inheritDoc} */ - @Override protected boolean checkState(LogRegressionMultiClassModel mdl) { - return true; - } - - /** Iterates among dataset and collects class labels. */ - private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Double> lbExtractor) { - assert datasetBuilder != null; - - PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); - - List<Double> res = new ArrayList<>(); - - try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build( - envBuilder, - (env, upstream, upstreamSize) -> new EmptyContext(), - partDataBuilder - )) { - final Set<Double> clsLabels = dataset.compute(data -> { - final Set<Double> locClsLabels = new HashSet<>(); - - final double[] lbs = data.getY(); - - for (double lb : lbs) - locClsLabels.add(lb); - - return locClsLabels; - }, (a, b) -> { - if (a == null) - return b == null ? new HashSet<>() : b; - if (b == null) - return a; - return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()); - }); - - if (clsLabels != null) - res.addAll(clsLabels); - - } - catch (Exception e) { - throw new RuntimeException(e); - } - return res; - } - - /** - * Set up the regularization parameter. - * - * @param batchSize The size of learning batch. - * @return Trainer with new batch size parameter value. - */ - public LogRegressionMultiClassTrainer withBatchSize(int batchSize) { - this.batchSize = batchSize; - return this; - } - - /** - * Get the batch size. - * - * @return The parameter value. - */ - public double getBatchSize() { - return batchSize; - } - - /** - * Get the amount of outer iterations of SGD algorithm. - * - * @return The parameter value. - */ - public int getAmountOfIterations() { - return amountOfIterations; - } - - /** - * Set up the amount of outer iterations. - * - * @param amountOfIterations The parameter value. - * @return Trainer with new amountOfIterations parameter value. - */ - public LogRegressionMultiClassTrainer withAmountOfIterations(int amountOfIterations) { - this.amountOfIterations = amountOfIterations; - return this; - } - - /** - * Get the amount of local iterations. - * - * @return The parameter value. - */ - public int getAmountOfLocIterations() { - return amountOfLocIterations; - } - - /** - * Set up the amount of local iterations of SGD algorithm. - * - * @param amountOfLocIterations The parameter value. - * @return Trainer with new amountOfLocIterations parameter value. - */ - public LogRegressionMultiClassTrainer withAmountOfLocIterations(int amountOfLocIterations) { - this.amountOfLocIterations = amountOfLocIterations; - return this; - } - - /** - * Set up the random seed parameter. - * - * @param seed Seed for random generator. - * @return Trainer with new seed parameter value. - */ - public LogRegressionMultiClassTrainer withSeed(long seed) { - this.seed = seed; - return this; - } - - /** - * Get the seed for random generator. - * - * @return The parameter value. - */ - public long seed() { - return seed; - } - - /** - * Set up the updates strategy. - * - * @param updatesStgy Update strategy. - * @return Trainer with new update strategy parameter value. - */ - public LogRegressionMultiClassTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) { - this.updatesStgy = updatesStgy; - return this; - } - - /** - * Get the update strategy. - * - * @return The parameter value. - */ - public UpdatesStrategy getUpdatesStgy() { - return updatesStgy; - } -}
http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java deleted file mode 100644 index 2e7b947..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * <!-- Package description. --> - * Contains multi-class logistic regression. - */ -package org.apache.ignite.ml.regressions.logistic.multiclass; http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java deleted file mode 100644 index f5d2b28..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.io.Serializable; -import java.util.Objects; -import org.apache.ignite.ml.Exportable; -import org.apache.ignite.ml.Exporter; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.primitives.vector.Vector; - -/** - * Base class for SVM linear classification model. - */ -public class SVMLinearBinaryClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearBinaryClassificationModel>, Serializable { - /** */ - private static final long serialVersionUID = -996984622291440226L; - - /** Output label format. '0' and '1' for false value and raw distances from the separating hyperplane otherwise. */ - private boolean isKeepingRawLabels = false; - - /** Threshold to assign '1' label to the observation if raw value more than this threshold. */ - private double threshold = 0.5; - - /** Multiplier of the objects's vector required to make prediction. */ - private Vector weights; - - /** Intercept of the linear regression model. */ - private double intercept; - - /** */ - public SVMLinearBinaryClassificationModel(Vector weights, double intercept) { - this.weights = weights; - this.intercept = intercept; - } - - /** - * Set up the output label format. - * - * @param isKeepingRawLabels The parameter value. - * @return Model with new isKeepingRawLabels parameter value. - */ - public SVMLinearBinaryClassificationModel withRawLabels(boolean isKeepingRawLabels) { - this.isKeepingRawLabels = isKeepingRawLabels; - return this; - } - - /** - * Set up the threshold. - * - * @param threshold The parameter value. - * @return Model with new threshold parameter value. - */ - public SVMLinearBinaryClassificationModel withThreshold(double threshold) { - this.threshold = threshold; - return this; - } - - /** - * Set up the weights. - * - * @param weights The parameter value. - * @return Model with new weights parameter value. - */ - public SVMLinearBinaryClassificationModel withWeights(Vector weights) { - this.weights = weights; - return this; - } - - /** - * Set up the intercept. - * - * @param intercept The parameter value. - * @return Model with new intercept parameter value. - */ - public SVMLinearBinaryClassificationModel withIntercept(double intercept) { - this.intercept = intercept; - return this; - } - - /** {@inheritDoc} */ - @Override public Double apply(Vector input) { - final double res = input.dot(weights) + intercept; - if (isKeepingRawLabels) - return res; - else - return res - threshold > 0 ? 1.0 : 0; - } - - /** - * Gets the output label format mode. - * - * @return The parameter value. - */ - public boolean isKeepingRawLabels() { - return isKeepingRawLabels; - } - - /** - * Gets the threshold. - * - * @return The parameter value. - */ - public double threshold() { - return threshold; - } - - /** - * Gets the weights. - * - * @return The parameter value. - */ - public Vector weights() { - return weights; - } - - /** - * Gets the intercept. - * - * @return The parameter value. - */ - public double intercept() { - return intercept; - } - - /** {@inheritDoc} */ - @Override public <P> void saveModel(Exporter<SVMLinearBinaryClassificationModel, P> exporter, P path) { - exporter.save(this, path); - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - SVMLinearBinaryClassificationModel mdl = (SVMLinearBinaryClassificationModel)o; - - return Double.compare(mdl.intercept, intercept) == 0 - && Double.compare(mdl.threshold, threshold) == 0 - && Boolean.compare(mdl.isKeepingRawLabels, isKeepingRawLabels) == 0 - && Objects.equals(weights, mdl.weights); - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - return Objects.hash(weights, intercept, isKeepingRawLabels, threshold); - } - - /** {@inheritDoc} */ - @Override public String toString() { - if (weights.size() < 20) { - StringBuilder builder = new StringBuilder(); - - for (int i = 0; i < weights.size(); i++) { - double nextItem = i == weights.size() - 1 ? intercept : weights.get(i + 1); - - builder.append(String.format("%.4f", Math.abs(weights.get(i)))) - .append("*x") - .append(i) - .append(nextItem > 0 ? " + " : " - "); - } - - builder.append(String.format("%.4f", Math.abs(intercept))); - return builder.toString(); - } - - return "SVMModel [" + - "weights=" + weights + - ", intercept=" + intercept + - ']'; - } - - /** {@inheritDoc} */ - @Override public String toString(boolean pretty) { - return toString(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java deleted file mode 100644 index 7ceb53b..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ /dev/null @@ -1,356 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.util.Random; -import org.apache.ignite.ml.dataset.Dataset; -import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; -import org.apache.ignite.ml.math.StorageConstants; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector; -import org.apache.ignite.ml.structures.LabeledVector; -import org.apache.ignite.ml.structures.LabeledVectorSet; -import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; -import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; -import org.jetbrains.annotations.NotNull; - -/** - * Base class for a soft-margin SVM linear classification trainer based on the communication-efficient distributed dual - * coordinate ascent algorithm (CoCoA) with hinge-loss function. <p> This trainer takes input as Labeled Dataset with 0 - * and 1 labels for two classes and makes binary classification. </p> The paper about this algorithm could be found - * here https://arxiv.org/abs/1409.1458. - */ -public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrainer<SVMLinearBinaryClassificationModel> { - /** Amount of outer SDCA algorithm iterations. */ - private int amountOfIterations = 200; - - /** Amount of local SDCA algorithm iterations. */ - private int amountOfLocIterations = 100; - - /** Regularization parameter. */ - private double lambda = 0.4; - - /** The seed number. */ - private long seed = 1234L; - - /** - * Trains model based on the specified data. - * - * @param datasetBuilder Dataset builder. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @return Model. - */ - @Override public <K, V> SVMLinearBinaryClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - - return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); - } - - /** {@inheritDoc} */ - @Override protected <K, V> SVMLinearBinaryClassificationModel updateModel(SVMLinearBinaryClassificationModel mdl, - DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { - - assert datasetBuilder != null; - - IgniteBiFunction<K, V, Double> patchedLbExtractor = (k, v) -> { - final Double lb = lbExtractor.apply(k, v); - if (lb == 0.0) - return -1.0; - else - return lb; - }; - - PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( - featureExtractor, - patchedLbExtractor - ); - - Vector weights; - - try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build( - envBuilder, - (env, upstream, upstreamSize) -> new EmptyContext(), - partDataBuilder - )) { - if (mdl == null) { - final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> { - if (a == null) - return b == null ? 0 : b; - if (b == null) - return a; - return b; - }); - - final int weightVectorSizeWithIntercept = cols + 1; - weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept); - } else - weights = getStateVector(mdl); - - for (int i = 0; i < this.getAmountOfIterations(); i++) { - Vector deltaWeights = calculateUpdates(weights, dataset); - if (deltaWeights == null) - return getLastTrainedModelOrThrowEmptyDatasetException(mdl); - - weights = weights.plus(deltaWeights); // creates new vector - } - } catch (Exception e) { - throw new RuntimeException(e); - } - return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0)); - } - - /** {@inheritDoc} */ - @Override protected boolean checkState(SVMLinearBinaryClassificationModel mdl) { - return true; - } - - /** - * @param mdl Model. - * @return vector of model weights with intercept. - */ - private Vector getStateVector(SVMLinearBinaryClassificationModel mdl) { - double intercept = mdl.intercept(); - Vector weights = mdl.weights(); - - int stateVectorSize = weights.size() + 1; - Vector res = weights.isDense() ? - new DenseVector(stateVectorSize) : - new SparseVector(stateVectorSize, StorageConstants.RANDOM_ACCESS_MODE); - - res.set(0, intercept); - weights.nonZeroes().forEach(ith -> res.set(ith.index(), ith.get())); - return res; - } - - /** */ - @NotNull private Vector initializeWeightsWithZeros(int vectorSize) { - return new DenseVector(vectorSize); - } - - /** */ - private Vector calculateUpdates(Vector weights, - Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) { - return dataset.compute(data -> { - Vector copiedWeights = weights.copy(); - Vector deltaWeights = initializeWeightsWithZeros(weights.size()); - final int amountOfObservation = data.rowSize(); - - Vector tmpAlphas = initializeWeightsWithZeros(amountOfObservation); - Vector deltaAlphas = initializeWeightsWithZeros(amountOfObservation); - - Random random = new Random(seed); - - for (int i = 0; i < this.getAmountOfLocIterations(); i++) { - int randomIdx = random.nextInt(amountOfObservation); - - Deltas deltas = getDeltas(data, copiedWeights, amountOfObservation, tmpAlphas, randomIdx); - - copiedWeights = copiedWeights.plus(deltas.deltaWeights); // creates new vector - deltaWeights = deltaWeights.plus(deltas.deltaWeights); // creates new vector - - tmpAlphas.set(randomIdx, tmpAlphas.get(randomIdx) + deltas.deltaAlpha); - deltaAlphas.set(randomIdx, deltaAlphas.get(randomIdx) + deltas.deltaAlpha); - } - return deltaWeights; - }, (a, b) -> { - if (a == null) - return b == null ? new DenseVector() : b; - if (b == null) - return a; - return a.plus(b); - }); - } - - /** */ - private Deltas getDeltas(LabeledVectorSet data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas, - int randomIdx) { - LabeledVector row = (LabeledVector)data.getRow(randomIdx); - Double lb = (Double)row.label(); - Vector v = makeVectorWithInterceptElement(row); - - double alpha = tmpAlphas.get(randomIdx); - - return maximize(lb, v, alpha, copiedWeights, amountOfObservation); - } - - /** */ - private Vector makeVectorWithInterceptElement(LabeledVector row) { - Vector vec = row.features().like(row.features().size() + 1); - - vec.set(0, 1); // set intercept element - - for (int j = 0; j < row.features().size(); j++) - vec.set(j + 1, row.features().get(j)); - - return vec; - } - - /** */ - private Deltas maximize(double lb, Vector v, double alpha, Vector weights, int amountOfObservation) { - double gradient = calcGradient(lb, v, weights, amountOfObservation); - double prjGrad = calculateProjectionGradient(alpha, gradient); - - return calcDeltas(lb, v, alpha, prjGrad, weights.size(), amountOfObservation); - } - - /** */ - private Deltas calcDeltas(double lb, Vector v, double alpha, double gradient, int vectorSize, - int amountOfObservation) { - if (gradient != 0.0) { - - double qii = v.dot(v); - double newAlpha = calcNewAlpha(alpha, gradient, qii); - - Vector deltaWeights = v.times(lb * (newAlpha - alpha) / (this.getLambda() * amountOfObservation)); - - return new Deltas(newAlpha - alpha, deltaWeights); - } - else - return new Deltas(0.0, initializeWeightsWithZeros(vectorSize)); - } - - /** */ - private double calcNewAlpha(double alpha, double gradient, double qii) { - if (qii != 0.0) - return Math.min(Math.max(alpha - (gradient / qii), 0.0), 1.0); - else - return 1.0; - } - - /** */ - private double calcGradient(double lb, Vector v, Vector weights, int amountOfObservation) { - double dotProduct = v.dot(weights); - return (lb * dotProduct - 1.0) * (this.getLambda() * amountOfObservation); - } - - /** */ - private double calculateProjectionGradient(double alpha, double gradient) { - if (alpha <= 0.0) - return Math.min(gradient, 0.0); - - else if (alpha >= 1.0) - return Math.max(gradient, 0.0); - - else - return gradient; - } - - /** - * Set up the regularization parameter. - * - * @param lambda The regularization parameter. Should be more than 0.0. - * @return Trainer with new lambda parameter value. - */ - public SVMLinearBinaryClassificationTrainer withLambda(double lambda) { - assert lambda > 0.0; - this.lambda = lambda; - return this; - } - - /** - * Get the regularization lambda. - * - * @return The property value. - */ - public double getLambda() { - return lambda; - } - - /** - * Get the amount of outer iterations of SCDA algorithm. - * - * @return The property value. - */ - public int getAmountOfIterations() { - return amountOfIterations; - } - - /** - * Set up the amount of outer iterations of SCDA algorithm. - * - * @param amountOfIterations The parameter value. - * @return Trainer with new amountOfIterations parameter value. - */ - public SVMLinearBinaryClassificationTrainer withAmountOfIterations(int amountOfIterations) { - this.amountOfIterations = amountOfIterations; - return this; - } - - /** - * Get the amount of local iterations of SCDA algorithm. - * - * @return The property value. - */ - public int getAmountOfLocIterations() { - return amountOfLocIterations; - } - - /** - * Set up the amount of local iterations of SCDA algorithm. - * - * @param amountOfLocIterations The parameter value. - * @return Trainer with new amountOfLocIterations parameter value. - */ - public SVMLinearBinaryClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) { - this.amountOfLocIterations = amountOfLocIterations; - return this; - } - - /** - * Get the seed number. - * - * @return The property value. - */ - public long getSeed() { - return seed; - } - - /** - * Set up the seed. - * - * @param seed The parameter value. - * @return Model with new seed parameter value. - */ - public SVMLinearBinaryClassificationTrainer withSeed(long seed) { - this.seed = seed; - return this; - } -} - -/** This is a helper class to handle pair results which are returned from the calculation method. */ -class Deltas { - /** */ - public double deltaAlpha; - - /** */ - public Vector deltaWeights; - - /** */ - public Deltas(double deltaAlpha, Vector deltaWeights) { - this.deltaAlpha = deltaAlpha; - this.deltaWeights = deltaWeights; - } -} - - http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java new file mode 100644 index 0000000..579fdb2 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.svm; + +import java.io.Serializable; +import java.util.Objects; +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * Base class for SVM linear classification model. + */ +public class SVMLinearClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearClassificationModel>, Serializable { + /** */ + private static final long serialVersionUID = -996984622291440226L; + + /** Output label format. '0' and '1' for false value and raw distances from the separating hyperplane otherwise. */ + private boolean isKeepingRawLabels = false; + + /** Threshold to assign '1' label to the observation if raw value more than this threshold. */ + private double threshold = 0.5; + + /** Multiplier of the objects's vector required to make prediction. */ + private Vector weights; + + /** Intercept of the linear regression model. */ + private double intercept; + + /** */ + public SVMLinearClassificationModel(Vector weights, double intercept) { + this.weights = weights; + this.intercept = intercept; + } + + /** + * Set up the output label format. + * + * @param isKeepingRawLabels The parameter value. + * @return Model with new isKeepingRawLabels parameter value. + */ + public SVMLinearClassificationModel withRawLabels(boolean isKeepingRawLabels) { + this.isKeepingRawLabels = isKeepingRawLabels; + return this; + } + + /** + * Set up the threshold. + * + * @param threshold The parameter value. + * @return Model with new threshold parameter value. + */ + public SVMLinearClassificationModel withThreshold(double threshold) { + this.threshold = threshold; + return this; + } + + /** + * Set up the weights. + * + * @param weights The parameter value. + * @return Model with new weights parameter value. + */ + public SVMLinearClassificationModel withWeights(Vector weights) { + this.weights = weights; + return this; + } + + /** + * Set up the intercept. + * + * @param intercept The parameter value. + * @return Model with new intercept parameter value. + */ + public SVMLinearClassificationModel withIntercept(double intercept) { + this.intercept = intercept; + return this; + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector input) { + final double res = input.dot(weights) + intercept; + if (isKeepingRawLabels) + return res; + else + return res - threshold > 0 ? 1.0 : 0; + } + + /** + * Gets the output label format mode. + * + * @return The parameter value. + */ + public boolean isKeepingRawLabels() { + return isKeepingRawLabels; + } + + /** + * Gets the threshold. + * + * @return The parameter value. + */ + public double threshold() { + return threshold; + } + + /** + * Gets the weights. + * + * @return The parameter value. + */ + public Vector weights() { + return weights; + } + + /** + * Gets the intercept. + * + * @return The parameter value. + */ + public double intercept() { + return intercept; + } + + /** {@inheritDoc} */ + @Override public <P> void saveModel(Exporter<SVMLinearClassificationModel, P> exporter, P path) { + exporter.save(this, path); + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + SVMLinearClassificationModel mdl = (SVMLinearClassificationModel)o; + + return Double.compare(mdl.intercept, intercept) == 0 + && Double.compare(mdl.threshold, threshold) == 0 + && Boolean.compare(mdl.isKeepingRawLabels, isKeepingRawLabels) == 0 + && Objects.equals(weights, mdl.weights); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + return Objects.hash(weights, intercept, isKeepingRawLabels, threshold); + } + + /** {@inheritDoc} */ + @Override public String toString() { + if (weights.size() < 20) { + StringBuilder builder = new StringBuilder(); + + for (int i = 0; i < weights.size(); i++) { + double nextItem = i == weights.size() - 1 ? intercept : weights.get(i + 1); + + builder.append(String.format("%.4f", Math.abs(weights.get(i)))) + .append("*x") + .append(i) + .append(nextItem > 0 ? " + " : " - "); + } + + builder.append(String.format("%.4f", Math.abs(intercept))); + return builder.toString(); + } + + return "SVMModel [" + + "weights=" + weights + + ", intercept=" + intercept + + ']'; + } + + /** {@inheritDoc} */ + @Override public String toString(boolean pretty) { + return toString(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java new file mode 100644 index 0000000..67484ea --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.svm; + +import java.util.Random; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.StorageConstants; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector; +import org.apache.ignite.ml.structures.LabeledVector; +import org.apache.ignite.ml.structures.LabeledVectorSet; +import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; +import org.jetbrains.annotations.NotNull; + +/** + * Base class for a soft-margin SVM linear classification trainer based on the communication-efficient distributed dual + * coordinate ascent algorithm (CoCoA) with hinge-loss function. <p> This trainer takes input as Labeled Dataset with 0 + * and 1 labels for two classes and makes binary classification. </p> The paper about this algorithm could be found + * here https://arxiv.org/abs/1409.1458. + */ +public class SVMLinearClassificationTrainer extends SingleLabelDatasetTrainer<SVMLinearClassificationModel> { + /** Amount of outer SDCA algorithm iterations. */ + private int amountOfIterations = 200; + + /** Amount of local SDCA algorithm iterations. */ + private int amountOfLocIterations = 100; + + /** Regularization parameter. */ + private double lambda = 0.4; + + /** The seed number. */ + private long seed = 1234L; + + /** + * Trains model based on the specified data. + * + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @return Model. + */ + @Override public <K, V> SVMLinearClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + assert datasetBuilder != null; + + IgniteBiFunction<K, V, Double> patchedLbExtractor = (k, v) -> { + final Double lb = lbExtractor.apply(k, v); + if (lb == 0.0) + return -1.0; + else + return lb; + }; + + PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( + featureExtractor, + patchedLbExtractor + ); + + Vector weights; + + try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build( + envBuilder, + (env, upstream, upstreamSize) -> new EmptyContext(), + partDataBuilder + )) { + if (mdl == null) { + final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> { + if (a == null) + return b == null ? 0 : b; + if (b == null) + return a; + return b; + }); + + final int weightVectorSizeWithIntercept = cols + 1; + weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept); + } else + weights = getStateVector(mdl); + + for (int i = 0; i < this.getAmountOfIterations(); i++) { + Vector deltaWeights = calculateUpdates(weights, dataset); + if (deltaWeights == null) + return getLastTrainedModelOrThrowEmptyDatasetException(mdl); + + weights = weights.plus(deltaWeights); // creates new vector + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return new SVMLinearClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0)); + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(SVMLinearClassificationModel mdl) { + return true; + } + + /** + * @param mdl Model. + * @return vector of model weights with intercept. + */ + private Vector getStateVector(SVMLinearClassificationModel mdl) { + double intercept = mdl.intercept(); + Vector weights = mdl.weights(); + + int stateVectorSize = weights.size() + 1; + Vector res = weights.isDense() ? + new DenseVector(stateVectorSize) : + new SparseVector(stateVectorSize, StorageConstants.RANDOM_ACCESS_MODE); + + res.set(0, intercept); + weights.nonZeroes().forEach(ith -> res.set(ith.index(), ith.get())); + return res; + } + + /** */ + @NotNull private Vector initializeWeightsWithZeros(int vectorSize) { + return new DenseVector(vectorSize); + } + + /** */ + private Vector calculateUpdates(Vector weights, + Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) { + return dataset.compute(data -> { + Vector copiedWeights = weights.copy(); + Vector deltaWeights = initializeWeightsWithZeros(weights.size()); + final int amountOfObservation = data.rowSize(); + + Vector tmpAlphas = initializeWeightsWithZeros(amountOfObservation); + Vector deltaAlphas = initializeWeightsWithZeros(amountOfObservation); + + Random random = new Random(seed); + + for (int i = 0; i < this.getAmountOfLocIterations(); i++) { + int randomIdx = random.nextInt(amountOfObservation); + + Deltas deltas = getDeltas(data, copiedWeights, amountOfObservation, tmpAlphas, randomIdx); + + copiedWeights = copiedWeights.plus(deltas.deltaWeights); // creates new vector + deltaWeights = deltaWeights.plus(deltas.deltaWeights); // creates new vector + + tmpAlphas.set(randomIdx, tmpAlphas.get(randomIdx) + deltas.deltaAlpha); + deltaAlphas.set(randomIdx, deltaAlphas.get(randomIdx) + deltas.deltaAlpha); + } + return deltaWeights; + }, (a, b) -> { + if (a == null) + return b == null ? new DenseVector() : b; + if (b == null) + return a; + return a.plus(b); + }); + } + + /** */ + private Deltas getDeltas(LabeledVectorSet data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas, + int randomIdx) { + LabeledVector row = (LabeledVector)data.getRow(randomIdx); + Double lb = (Double)row.label(); + Vector v = makeVectorWithInterceptElement(row); + + double alpha = tmpAlphas.get(randomIdx); + + return maximize(lb, v, alpha, copiedWeights, amountOfObservation); + } + + /** */ + private Vector makeVectorWithInterceptElement(LabeledVector row) { + Vector vec = row.features().like(row.features().size() + 1); + + vec.set(0, 1); // set intercept element + + for (int j = 0; j < row.features().size(); j++) + vec.set(j + 1, row.features().get(j)); + + return vec; + } + + /** */ + private Deltas maximize(double lb, Vector v, double alpha, Vector weights, int amountOfObservation) { + double gradient = calcGradient(lb, v, weights, amountOfObservation); + double prjGrad = calculateProjectionGradient(alpha, gradient); + + return calcDeltas(lb, v, alpha, prjGrad, weights.size(), amountOfObservation); + } + + /** */ + private Deltas calcDeltas(double lb, Vector v, double alpha, double gradient, int vectorSize, + int amountOfObservation) { + if (gradient != 0.0) { + + double qii = v.dot(v); + double newAlpha = calcNewAlpha(alpha, gradient, qii); + + Vector deltaWeights = v.times(lb * (newAlpha - alpha) / (this.getLambda() * amountOfObservation)); + + return new Deltas(newAlpha - alpha, deltaWeights); + } + else + return new Deltas(0.0, initializeWeightsWithZeros(vectorSize)); + } + + /** */ + private double calcNewAlpha(double alpha, double gradient, double qii) { + if (qii != 0.0) + return Math.min(Math.max(alpha - (gradient / qii), 0.0), 1.0); + else + return 1.0; + } + + /** */ + private double calcGradient(double lb, Vector v, Vector weights, int amountOfObservation) { + double dotProduct = v.dot(weights); + return (lb * dotProduct - 1.0) * (this.getLambda() * amountOfObservation); + } + + /** */ + private double calculateProjectionGradient(double alpha, double gradient) { + if (alpha <= 0.0) + return Math.min(gradient, 0.0); + + else if (alpha >= 1.0) + return Math.max(gradient, 0.0); + + else + return gradient; + } + + /** + * Set up the regularization parameter. + * + * @param lambda The regularization parameter. Should be more than 0.0. + * @return Trainer with new lambda parameter value. + */ + public SVMLinearClassificationTrainer withLambda(double lambda) { + assert lambda > 0.0; + this.lambda = lambda; + return this; + } + + /** + * Get the regularization lambda. + * + * @return The property value. + */ + public double getLambda() { + return lambda; + } + + /** + * Get the amount of outer iterations of SCDA algorithm. + * + * @return The property value. + */ + public int getAmountOfIterations() { + return amountOfIterations; + } + + /** + * Set up the amount of outer iterations of SCDA algorithm. + * + * @param amountOfIterations The parameter value. + * @return Trainer with new amountOfIterations parameter value. + */ + public SVMLinearClassificationTrainer withAmountOfIterations(int amountOfIterations) { + this.amountOfIterations = amountOfIterations; + return this; + } + + /** + * Get the amount of local iterations of SCDA algorithm. + * + * @return The property value. + */ + public int getAmountOfLocIterations() { + return amountOfLocIterations; + } + + /** + * Set up the amount of local iterations of SCDA algorithm. + * + * @param amountOfLocIterations The parameter value. + * @return Trainer with new amountOfLocIterations parameter value. + */ + public SVMLinearClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) { + this.amountOfLocIterations = amountOfLocIterations; + return this; + } + + /** + * Get the seed number. + * + * @return The property value. + */ + public long getSeed() { + return seed; + } + + /** + * Set up the seed. + * + * @param seed The parameter value. + * @return Model with new seed parameter value. + */ + public SVMLinearClassificationTrainer withSeed(long seed) { + this.seed = seed; + return this; + } +} + +/** This is a helper class to handle pair results which are returned from the calculation method. */ +class Deltas { + /** */ + public double deltaAlpha; + + /** */ + public Vector deltaWeights; + + /** */ + public Deltas(double deltaAlpha, Vector deltaWeights) { + this.deltaAlpha = deltaAlpha; + this.deltaWeights = deltaWeights; + } +} + + http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java deleted file mode 100644 index 46bf4b2..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.io.Serializable; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.TreeMap; -import org.apache.ignite.ml.Exportable; -import org.apache.ignite.ml.Exporter; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.primitives.vector.Vector; - -/** Base class for multi-classification model for set of SVM classifiers. */ -public class SVMLinearMultiClassClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearMultiClassClassificationModel>, Serializable { - /** */ - private static final long serialVersionUID = -667986511191350227L; - - /** List of models associated with each class. */ - private Map<Double, SVMLinearBinaryClassificationModel> models; - - /** */ - public SVMLinearMultiClassClassificationModel() { - this.models = new HashMap<>(); - } - - /** {@inheritDoc} */ - @Override public Double apply(Vector input) { - TreeMap<Double, Double> maxMargins = new TreeMap<>(); - - models.forEach((k, v) -> maxMargins.put(input.dot(v.weights()) + v.intercept(), k)); - - return maxMargins.lastEntry().getValue(); - } - - /** {@inheritDoc} */ - @Override public <P> void saveModel(Exporter<SVMLinearMultiClassClassificationModel, P> exporter, P path) { - exporter.save(this, path); - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - - if (o == null || getClass() != o.getClass()) - return false; - - SVMLinearMultiClassClassificationModel mdl = (SVMLinearMultiClassClassificationModel)o; - - return Objects.equals(models, mdl.models); - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - return Objects.hash(models); - } - - /** {@inheritDoc} */ - @Override public String toString() { - StringBuilder wholeStr = new StringBuilder(); - - models.forEach((clsLb, mdl) -> - wholeStr - .append("The class with label ") - .append(clsLb) - .append(" has classifier: ") - .append(mdl.toString()) - .append(System.lineSeparator()) - ); - - return wholeStr.toString(); - } - - /** {@inheritDoc} */ - @Override public String toString(boolean pretty) { - return toString(); - } - - /** - * Adds a specific SVM binary classifier to the bunch of same classifiers. - * - * @param clsLb The class label for the added model. - * @param mdl The model. - */ - public void add(double clsLb, SVMLinearBinaryClassificationModel mdl) { - models.put(clsLb, mdl); - } - - /** - * @param clsLb Class label. - * @return model trained for target class if it exists. - */ - public Optional<SVMLinearBinaryClassificationModel> getModelForClass(double clsLb) { - return Optional.of(models.get(clsLb)); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java deleted file mode 100644 index 94f2a99..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.svm; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.apache.ignite.ml.dataset.Dataset; -import org.apache.ignite.ml.dataset.DatasetBuilder; -import org.apache.ignite.ml.dataset.PartitionDataBuilder; -import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; -import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap; -import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; - -/** - * Base class for a soft-margin SVM linear multiclass-classification trainer based on the communication-efficient - * distributed dual coordinate ascent algorithm (CoCoA) with hinge-loss function. - * - * All common parameters are shared with bunch of binary classification trainers. - */ -public class SVMLinearMultiClassClassificationTrainer - extends SingleLabelDatasetTrainer<SVMLinearMultiClassClassificationModel> { - /** Amount of outer SDCA algorithm iterations. */ - private int amountOfIterations = 20; - - /** Amount of local SDCA algorithm iterations. */ - private int amountOfLocIterations = 50; - - /** Regularization parameter. */ - private double lambda = 0.2; - - /** The seed number. */ - private long seed = 1234L; - - /** - * Trains model based on the specified data. - * - * @param datasetBuilder Dataset builder. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @return Model. - */ - @Override public <K, V> SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { - return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); - } - - /** {@inheritDoc} */ - @Override public <K, V> SVMLinearMultiClassClassificationModel updateModel( - SVMLinearMultiClassClassificationModel mdl, - DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { - - List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); - if (classes.isEmpty()) - return getLastTrainedModelOrThrowEmptyDatasetException(mdl); - - SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel(); - - classes.forEach(clsLb -> { - SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer() - .withAmountOfIterations(this.getAmountOfIterations()) - .withAmountOfLocIterations(this.getAmountOfLocIterations()) - .withLambda(this.getLambda()) - .withSeed(this.seed); - - IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> { - Double lb = lbExtractor.apply(k, v); - - if (lb.equals(clsLb)) - return 1.0; - else - return 0.0; - }; - - SVMLinearBinaryClassificationModel updatedMdl; - - if (mdl == null) - updatedMdl = learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer); - else - updatedMdl = updateModel(mdl, clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer); - multiClsMdl.add(clsLb, updatedMdl); - }); - - return multiClsMdl; - } - - /** {@inheritDoc} */ - @Override protected boolean checkState(SVMLinearMultiClassClassificationModel mdl) { - return true; - } - - /** - * Trains model based on the specified data. - * - * @param svmTrainer Prepared SVM trainer. - * @param datasetBuilder Dataset builder. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - */ - private <K, V> SVMLinearBinaryClassificationModel learnNewModel(SVMLinearBinaryClassificationTrainer svmTrainer, - DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { - - return svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor); - } - - /** - * Updates already learned model or fit new model if there is no model for current class label. - * - * @param multiClsMdl Learning multi-class model. - * @param clsLb Current class label. - * @param svmTrainer Prepared SVM trainer. - * @param datasetBuilder Dataset builder. - * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - */ - private <K, V> SVMLinearBinaryClassificationModel updateModel(SVMLinearMultiClassClassificationModel multiClsMdl, - Double clsLb, SVMLinearBinaryClassificationTrainer svmTrainer, DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - - return multiClsMdl.getModelForClass(clsLb) - .map(learnedModel -> svmTrainer.update(learnedModel, datasetBuilder, featureExtractor, lbExtractor)) - .orElseGet(() -> svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor)); - } - - /** Iterates among dataset and collects class labels. */ - private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Double> lbExtractor) { - assert datasetBuilder != null; - - PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); - - List<Double> res = new ArrayList<>(); - - try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build( - envBuilder, - (env, upstream, upstreamSize) -> new EmptyContext(), - partDataBuilder - )) { - final Set<Double> clsLabels = dataset.compute(data -> { - final Set<Double> locClsLabels = new HashSet<>(); - - final double[] lbs = data.getY(); - - for (double lb : lbs) - locClsLabels.add(lb); - - return locClsLabels; - }, (a, b) -> { - if (a == null) - return b == null ? new HashSet<>() : b; - if (b == null) - return a; - return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()); - }); - - if (clsLabels != null) - res.addAll(clsLabels); - } catch (Exception e) { - throw new RuntimeException(e); - } - return res; - } - - /** - * Set up the regularization parameter. - * - * @param lambda The regularization parameter. Should be more than 0.0. - * @return Trainer with new lambda parameter value. - */ - public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) { - assert lambda > 0.0; - this.lambda = lambda; - return this; - } - - /** - * Get the regularization lambda. - * - * @return The property value. - */ - public double getLambda() { - return lambda; - } - - /** - * Gets the amount of outer iterations of SCDA algorithm. - * - * @return The property value. - */ - public int getAmountOfIterations() { - return amountOfIterations; - } - - /** - * Set up the amount of outer iterations of SCDA algorithm. - * - * @param amountOfIterations The parameter value. - * @return Trainer with new amountOfIterations parameter value. - */ - public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) { - this.amountOfIterations = amountOfIterations; - return this; - } - - /** - * Gets the amount of local iterations of SCDA algorithm. - * - * @return The property value. - */ - public int getAmountOfLocIterations() { - return amountOfLocIterations; - } - - /** - * Set up the amount of local iterations of SCDA algorithm. - * - * @param amountOfLocIterations The parameter value. - * @return Trainer with new amountOfLocIterations parameter value. - */ - public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) { - this.amountOfLocIterations = amountOfLocIterations; - return this; - } - - /** - * Gets the seed number. - * - * @return The property value. - */ - public long getSeed() { - return seed; - } - - /** - * Set up the seed. - * - * @param seed The parameter value. - * @return Model with new seed parameter value. - */ - public SVMLinearMultiClassClassificationTrainer withSeed(long seed) { - this.seed = seed; - return this; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java index 745eac9..e951145 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java @@ -34,16 +34,13 @@ import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix; -import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; import org.apache.ignite.ml.structures.Dataset; import org.apache.ignite.ml.structures.DatasetRow; import org.apache.ignite.ml.structures.FeatureMetadata; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.structures.LabeledVectorSet; -import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel; -import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearClassificationModel; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -92,17 +89,7 @@ public class CollectionsTest { test(new KNNClassificationModel(null).withK(1), new KNNClassificationModel(null).withK(2)); - LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel(); - mdl.add(1, new LogisticRegressionModel(new DenseVector(), 1.0)); - test(mdl, new LogRegressionMultiClassModel()); - - test(new LinearRegressionModel(null, 1.0), new LinearRegressionModel(null, 0.5)); - - SVMLinearMultiClassClassificationModel mdl1 = new SVMLinearMultiClassClassificationModel(); - mdl1.add(1, new SVMLinearBinaryClassificationModel(new DenseVector(), 1.0)); - test(mdl1, new SVMLinearMultiClassClassificationModel()); - - test(new SVMLinearBinaryClassificationModel(null, 1.0), new SVMLinearBinaryClassificationModel(null, 0.5)); + test(new SVMLinearClassificationModel(null, 1.0), new SVMLinearClassificationModel(null, 0.5)); test(new ANNClassificationModel(new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNClassificationModel(new LabeledVectorSet<>(1, 1, true), new ANNClassificationTrainer.CentroidStat())); http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java index ca3f0b5..c5b2ffe 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java @@ -43,12 +43,10 @@ import org.apache.ignite.ml.math.distances.ManhattanDistance; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.structures.LabeledVectorSet; -import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel; -import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearClassificationModel; import org.junit.Assert; import org.junit.Test; @@ -99,36 +97,11 @@ public class LocalModelsTest { @Test public void importExportSVMBinaryClassificationModelTest() throws IOException { executeModelTest(mdlFilePath -> { - SVMLinearBinaryClassificationModel mdl = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{1, 2}), 3); - Exporter<SVMLinearBinaryClassificationModel, String> exporter = new FileExporter<>(); + SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(new DenseVector(new double[] {1, 2}), 3); + Exporter<SVMLinearClassificationModel, String> exporter = new FileExporter<>(); mdl.saveModel(exporter, mdlFilePath); - SVMLinearBinaryClassificationModel load = exporter.load(mdlFilePath); - - Assert.assertNotNull(load); - Assert.assertEquals("", mdl, load); - - return null; - }); - } - - /** */ - @Test - public void importExportSVMMultiClassClassificationModelTest() throws IOException { - executeModelTest(mdlFilePath -> { - SVMLinearBinaryClassificationModel binaryMdl1 = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{1, 2}), 3); - SVMLinearBinaryClassificationModel binaryMdl2 = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{2, 3}), 4); - SVMLinearBinaryClassificationModel binaryMdl3 = new SVMLinearBinaryClassificationModel(new DenseVector(new double[]{3, 4}), 5); - - SVMLinearMultiClassClassificationModel mdl = new SVMLinearMultiClassClassificationModel(); - mdl.add(1, binaryMdl1); - mdl.add(2, binaryMdl2); - mdl.add(3, binaryMdl3); - - Exporter<SVMLinearMultiClassClassificationModel, String> exporter = new FileExporter<>(); - mdl.saveModel(exporter, mdlFilePath); - - SVMLinearMultiClassClassificationModel load = exporter.load(mdlFilePath); + SVMLinearClassificationModel load = exporter.load(mdlFilePath); Assert.assertNotNull(load); Assert.assertEquals("", mdl, load); @@ -155,23 +128,6 @@ public class LocalModelsTest { } /** */ - @Test - public void importExportLogRegressionMultiClassModelTest() throws IOException { - executeModelTest(mdlFilePath -> { - LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel(); - Exporter<LogRegressionMultiClassModel, String> exporter = new FileExporter<>(); - mdl.saveModel(exporter, mdlFilePath); - - LogRegressionMultiClassModel load = exporter.load(mdlFilePath); - - Assert.assertNotNull(load); - Assert.assertEquals("", mdl, load); - - return null; - }); - } - - /** */ private void executeModelTest(Function<String, Void> code) throws IOException { Path mdlPath = Files.createTempFile(null, null); http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java index 9842d92..61f9fc4 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java @@ -28,8 +28,8 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.junit.Assert; import org.junit.Test; http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java index e59d515..8445900 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineMdlTest.java @@ -20,7 +20,7 @@ package org.apache.ignite.ml.pipeline; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; import org.junit.Test; /** http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java index d517ce6..fec6220 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java @@ -29,7 +29,7 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpda import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.junit.Test; /** http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java index 021b567..2fa69ef 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java @@ -20,7 +20,6 @@ package org.apache.ignite.ml.regressions; import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest; import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest; import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainerTest; -import org.apache.ignite.ml.regressions.logistic.LogRegMultiClassTrainerTest; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModelTest; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainerTest; import org.junit.runner.RunWith; @@ -35,8 +34,7 @@ import org.junit.runners.Suite; LinearRegressionLSQRTrainerTest.class, LinearRegressionSGDTrainerTest.class, LogisticRegressionModelTest.class, - LogisticRegressionSGDTrainerTest.class, - LogRegMultiClassTrainerTest.class + LogisticRegressionSGDTrainerTest.class }) public class RegressionsTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java index 66871b0..36d0fc7 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java @@ -21,8 +21,6 @@ import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; import org.junit.Test; import static org.junit.Assert.assertTrue; @@ -40,9 +38,9 @@ public class LinearRegressionModelTest { Vector weights = new DenseVector(new double[]{2.0, 3.0}); LinearRegressionModel mdl = new LinearRegressionModel(weights, 1.0); - assertTrue(mdl.toString().length() > 0); - assertTrue(mdl.toString(true).length() > 0); - assertTrue(mdl.toString(false).length() > 0); + assertTrue(!mdl.toString().isEmpty()); + assertTrue(!mdl.toString(true).isEmpty()); + assertTrue(!mdl.toString(false).isEmpty()); Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION); @@ -61,21 +59,6 @@ public class LinearRegressionModelTest { } /** */ - @Test - public void testPredictWithMultiClasses() { - Vector weights1 = new DenseVector(new double[]{10.0, 0.0}); - Vector weights2 = new DenseVector(new double[]{0.0, 10.0}); - Vector weights3 = new DenseVector(new double[]{-1.0, -1.0}); - LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel(); - mdl.add(1, new LogisticRegressionModel(weights1, 0.0).withRawLabels(true)); - mdl.add(2, new LogisticRegressionModel(weights2, 0.0).withRawLabels(true)); - mdl.add(2, new LogisticRegressionModel(weights3, 0.0).withRawLabels(true)); - - Vector observation = new DenseVector(new double[]{1.0, 1.0}); - TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION); - } - - /** */ @Test(expected = CardinalityException.class) public void testPredictOnAnObservationWithWrongCardinality() { Vector weights = new DenseVector(new double[]{2.0, 3.0}); http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java deleted file mode 100644 index c99bf02..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.logistic; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.common.TrainerTest; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.math.primitives.vector.VectorUtils; -import org.apache.ignite.ml.nn.UpdatesStrategy; -import org.apache.ignite.ml.optimization.SmoothParametrized; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; -import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer; -import org.junit.Assert; -import org.junit.Test; - -/** - * Tests for {@link LogRegressionMultiClassTrainer}. - */ -public class LogRegMultiClassTrainerTest extends TrainerTest { - /** - * Test trainer on 4 sets grouped around of square vertices. - */ - @Test - public void testTrainWithTheLinearlySeparableCase() { - Map<Integer, double[]> cacheMock = new HashMap<>(); - - for (int i = 0; i < fourSetsInSquareVertices.length; i++) - cacheMock.put(i, fourSetsInSquareVertices[i]); - - final UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> stgy = new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ); - - LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>() - .withUpdatesStgy(stgy) - .withAmountOfIterations(1000) - .withAmountOfLocIterations(10) - .withBatchSize(100) - .withSeed(123L); - - Assert.assertEquals(trainer.getAmountOfIterations(), 1000); - Assert.assertEquals(trainer.getAmountOfLocIterations(), 10); - Assert.assertEquals(trainer.getBatchSize(), 100, PRECISION); - Assert.assertEquals(trainer.seed(), 123L); - Assert.assertEquals(trainer.getUpdatesStgy(), stgy); - - LogRegressionMultiClassModel mdl = trainer.fit( - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - Assert.assertTrue(mdl.toString().length() > 0); - Assert.assertTrue(mdl.toString(true).length() > 0); - Assert.assertTrue(mdl.toString(false).length() > 0); - - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 10)), PRECISION); - TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-10, 10)), PRECISION); - TestUtils.assertEquals(2, mdl.apply(VectorUtils.of(-10, -10)), PRECISION); - TestUtils.assertEquals(3, mdl.apply(VectorUtils.of(10, -10)), PRECISION); - } - - /** */ - @Test - public void testUpdate() { - Map<Integer, double[]> cacheMock = new HashMap<>(); - - for (int i = 0; i < fourSetsInSquareVertices.length; i++) - cacheMock.put(i, fourSetsInSquareVertices[i]); - - LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>() - .withUpdatesStgy(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - )) - .withAmountOfIterations(1000) - .withAmountOfLocIterations(10) - .withBatchSize(100) - .withSeed(123L); - - LogRegressionMultiClassModel originalMdl = trainer.fit( - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - LogRegressionMultiClassModel updatedOnSameDS = trainer.update( - originalMdl, - cacheMock, - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - LogRegressionMultiClassModel updatedOnEmptyDS = trainer.update( - originalMdl, - new HashMap<Integer, double[]>(), - parts, - (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), - (k, v) -> v[0] - ); - - List<Vector> vectors = Arrays.asList( - VectorUtils.of(10, 10), - VectorUtils.of(-10, 10), - VectorUtils.of(-10, -10), - VectorUtils.of(10, -10) - ); - - for (Vector vec : vectors) { - TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION); - TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION); - } - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/098caf44/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java index e8aaacd..4fae638 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java @@ -21,7 +21,6 @@ import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; import org.junit.Test; import static org.junit.Assert.assertEquals;
