http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java deleted file mode 100644 index aafeae8..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.ignite.ml.regressions; - -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.decompositions.QRDSolver; -import org.apache.ignite.ml.math.decompositions.QRDecomposition; -import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; -import org.apache.ignite.ml.math.exceptions.SingularMatrixException; -import org.apache.ignite.ml.math.functions.Functions; - -/** - * This class is based on the corresponding class from Apache Common Math lib. - * <p>Implements ordinary least squares (OLS) to estimate the parameters of a - * multiple linear regression model.</p> - * - * <p>The regression coefficients, <code>b</code>, satisfy the normal equations: - * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p> - * - * <p>To solve the normal equations, this implementation uses QR decomposition - * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the - * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i> - * has rows corresponding to sample observations and columns corresponding to independent - * variables. When the model is estimated using an intercept term (i.e. when - * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code> - * matrix includes an initial column identically equal to 1. We solve the normal equations - * as follows: - * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y - * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y - * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y - * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y - * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y - * R b = Q<sup>T</sup> y </code></pre></p> - * - * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p> - */ -public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression { - /** Cached QR decomposition of X matrix */ - private QRDSolver solver = null; - - /** Singularity threshold for QR decomposition */ - private final double threshold; - - /** - * Create an empty OLSMultipleLinearRegression instance. - */ - public OLSMultipleLinearRegression() { - this(0d); - } - - /** - * Create an empty OLSMultipleLinearRegression instance, using the given - * singularity threshold for the QR decomposition. - * - * @param threshold the singularity threshold - */ - public OLSMultipleLinearRegression(final double threshold) { - this.threshold = threshold; - } - - /** - * Loads model x and y sample data, overriding any previous sample. - * - * Computes and caches QR decomposition of the X matrix. - * - * @param y the {@code n}-sized vector representing the y sample - * @param x the {@code n x k} matrix representing the x sample - * @throws MathIllegalArgumentException if the x and y array data are not compatible for the regression - */ - public void newSampleData(Vector y, Matrix x) throws MathIllegalArgumentException { - validateSampleData(x, y); - newYSampleData(y); - newXSampleData(x); - } - - /** - * {@inheritDoc} - * <p>This implementation computes and caches the QR decomposition of the X matrix.</p> - */ - @Override public void newSampleData(double[] data, int nobs, int nvars, Matrix like) { - super.newSampleData(data, nobs, nvars, like); - QRDecomposition qr = new QRDecomposition(getX(), threshold); - solver = new QRDSolver(qr.getQ(), qr.getR()); - } - - /** - * <p>Compute the "hat" matrix. - * </p> - * <p>The hat matrix is defined in terms of the design matrix X - * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup> - * </p> - * <p>The implementation here uses the QR decomposition to compute the - * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the - * p-dimensional identity matrix augmented by 0's. This computational - * formula is from "The Hat Matrix in Regression and ANOVA", - * David C. Hoaglin and Roy E. Welsch, - * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. - * </p> - * <p>Data for the model must have been successfully loaded using one of - * the {@code newSampleData} methods before invoking this method; otherwise - * a {@code NullPointerException} will be thrown.</p> - * - * @return the hat matrix - * @throws NullPointerException unless method {@code newSampleData} has been called beforehand. - */ - public Matrix calculateHat() { - return solver.calculateHat(); - } - - /** - * <p>Returns the sum of squared deviations of Y from its mean.</p> - * - * <p>If the model has no intercept term, <code>0</code> is used for the - * mean of Y - i.e., what is returned is the sum of the squared Y values.</p> - * - * <p>The value returned by this method is the SSTO value used in - * the {@link #calculateRSquared() R-squared} computation.</p> - * - * @return SSTO - the total sum of squares - * @throws NullPointerException if the sample has not been set - * @see #isNoIntercept() - */ - public double calculateTotalSumOfSquares() { - if (isNoIntercept()) - return getY().foldMap(Functions.PLUS, Functions.SQUARE, 0.0); - else { - // TODO: IGNITE-5826, think about incremental update formula. - final double mean = getY().sum() / getY().size(); - return getY().foldMap(Functions.PLUS, x -> (mean - x) * (mean - x), 0.0); - } - } - - /** - * Returns the sum of squared residuals. - * - * @return residual sum of squares - * @throws SingularMatrixException if the design matrix is singular - * @throws NullPointerException if the data for the model have not been loaded - */ - public double calculateResidualSumOfSquares() { - final Vector residuals = calculateResiduals(); - // No advertised DME, args are valid - return residuals.dot(residuals); - } - - /** - * Returns the R-Squared statistic, defined by the formula <pre> - * R<sup>2</sup> = 1 - SSR / SSTO - * </pre> - * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals} - * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares} - * - * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p> - * - * @return R-square statistic - * @throws NullPointerException if the sample has not been set - * @throws SingularMatrixException if the design matrix is singular - */ - public double calculateRSquared() { - return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares(); - } - - /** - * <p>Returns the adjusted R-squared statistic, defined by the formula <pre> - * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)] - * </pre> - * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}, - * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number - * of observations and p is the number of parameters estimated (including the intercept).</p> - * - * <p>If the regression is estimated without an intercept term, what is returned is <pre> - * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code> - * </pre></p> - * - * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p> - * - * @return adjusted R-Squared statistic - * @throws NullPointerException if the sample has not been set - * @throws SingularMatrixException if the design matrix is singular - * @see #isNoIntercept() - */ - public double calculateAdjustedRSquared() { - final double n = getX().rowSize(); - if (isNoIntercept()) - return 1 - (1 - calculateRSquared()) * (n / (n - getX().columnSize())); - else - return 1 - (calculateResidualSumOfSquares() * (n - 1)) / - (calculateTotalSumOfSquares() * (n - getX().columnSize())); - } - - /** - * {@inheritDoc} - * <p>This implementation computes and caches the QR decomposition of the X matrix - * once it is successfully loaded.</p> - */ - @Override protected void newXSampleData(Matrix x) { - super.newXSampleData(x); - QRDecomposition qr = new QRDecomposition(getX()); - solver = new QRDSolver(qr.getQ(), qr.getR()); - } - - /** - * Calculates the regression coefficients using OLS. - * - * <p>Data for the model must have been successfully loaded using one of - * the {@code newSampleData} methods before invoking this method; otherwise - * a {@code NullPointerException} will be thrown.</p> - * - * @return beta - * @throws SingularMatrixException if the design matrix is singular - * @throws NullPointerException if the data for the model have not been loaded - */ - @Override protected Vector calculateBeta() { - return solver.solve(getY()); - } - - /** - * <p>Calculates the variance-covariance matrix of the regression parameters. - * </p> - * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup> - * </p> - * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup> - * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of - * R included, where p = the length of the beta vector.</p> - * - * <p>Data for the model must have been successfully loaded using one of - * the {@code newSampleData} methods before invoking this method; otherwise - * a {@code NullPointerException} will be thrown.</p> - * - * @return The beta variance-covariance matrix - * @throws SingularMatrixException if the design matrix is singular - * @throws NullPointerException if the data for the model have not been loaded - */ - @Override protected Matrix calculateBetaVariance() { - return solver.calculateBetaVariance(getX().columnSize()); - } - - /** */ - QRDSolver solver() { - return solver; - } -}
http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java deleted file mode 100644 index b95cbf3..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions; - -import org.apache.ignite.ml.Exportable; -import org.apache.ignite.ml.Exporter; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.decompositions.QRDSolver; -import org.apache.ignite.ml.math.decompositions.QRDecomposition; - -/** - * Model for linear regression. - */ -public class OLSMultipleLinearRegressionModel implements Model<Vector, Vector>, - Exportable<OLSMultipleLinearRegressionModelFormat> { - /** */ - private final Matrix xMatrix; - /** */ - private final QRDSolver solver; - - /** - * Construct linear regression model. - * - * @param xMatrix See {@link QRDecomposition#QRDecomposition(Matrix)}. - * @param solver Linear regression solver object. - */ - public OLSMultipleLinearRegressionModel(Matrix xMatrix, QRDSolver solver) { - this.xMatrix = xMatrix; - this.solver = solver; - } - - /** {@inheritDoc} */ - @Override public Vector apply(Vector val) { - return xMatrix.times(solver.solve(val)); - } - - /** {@inheritDoc} */ - @Override public <P> void saveModel(Exporter<OLSMultipleLinearRegressionModelFormat, P> exporter, P path) { - exporter.save(new OLSMultipleLinearRegressionModelFormat(xMatrix, solver), path); - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - OLSMultipleLinearRegressionModel mdl = (OLSMultipleLinearRegressionModel)o; - - return xMatrix.equals(mdl.xMatrix) && solver.equals(mdl.solver); - } - - /** {@inheritDoc} */ - @Override public int hashCode() { - int res = xMatrix.hashCode(); - res = 31 * res + solver.hashCode(); - return res; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java deleted file mode 100644 index fc44968..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions; - -import java.io.Serializable; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.decompositions.QRDSolver; - -/** - * Linear regression model representation. - * - * @see OLSMultipleLinearRegressionModel - */ -public class OLSMultipleLinearRegressionModelFormat implements Serializable { - /** X sample data. */ - private final Matrix xMatrix; - - /** Whether or not the regression model includes an intercept. True means no intercept. */ - private final QRDSolver solver; - - /** */ - public OLSMultipleLinearRegressionModelFormat(Matrix xMatrix, QRDSolver solver) { - this.xMatrix = xMatrix; - this.solver = solver; - } - - /** */ - public OLSMultipleLinearRegressionModel getOLSMultipleLinearRegressionModel() { - return new OLSMultipleLinearRegressionModel(xMatrix, solver); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java deleted file mode 100644 index dde0aca..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions; - -import org.apache.ignite.ml.Trainer; -import org.apache.ignite.ml.math.Matrix; - -/** - * Trainer for linear regression. - */ -public class OLSMultipleLinearRegressionTrainer implements Trainer<OLSMultipleLinearRegressionModel, double[]> { - /** */ - private final double threshold; - - /** */ - private final int nobs; - - /** */ - private final int nvars; - - /** */ - private final Matrix like; - - /** - * Construct linear regression trainer. - * - * @param threshold the singularity threshold for QR decomposition - * @param nobs number of observations (rows) - * @param nvars number of independent variables (columns, not counting y) - * @param like matrix(maybe empty) indicating how data should be stored - */ - public OLSMultipleLinearRegressionTrainer(double threshold, int nobs, int nvars, Matrix like) { - this.threshold = threshold; - this.nobs = nobs; - this.nvars = nvars; - this.like = like; - } - - /** {@inheritDoc} */ - @Override public OLSMultipleLinearRegressionModel train(double[] data) { - OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(threshold); - - regression.newSampleData(data, nobs, nvars, like); - - return new OLSMultipleLinearRegressionModel(regression.getX(), regression.solver()); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/RegressionsErrorMessages.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/RegressionsErrorMessages.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/RegressionsErrorMessages.java deleted file mode 100644 index 883adca..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/RegressionsErrorMessages.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions; - -/** - * This class contains various messages used in regressions, - */ -public class RegressionsErrorMessages { - /** Constant for string indicating that sample has insufficient observed points. */ - static final String INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE = "Insufficient observed points in sample."; - /** */ - static final String NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS = "Not enough data (%d rows) for this many predictors (%d predictors)"; -} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java new file mode 100644 index 0000000..6586a81 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import java.io.Serializable; +import java.util.Objects; +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.Vector; + +/** + * Simple linear regression model which predicts result value Y as a linear combination of input variables: + * Y = weights * X + intercept. + */ +public class LinearRegressionModel implements Model<Vector, Double>, Exportable<LinearRegressionModel>, Serializable { + /** */ + private static final long serialVersionUID = -105984600091550226L; + + /** Multiplier of the objects's vector required to make prediction. */ + private final Vector weights; + + /** Intercept of the linear regression model */ + private final double intercept; + + /** */ + public LinearRegressionModel(Vector weights, double intercept) { + this.weights = weights; + this.intercept = intercept; + } + + /** */ + public Vector getWeights() { + return weights; + } + + /** */ + public double getIntercept() { + return intercept; + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector input) { + return input.dot(weights) + intercept; + } + + /** {@inheritDoc} */ + @Override public <P> void saveModel(Exporter<LinearRegressionModel, P> exporter, P path) { + exporter.save(this, path); + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + LinearRegressionModel mdl = (LinearRegressionModel)o; + return Double.compare(mdl.intercept, intercept) == 0 && + Objects.equals(weights, mdl.weights); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + + return Objects.hash(weights, intercept); + } + + /** {@inheritDoc} */ + @Override public String toString() { + if (weights.size() < 10) { + StringBuilder builder = new StringBuilder(); + + for (int i = 0; i < weights.size(); i++) { + double nextItem = i == weights.size() - 1 ? intercept : weights.get(i + 1); + + builder.append(String.format("%.4f", Math.abs(weights.get(i)))) + .append("*x") + .append(i) + .append(nextItem > 0 ? " + " : " - "); + } + + builder.append(String.format("%.4f", Math.abs(intercept))); + return builder.toString(); + } + + return "LinearRegressionModel{" + + "weights=" + weights + + ", intercept=" + intercept + + '}'; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionQRTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionQRTrainer.java new file mode 100644 index 0000000..5de3cda --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionQRTrainer.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import org.apache.ignite.ml.Trainer; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.decompositions.QRDSolver; +import org.apache.ignite.ml.math.decompositions.QRDecomposition; +import org.apache.ignite.ml.math.impls.vector.FunctionVector; + +/** + * Linear regression trainer based on least squares loss function and QR decomposition. + */ +public class LinearRegressionQRTrainer implements Trainer<LinearRegressionModel, Matrix> { + /** + * {@inheritDoc} + */ + @Override public LinearRegressionModel train(Matrix data) { + Vector groundTruth = extractGroundTruth(data); + Matrix inputs = extractInputs(data); + + QRDecomposition decomposition = new QRDecomposition(inputs); + QRDSolver solver = new QRDSolver(decomposition.getQ(), decomposition.getR()); + + Vector variables = solver.solve(groundTruth); + Vector weights = variables.viewPart(1, variables.size() - 1); + + double intercept = variables.get(0); + + return new LinearRegressionModel(weights, intercept); + } + + /** + * Extracts first column with ground truth from the data set matrix. + * + * @param data data to build model + * @return Ground truth vector + */ + private Vector extractGroundTruth(Matrix data) { + return data.getCol(0); + } + + /** + * Extracts all inputs from data set matrix and updates matrix so that first column contains value 1.0. + * + * @param data data to build model + * @return Inputs matrix + */ + private Matrix extractInputs(Matrix data) { + data = data.copy(); + + data.assignColumn(0, new FunctionVector(data.rowSize(), row -> 1.0)); + + return data; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java new file mode 100644 index 0000000..aad4c7a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import org.apache.ignite.ml.Trainer; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.optimization.BarzilaiBorweinUpdater; +import org.apache.ignite.ml.optimization.GradientDescent; +import org.apache.ignite.ml.optimization.LeastSquaresGradientFunction; +import org.apache.ignite.ml.optimization.SimpleUpdater; + +/** + * Linear regression trainer based on least squares loss function and gradient descent optimization algorithm. + */ +public class LinearRegressionSGDTrainer implements Trainer<LinearRegressionModel, Matrix> { + /** + * Gradient descent optimizer. + */ + private final GradientDescent gradientDescent; + + /** */ + public LinearRegressionSGDTrainer(GradientDescent gradientDescent) { + this.gradientDescent = gradientDescent; + } + + /** */ + public LinearRegressionSGDTrainer(int maxIterations, double convergenceTol) { + this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new BarzilaiBorweinUpdater()) + .withMaxIterations(maxIterations) + .withConvergenceTol(convergenceTol); + } + + /** */ + public LinearRegressionSGDTrainer(int maxIterations, double convergenceTol, double learningRate) { + this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new SimpleUpdater(learningRate)) + .withMaxIterations(maxIterations) + .withConvergenceTol(convergenceTol); + } + + /** + * {@inheritDoc} + */ + @Override public LinearRegressionModel train(Matrix data) { + Vector variables = gradientDescent.optimize(data, data.likeVector(data.columnSize())); + Vector weights = variables.viewPart(1, variables.size() - 1); + + double intercept = variables.get(0); + + return new LinearRegressionModel(weights, intercept); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/package-info.java new file mode 100644 index 0000000..086a824 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains various linear regressions. + */ +package org.apache.ignite.ml.regressions.linear; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java index 37dec77..862a9c1 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java @@ -28,9 +28,8 @@ import org.apache.ignite.ml.knn.models.KNNModelFormat; import org.apache.ignite.ml.knn.models.KNNStrategy; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; -import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModel; -import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModelFormat; -import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionTrainer; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; import org.apache.ignite.ml.structures.LabeledDataset; import org.junit.Assert; import org.junit.Test; @@ -63,21 +62,16 @@ public class LocalModelsTest { /** */ @Test - public void importExportOLSMultipleLinearRegressionModelTest() throws IOException { + public void importExportLinearRegressionModelTest() throws IOException { executeModelTest(mdlFilePath -> { - OLSMultipleLinearRegressionModel mdl = getAbstractMultipleLinearRegressionModel(); + LinearRegressionModel model = new LinearRegressionModel(new DenseLocalOnHeapVector(new double[]{1, 2}), 3); + Exporter<LinearRegressionModel, String> exporter = new FileExporter<>(); + model.saveModel(exporter, mdlFilePath); - Exporter<OLSMultipleLinearRegressionModelFormat, String> exporter = new FileExporter<>(); - - mdl.saveModel(exporter, mdlFilePath); - - OLSMultipleLinearRegressionModelFormat load = exporter.load(mdlFilePath); + LinearRegressionModel load = exporter.load(mdlFilePath); Assert.assertNotNull(load); - - OLSMultipleLinearRegressionModel importedMdl = load.getOLSMultipleLinearRegressionModel(); - - Assert.assertTrue("", mdl.equals(importedMdl)); + Assert.assertEquals("", model, load); return null; }); @@ -114,24 +108,6 @@ public class LocalModelsTest { } /** */ - private OLSMultipleLinearRegressionModel getAbstractMultipleLinearRegressionModel() { - double[] data = new double[] { - 0, 0, 0, 0, 0, 0, // IMPL NOTE values in this row are later replaced (with 1.0) - 0, 2.0, 0, 0, 0, 0, - 0, 0, 3.0, 0, 0, 0, - 0, 0, 0, 4.0, 0, 0, - 0, 0, 0, 0, 5.0, 0, - 0, 0, 0, 0, 0, 6.0}; - - final int nobs = 6, nvars = 5; - - OLSMultipleLinearRegressionTrainer trainer - = new OLSMultipleLinearRegressionTrainer(0, nobs, nvars, new DenseLocalOnHeapMatrix(1, 1)); - - return trainer.train(data); - } - - /** */ @Test public void importExportKNNModelTest() throws IOException { executeModelTest(mdlFilePath -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java new file mode 100644 index 0000000..f6f4775 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/GradientDescentTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.optimization; + +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.junit.Test; + +/** + * Tests for {@link GradientDescent}. + */ +public class GradientDescentTest { + /** */ + private static final double PRECISION = 1e-6; + + /** + * Test gradient descent optimization on function y = x^2 with gradient function 2 * x. + */ + @Test + public void testOptimize() { + GradientDescent gradientDescent = new GradientDescent( + (inputs, groundTruth, point) -> point.times(2), + new SimpleUpdater(0.01) + ); + + Vector res = gradientDescent.optimize(new DenseLocalOnHeapMatrix(new double[1][1]), + new DenseLocalOnHeapVector(new double[]{ 2.0 })); + + TestUtils.assertEquals(0, res.get(0), PRECISION); + } + + /** + * Test gradient descent optimization on function y = (x - 2)^2 with gradient function 2 * (x - 2). + */ + @Test + public void testOptimizeWithOffset() { + GradientDescent gradientDescent = new GradientDescent( + (inputs, groundTruth, point) -> point.minus(new DenseLocalOnHeapVector(new double[]{ 2.0 })).times(2.0), + new SimpleUpdater(0.01) + ); + + Vector res = gradientDescent.optimize(new DenseLocalOnHeapMatrix(new double[1][1]), + new DenseLocalOnHeapVector(new double[]{ 2.0 })); + + TestUtils.assertEquals(2, res.get(0), PRECISION); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java new file mode 100644 index 0000000..9017c43 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducerTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.optimization.util; + +import org.apache.ignite.Ignite; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests for {@link SparseDistributedMatrixMapReducer}. + */ +public class SparseDistributedMatrixMapReducerTest extends GridCommonAbstractTest { + /** Number of nodes in grid */ + private static final int NODE_COUNT = 2; + + /** */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() throws Exception { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** + * Tests that matrix 100x100 filled by "1.0" and distributed across nodes successfully processed (calculate sum of + * all elements) via {@link SparseDistributedMatrixMapReducer}. + */ + public void testMapReduce() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(100, 100); + for (int i = 0; i < 100; i++) + for (int j = 0; j < 100; j++) + distributedMatrix.set(i, j, 1); + SparseDistributedMatrixMapReducer mapReducer = new SparseDistributedMatrixMapReducer(distributedMatrix); + double total = mapReducer.mapReduce( + (matrix, args) -> { + double partialSum = 0.0; + for (int i = 0; i < matrix.rowSize(); i++) + for (int j = 0; j < matrix.columnSize(); j++) + partialSum += matrix.get(i, j); + return partialSum; + }, + sums -> { + double totalSum = 0; + for (Double partialSum : sums) + if (partialSum != null) + totalSum += partialSum; + return totalSum; + }, 0.0); + assertEquals(100.0 * 100.0, total, 1e-18); + } + + /** + * Tests that matrix 100x100 filled by "1.0" and distributed across nodes successfully processed via + * {@link SparseDistributedMatrixMapReducer} even when mapping function returns {@code null}. + */ + public void testMapReduceWithNullValues() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(100, 100); + for (int i = 0; i < 100; i++) + for (int j = 0; j < 100; j++) + distributedMatrix.set(i, j, 1); + SparseDistributedMatrixMapReducer mapReducer = new SparseDistributedMatrixMapReducer(distributedMatrix); + double total = mapReducer.mapReduce( + (matrix, args) -> null, + sums -> { + double totalSum = 0; + for (Double partialSum : sums) + if (partialSum != null) + totalSum += partialSum; + return totalSum; + }, 0.0); + assertEquals(0, total, 1e-18); + } + + /** + * Tests that matrix 1x100 filled by "1.0" and distributed across nodes successfully processed (calculate sum of + * all elements) via {@link SparseDistributedMatrixMapReducer} even when not all nodes contains data. + */ + public void testMapReduceWithOneEmptyNode() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(1, 100); + for (int j = 0; j < 100; j++) + distributedMatrix.set(0, j, 1); + SparseDistributedMatrixMapReducer mapReducer = new SparseDistributedMatrixMapReducer(distributedMatrix); + double total = mapReducer.mapReduce( + (matrix, args) -> { + double partialSum = 0.0; + for (int i = 0; i < matrix.rowSize(); i++) + for (int j = 0; j < matrix.columnSize(); j++) + partialSum += matrix.get(i, j); + return partialSum; + }, + sums -> { + double totalSum = 0; + for (Double partialSum : sums) + if (partialSum != null) + totalSum += partialSum; + return totalSum; + }, 0.0); + assertEquals(100.0, total, 1e-18); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/b2060855/modules/ml/src/test/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegressionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegressionTest.java deleted file mode 100644 index 6ad56a5..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegressionTest.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions; - -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; -import org.apache.ignite.ml.math.exceptions.NullArgumentException; -import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -/** - * This class is based on the corresponding class from Apache Common Math lib. - * Abstract base class for implementations of {@link MultipleLinearRegression}. - */ -public abstract class AbstractMultipleLinearRegressionTest { - /** */ - protected AbstractMultipleLinearRegression regression; - - /** */ - @Before - public void setUp() { - regression = createRegression(); - } - - /** */ - protected abstract AbstractMultipleLinearRegression createRegression(); - - /** */ - protected abstract int getNumberOfRegressors(); - - /** */ - protected abstract int getSampleSize(); - - /** */ - @Test - public void canEstimateRegressionParameters() { - double[] beta = regression.estimateRegressionParameters(); - Assert.assertEquals(getNumberOfRegressors(), beta.length); - } - - /** */ - @Test - public void canEstimateResiduals() { - double[] e = regression.estimateResiduals(); - Assert.assertEquals(getSampleSize(), e.length); - } - - /** */ - @Test - public void canEstimateRegressionParametersVariance() { - Matrix var = regression.estimateRegressionParametersVariance(); - Assert.assertEquals(getNumberOfRegressors(), var.rowSize()); - } - - /** */ - @Test - public void canEstimateRegressandVariance() { - if (getSampleSize() > getNumberOfRegressors()) { - double variance = regression.estimateRegressandVariance(); - Assert.assertTrue(variance > 0.0); - } - } - - /** - * Verifies that newSampleData methods consistently insert unitary columns - * in design matrix. Confirms the fix for MATH-411. - */ - @Test - public void testNewSample() { - double[] design = new double[] { - 1, 19, 22, 33, - 2, 20, 30, 40, - 3, 25, 35, 45, - 4, 27, 37, 47 - }; - - double[] y = new double[] {1, 2, 3, 4}; - - double[][] x = new double[][] { - {19, 22, 33}, - {20, 30, 40}, - {25, 35, 45}, - {27, 37, 47} - }; - - AbstractMultipleLinearRegression regression = createRegression(); - regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix()); - - Matrix flatX = regression.getX().copy(); - Vector flatY = regression.getY().copy(); - - regression.newXSampleData(new DenseLocalOnHeapMatrix(x)); - regression.newYSampleData(new DenseLocalOnHeapVector(y)); - - Assert.assertEquals(flatX, regression.getX()); - Assert.assertEquals(flatY, regression.getY()); - - // No intercept - regression.setNoIntercept(true); - regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix()); - - flatX = regression.getX().copy(); - flatY = regression.getY().copy(); - - regression.newXSampleData(new DenseLocalOnHeapMatrix(x)); - regression.newYSampleData(new DenseLocalOnHeapVector(y)); - - Assert.assertEquals(flatX, regression.getX()); - Assert.assertEquals(flatY, regression.getY()); - } - - /** */ - @Test(expected = NullArgumentException.class) - public void testNewSampleNullData() { - double[] data = null; - createRegression().newSampleData(data, 2, 3, new DenseLocalOnHeapMatrix()); - } - - /** */ - @Test(expected = MathIllegalArgumentException.class) - public void testNewSampleInvalidData() { - double[] data = new double[] {1, 2, 3, 4}; - createRegression().newSampleData(data, 2, 3, new DenseLocalOnHeapMatrix()); - } - - /** */ - @Test(expected = MathIllegalArgumentException.class) - public void testNewSampleInsufficientData() { - double[] data = new double[] {1, 2, 3, 4}; - createRegression().newSampleData(data, 1, 3, new DenseLocalOnHeapMatrix()); - } - - /** */ - @Test(expected = NullArgumentException.class) - public void testXSampleDataNull() { - createRegression().newXSampleData(null); - } - - /** */ - @Test(expected = NullArgumentException.class) - public void testYSampleDataNull() { - createRegression().newYSampleData(null); - } - -}
