NO-JIRA Trevors updates
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/545648f6 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/545648f6 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/545648f6 Branch: refs/heads/branch-0.14.0 Commit: 545648f6a8f69139757f0328e924c6b16a839441 Parents: 49ad8cb Author: Trevor a.k.a @rawkintrevo <[email protected]> Authored: Sat Sep 8 18:34:43 2018 -0500 Committer: Trevor a.k.a @rawkintrevo <[email protected]> Committed: Sat Sep 8 18:34:43 2018 -0500 ---------------------------------------------------------------------- .../apache/mahout/collections/Arithmetic.java | 489 ++++ .../apache/mahout/collections/Constants.java | 75 + .../org/apache/mahout/common/RandomUtils.java | 100 + .../org/apache/mahout/common/RandomWrapper.java | 105 + .../org/apache/mahout/math/AbstractMatrix.java | 834 +++++++ .../org/apache/mahout/math/AbstractVector.java | 684 ++++++ .../java/org/apache/mahout/math/Algebra.java | 73 + .../java/org/apache/mahout/math/Arrays.java | 662 +++++ .../org/apache/mahout/math/BinarySearch.java | 403 +++ .../mahout/math/CardinalityException.java | 30 + .../java/org/apache/mahout/math/Centroid.java | 89 + .../mahout/math/CholeskyDecomposition.java | 227 ++ .../org/apache/mahout/math/ConstantVector.java | 177 ++ .../apache/mahout/math/DelegatingVector.java | 336 +++ .../org/apache/mahout/math/DenseMatrix.java | 193 ++ .../mahout/math/DenseSymmetricMatrix.java | 62 + .../org/apache/mahout/math/DenseVector.java | 442 ++++ .../org/apache/mahout/math/DiagonalMatrix.java | 378 +++ .../org/apache/mahout/math/FileBasedMatrix.java | 185 ++ .../math/FileBasedSparseBinaryMatrix.java | 535 ++++ .../mahout/math/FunctionalMatrixView.java | 99 + .../org/apache/mahout/math/IndexException.java | 30 + .../apache/mahout/math/LengthCachingVector.java | 35 + .../java/org/apache/mahout/math/Matrices.java | 167 ++ .../java/org/apache/mahout/math/Matrix.java | 413 ++++ .../org/apache/mahout/math/MatrixSlice.java | 36 + .../org/apache/mahout/math/MatrixTimesOps.java | 35 + .../apache/mahout/math/MatrixVectorView.java | 292 +++ .../java/org/apache/mahout/math/MatrixView.java | 160 ++ .../java/org/apache/mahout/math/MurmurHash.java | 158 ++ .../org/apache/mahout/math/MurmurHash3.java | 84 + .../org/apache/mahout/math/NamedVector.java | 328 +++ .../apache/mahout/math/OldQRDecomposition.java | 234 ++ .../mahout/math/OrderedIntDoubleMapping.java | 265 ++ .../mahout/math/OrthonormalityVerifier.java | 46 + .../apache/mahout/math/PermutedVectorView.java | 250 ++ .../apache/mahout/math/PersistentObject.java | 58 + .../org/apache/mahout/math/PivotedMatrix.java | 288 +++ .../main/java/org/apache/mahout/math/QR.java | 27 + .../org/apache/mahout/math/QRDecomposition.java | 181 ++ .../mahout/math/RandomAccessSparseVector.java | 303 +++ .../apache/mahout/math/RandomTrinaryMatrix.java | 146 ++ .../math/SequentialAccessSparseVector.java | 379 +++ .../mahout/math/SingularValueDecomposition.java | 669 +++++ .../java/org/apache/mahout/math/Sorting.java | 2297 ++++++++++++++++++ .../apache/mahout/math/SparseColumnMatrix.java | 220 ++ .../org/apache/mahout/math/SparseMatrix.java | 245 ++ .../org/apache/mahout/math/SparseRowMatrix.java | 289 +++ .../java/org/apache/mahout/math/Swapper.java | 35 + .../mahout/math/TransposedMatrixView.java | 147 ++ .../org/apache/mahout/math/UpperTriangular.java | 160 ++ .../java/org/apache/mahout/math/Vector.java | 434 ++++ .../mahout/math/VectorBinaryAggregate.java | 481 ++++ .../apache/mahout/math/VectorBinaryAssign.java | 667 +++++ .../org/apache/mahout/math/VectorIterable.java | 56 + .../java/org/apache/mahout/math/VectorView.java | 238 ++ .../org/apache/mahout/math/WeightedVector.java | 87 + .../mahout/math/WeightedVectorComparator.java | 54 + .../math/als/AlternatingLeastSquaresSolver.java | 116 + ...itFeedbackAlternatingLeastSquaresSolver.java | 171 ++ .../math/decomposer/AsyncEigenVerifier.java | 80 + .../mahout/math/decomposer/EigenStatus.java | 50 + .../math/decomposer/SimpleEigenVerifier.java | 41 + .../math/decomposer/SingularVectorVerifier.java | 25 + .../math/decomposer/hebbian/EigenUpdater.java | 25 + .../math/decomposer/hebbian/HebbianSolver.java | 342 +++ .../math/decomposer/hebbian/HebbianUpdater.java | 71 + .../math/decomposer/hebbian/TrainingState.java | 143 ++ .../math/decomposer/lanczos/LanczosSolver.java | 213 ++ .../math/decomposer/lanczos/LanczosState.java | 107 + .../org/apache/mahout/math/flavor/BackEnum.java | 26 + .../apache/mahout/math/flavor/MatrixFlavor.java | 82 + .../math/flavor/TraversingStructureEnum.java | 48 + .../math/function/DoubleDoubleFunction.java | 98 + .../mahout/math/function/DoubleFunction.java | 48 + .../mahout/math/function/FloatFunction.java | 36 + .../apache/mahout/math/function/Functions.java | 1730 +++++++++++++ .../mahout/math/function/IntFunction.java | 41 + .../math/function/IntIntDoubleFunction.java | 43 + .../mahout/math/function/IntIntFunction.java | 25 + .../org/apache/mahout/math/function/Mult.java | 71 + .../math/function/ObjectObjectProcedure.java | 40 + .../mahout/math/function/ObjectProcedure.java | 47 + .../apache/mahout/math/function/PlusMult.java | 123 + .../math/function/SquareRootFunction.java | 26 + .../mahout/math/function/TimesFunction.java | 77 + .../mahout/math/function/VectorFunction.java | 27 + .../mahout/math/function/package-info.java | 4 + .../apache/mahout/math/jet/math/Arithmetic.java | 328 +++ .../apache/mahout/math/jet/math/Constants.java | 49 + .../apache/mahout/math/jet/math/Polynomial.java | 98 + .../mahout/math/jet/math/package-info.java | 5 + .../random/AbstractContinousDistribution.java | 51 + .../random/AbstractDiscreteDistribution.java | 27 + .../math/jet/random/AbstractDistribution.java | 87 + .../mahout/math/jet/random/Exponential.java | 81 + .../apache/mahout/math/jet/random/Gamma.java | 302 +++ .../math/jet/random/NegativeBinomial.java | 106 + .../apache/mahout/math/jet/random/Normal.java | 110 + .../apache/mahout/math/jet/random/Poisson.java | 296 +++ .../apache/mahout/math/jet/random/Uniform.java | 164 ++ .../math/jet/random/engine/MersenneTwister.java | 275 +++ .../math/jet/random/engine/RandomEngine.java | 169 ++ .../math/jet/random/engine/package-info.java | 7 + .../math/jet/random/sampling/RandomSampler.java | 503 ++++ .../org/apache/mahout/math/jet/stat/Gamma.java | 681 ++++++ .../mahout/math/jet/stat/Probability.java | 203 ++ .../mahout/math/jet/stat/package-info.java | 5 + .../apache/mahout/math/list/AbstractList.java | 247 ++ .../mahout/math/list/AbstractObjectList.java | 80 + .../mahout/math/list/ObjectArrayList.java | 419 ++++ .../mahout/math/list/SimpleLongArrayList.java | 102 + .../apache/mahout/math/list/package-info.java | 144 ++ .../apache/mahout/math/map/HashFunctions.java | 115 + .../org/apache/mahout/math/map/OpenHashMap.java | 652 +++++ .../org/apache/mahout/math/map/PrimeFinder.java | 145 ++ .../mahout/math/map/QuickOpenIntIntHashMap.java | 215 ++ .../apache/mahout/math/map/package-info.java | 250 ++ .../org/apache/mahout/math/package-info.java | 4 + .../math/random/AbstractSamplerFunction.java | 39 + .../mahout/math/random/ChineseRestaurant.java | 111 + .../apache/mahout/math/random/Empirical.java | 124 + .../apache/mahout/math/random/IndianBuffet.java | 157 ++ .../org/apache/mahout/math/random/Missing.java | 59 + .../apache/mahout/math/random/MultiNormal.java | 118 + .../apache/mahout/math/random/Multinomial.java | 202 ++ .../org/apache/mahout/math/random/Normal.java | 40 + .../mahout/math/random/PoissonSampler.java | 67 + .../org/apache/mahout/math/random/Sampler.java | 25 + .../mahout/math/random/WeightedThing.java | 71 + .../org/apache/mahout/math/set/AbstractSet.java | 188 ++ .../org/apache/mahout/math/set/HashUtils.java | 56 + .../org/apache/mahout/math/set/OpenHashSet.java | 548 +++++ .../math/solver/ConjugateGradientSolver.java | 213 ++ .../mahout/math/solver/EigenDecomposition.java | 892 +++++++ .../mahout/math/solver/JacobiConditioner.java | 47 + .../org/apache/mahout/math/solver/LSMR.java | 565 +++++ .../mahout/math/solver/Preconditioner.java | 36 + .../mahout/math/ssvd/SequentialBigSvd.java | 69 + .../apache/mahout/math/stats/LogLikelihood.java | 220 ++ .../math/stats/OnlineExponentialAverage.java | 62 + .../mahout/math/stats/OnlineSummarizer.java | 93 + .../apache/mahout/math/QRDecompositionTest.java | 280 +++ .../math/TestSingularValueDecomposition.java | 327 +++ .../als/AlternatingLeastSquaresSolverTest.java | 151 ++ .../mahout/math/decomposer/SolverTest.java | 177 ++ .../decomposer/hebbian/TestHebbianSolver.java | 207 ++ .../decomposer/lanczos/TestLanczosSolver.java | 97 + .../apache/mahout/math/jet/stat/GammaTest.java | 138 ++ .../mahout/math/jet/stat/ProbabilityTest.java | 196 ++ .../math/random/ChineseRestaurantTest.java | 158 ++ .../mahout/math/randomized/RandomBlasting.java | 355 +++ .../mahout/math/ssvd/SequentialBigSvdTest.java | 86 + .../mahout/math/stats/OnlineSummarizerTest.java | 108 + 154 files changed, 32350 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/collections/Arithmetic.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/collections/Arithmetic.java b/core/src/main/java/org/apache/mahout/collections/Arithmetic.java new file mode 100644 index 0000000..18e3200 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/collections/Arithmetic.java @@ -0,0 +1,489 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* +Copyright 1999 CERN - European Organization for Nuclear Research. +Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose +is hereby granted without fee, provided that the above copyright notice appear in all copies and +that both that copyright notice and this permission notice appear in supporting documentation. +CERN makes no representations about the suitability of this software for any purpose. +It is provided "as is" without expressed or implied warranty. +*/ +package org.apache.mahout.collections; + +/** + * Arithmetic functions. + */ +public final class Arithmetic extends Constants { + // for method STIRLING_CORRECTION(...) + private static final double[] STIRLING_CORRECTION = { + 0.0, + 8.106146679532726e-02, 4.134069595540929e-02, + 2.767792568499834e-02, 2.079067210376509e-02, + 1.664469118982119e-02, 1.387612882307075e-02, + 1.189670994589177e-02, 1.041126526197209e-02, + 9.255462182712733e-03, 8.330563433362871e-03, + 7.573675487951841e-03, 6.942840107209530e-03, + 6.408994188004207e-03, 5.951370112758848e-03, + 5.554733551962801e-03, 5.207655919609640e-03, + 4.901395948434738e-03, 4.629153749334029e-03, + 4.385560249232324e-03, 4.166319691996922e-03, + 3.967954218640860e-03, 3.787618068444430e-03, + 3.622960224683090e-03, 3.472021382978770e-03, + 3.333155636728090e-03, 3.204970228055040e-03, + 3.086278682608780e-03, 2.976063983550410e-03, + 2.873449362352470e-03, 2.777674929752690e-03, + }; + + // for method logFactorial(...) + // log(k!) for k = 0, ..., 29 + private static final double[] LOG_FACTORIALS = { + 0.00000000000000000, 0.00000000000000000, 0.69314718055994531, + 1.79175946922805500, 3.17805383034794562, 4.78749174278204599, + 6.57925121201010100, 8.52516136106541430, 10.60460290274525023, + 12.80182748008146961, 15.10441257307551530, 17.50230784587388584, + 19.98721449566188615, 22.55216385312342289, 25.19122118273868150, + 27.89927138384089157, 30.67186010608067280, 33.50507345013688888, + 36.39544520803305358, 39.33988418719949404, 42.33561646075348503, + 45.38013889847690803, 48.47118135183522388, 51.60667556776437357, + 54.78472939811231919, 58.00360522298051994, 61.26170176100200198, + 64.55753862700633106, 67.88974313718153498, 71.25703896716800901 + }; + + // k! for k = 0, ..., 20 + private static final long[] LONG_FACTORIALS = { + 1L, + 1L, + 2L, + 6L, + 24L, + 120L, + 720L, + 5040L, + 40320L, + 362880L, + 3628800L, + 39916800L, + 479001600L, + 6227020800L, + 87178291200L, + 1307674368000L, + 20922789888000L, + 355687428096000L, + 6402373705728000L, + 121645100408832000L, + 2432902008176640000L + }; + + // k! for k = 21, ..., 170 + private static final double[] DOUBLE_FACTORIALS = { + 5.109094217170944E19, + 1.1240007277776077E21, + 2.585201673888498E22, + 6.204484017332394E23, + 1.5511210043330984E25, + 4.032914611266057E26, + 1.0888869450418352E28, + 3.048883446117138E29, + 8.841761993739701E30, + 2.652528598121911E32, + 8.222838654177924E33, + 2.6313083693369355E35, + 8.68331761881189E36, + 2.952327990396041E38, + 1.0333147966386144E40, + 3.719933267899013E41, + 1.3763753091226346E43, + 5.23022617466601E44, + 2.0397882081197447E46, + 8.15915283247898E47, + 3.34525266131638E49, + 1.4050061177528801E51, + 6.041526306337384E52, + 2.6582715747884495E54, + 1.196222208654802E56, + 5.502622159812089E57, + 2.5862324151116827E59, + 1.2413915592536068E61, + 6.082818640342679E62, + 3.0414093201713376E64, + 1.5511187532873816E66, + 8.06581751709439E67, + 4.274883284060024E69, + 2.308436973392413E71, + 1.2696403353658264E73, + 7.109985878048632E74, + 4.052691950487723E76, + 2.350561331282879E78, + 1.386831185456898E80, + 8.32098711274139E81, + 5.075802138772246E83, + 3.146997326038794E85, + 1.9826083154044396E87, + 1.2688693218588414E89, + 8.247650592082472E90, + 5.443449390774432E92, + 3.6471110918188705E94, + 2.48003554243683E96, + 1.7112245242814127E98, + 1.1978571669969892E100, + 8.504785885678624E101, + 6.123445837688612E103, + 4.470115461512686E105, + 3.307885441519387E107, + 2.4809140811395404E109, + 1.8854947016660506E111, + 1.451830920282859E113, + 1.1324281178206295E115, + 8.94618213078298E116, + 7.15694570462638E118, + 5.797126020747369E120, + 4.7536433370128435E122, + 3.94552396972066E124, + 3.314240134565354E126, + 2.8171041143805494E128, + 2.4227095383672744E130, + 2.107757298379527E132, + 1.854826422573984E134, + 1.6507955160908465E136, + 1.4857159644817605E138, + 1.3520015276784033E140, + 1.2438414054641305E142, + 1.156772507081641E144, + 1.0873661566567426E146, + 1.0329978488239061E148, + 9.916779348709491E149, + 9.619275968248216E151, + 9.426890448883248E153, + 9.332621544394415E155, + 9.332621544394418E157, + 9.42594775983836E159, + 9.614466715035125E161, + 9.902900716486178E163, + 1.0299016745145631E166, + 1.0813967582402912E168, + 1.1462805637347086E170, + 1.2265202031961373E172, + 1.324641819451829E174, + 1.4438595832024942E176, + 1.5882455415227423E178, + 1.7629525510902457E180, + 1.974506857221075E182, + 2.2311927486598138E184, + 2.543559733472186E186, + 2.925093693493014E188, + 3.393108684451899E190, + 3.96993716080872E192, + 4.6845258497542896E194, + 5.574585761207606E196, + 6.689502913449135E198, + 8.094298525273444E200, + 9.875044200833601E202, + 1.2146304367025332E205, + 1.506141741511141E207, + 1.882677176888926E209, + 2.3721732428800483E211, + 3.0126600184576624E213, + 3.856204823625808E215, + 4.974504222477287E217, + 6.466855489220473E219, + 8.471580690878813E221, + 1.1182486511960037E224, + 1.4872707060906847E226, + 1.99294274616152E228, + 2.690472707318049E230, + 3.6590428819525483E232, + 5.0128887482749884E234, + 6.917786472619482E236, + 9.615723196941089E238, + 1.3462012475717523E241, + 1.8981437590761713E243, + 2.6953641378881633E245, + 3.8543707171800694E247, + 5.550293832739308E249, + 8.047926057471989E251, + 1.1749972043909107E254, + 1.72724589045464E256, + 2.5563239178728637E258, + 3.8089226376305687E260, + 5.7133839564458575E262, + 8.627209774233244E264, + 1.3113358856834527E267, + 2.0063439050956838E269, + 3.0897696138473515E271, + 4.789142901463393E273, + 7.471062926282892E275, + 1.1729568794264134E278, + 1.8532718694937346E280, + 2.946702272495036E282, + 4.714723635992061E284, + 7.590705053947223E286, + 1.2296942187394494E289, + 2.0044015765453032E291, + 3.287218585534299E293, + 5.423910666131583E295, + 9.003691705778434E297, + 1.5036165148649983E300, + 2.5260757449731988E302, + 4.2690680090047056E304, + 7.257415615308004E306 + }; + + /** Makes this class non instantiable, but still let's others inherit from it. */ + Arithmetic() { + } + + /** + * Efficiently returns the binomial coefficient, often also referred to as + * "n over k" or "n choose k". The binomial coefficient is defined as + * <tt>(n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k )</tt>. + * <ul> <li><tt>k<0</tt>: <tt>0</tt>.</li> + * <li><tt>k==0</tt>: <tt>1</tt>.</li> + * <li><tt>k==1</tt>: <tt>n</tt>.</li> + * <li>else: <tt>(n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k)</tt>.</li> + * </ul> + * + * @param n + * @param k + * @return the binomial coefficient. + */ + public static double binomial(double n, long k) { + if (k < 0) { + return 0; + } + if (k == 0) { + return 1; + } + if (k == 1) { + return n; + } + + // binomial(n,k) = (n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k ) + double a = n - k + 1; + double b = 1; + double binomial = 1; + for (long i = k; i-- > 0;) { + binomial *= (a++) / (b++); + } + return binomial; + } + + /** + * Efficiently returns the binomial coefficient, often also referred to as "n over k" or "n choose k". The binomial + * coefficient is defined as <ul> <li><tt>k<0</tt>: <tt>0</tt>. <li><tt>k==0 || k==n</tt>: <tt>1</tt>. <li><tt>k==1 || k==n-1</tt>: + * <tt>n</tt>. <li>else: <tt>(n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k )</tt>. </ul> + * + * @return the binomial coefficient. + */ + public static double binomial(long n, long k) { + if (k < 0) { + return 0; + } + if (k == 0 || k == n) { + return 1; + } + if (k == 1 || k == n - 1) { + return n; + } + + // try quick version and see whether we get numeric overflows. + // factorial(..) is O(1); requires no loop; only a table lookup. + if (n > k) { + int max = LONG_FACTORIALS.length + DOUBLE_FACTORIALS.length; + if (n < max) { // if (n! < inf && k! < inf) + double n_fac = factorial((int) n); + double k_fac = factorial((int) k); + double n_minus_k_fac = factorial((int) (n - k)); + double nk = n_minus_k_fac * k_fac; + if (nk != Double.POSITIVE_INFINITY) { // no numeric overflow? + // now this is completely safe and accurate + return n_fac / nk; + } + } + if (k > n / 2) { + k = n - k; + } // quicker + } + + // binomial(n,k) = (n * n-1 * ... * n-k+1 ) / ( 1 * 2 * ... * k ) + long a = n - k + 1; + long b = 1; + double binomial = 1; + for (long i = k; i-- > 0;) { + binomial *= (double) a++ / (b++); + } + return binomial; + } + + /** + * Returns the smallest <code>long >= value</code>. + * <dl><dt>Examples: {@code 1.0 -> 1, 1.2 -> 2, 1.9 -> 2}. This + * method is safer than using (long) Math.ceil(value), because of possible rounding error.</dt></dl> + */ + public static long ceil(double value) { + return Math.round(Math.ceil(value)); + } + + /** + * Evaluates the series of Chebyshev polynomials Ti at argument x/2. The series is given by + * <pre> + * N-1 + * - ' + * y = > coef[i] T (x/2) + * - i + * i=0 + * </pre> + * Coefficients are stored in reverse order, i.e. the zero order term is last in the array. Note N is the number of + * coefficients, not the order. <p> If coefficients are for the interval a to b, x must have been transformed to x -< + * 2(2x - b - a)/(b-a) before entering the routine. This maps x from (a, b) to (-1, 1), over which the Chebyshev + * polynomials are defined. <p> If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b, + * 1/a), the transformation required is {@code x -> 2(2ab/x - b - a)/(b-a)}. If b is infinity, this becomes {@code x -> 4a/x - 1}. + * <p> SPEED: <p> Taking advantage of the recurrence properties of the Chebyshev polynomials, the routine requires one + * more addition per loop than evaluating a nested polynomial of the same degree. + * + * @param x argument to the polynomial. + * @param coef the coefficients of the polynomial. + * @param N the number of coefficients. + */ + public static double chbevl(double x, double[] coef, int N) { + + int p = 0; + + double b0 = coef[p++]; + double b1 = 0.0; + int i = N - 1; + + double b2; + do { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + coef[p++]; + } while (--i > 0); + + return 0.5 * (b0 - b2); + } + + /** + * Instantly returns the factorial <tt>k!</tt>. + * + * @param k must hold <tt>k >= 0</tt>. + */ + private static double factorial(int k) { + if (k < 0) { + throw new IllegalArgumentException(); + } + + int length1 = LONG_FACTORIALS.length; + if (k < length1) { + return LONG_FACTORIALS[k]; + } + + int length2 = DOUBLE_FACTORIALS.length; + if (k < length1 + length2) { + return DOUBLE_FACTORIALS[k - length1]; + } else { + return Double.POSITIVE_INFINITY; + } + } + + /** + * Returns the largest <code>long <= value</code>. + * <dl><dt>Examples: {@code 1.0 -> 1, 1.2 -> 1, 1.9 -> 1 <dt> 2.0 -> 2, 2.2 -> 2, 2.9 -> 2}</dt></dl> + * This method is safer than using (long) Math.floor(value), because of possible rounding error. + */ + public static long floor(double value) { + return Math.round(Math.floor(value)); + } + + /** Returns <tt>log<sub>base</sub>value</tt>. */ + public static double log(double base, double value) { + return Math.log(value) / Math.log(base); + } + + /** Returns <tt>log<sub>10</sub>value</tt>. */ + public static double log10(double value) { + // 1.0 / Math.log(10) == 0.43429448190325176 + return Math.log(value) * 0.43429448190325176; + } + + /** Returns <tt>log<sub>2</sub>value</tt>. */ + public static double log2(double value) { + // 1.0 / Math.log(2) == 1.4426950408889634 + return Math.log(value) * 1.4426950408889634; + } + + /** + * Returns <tt>log(k!)</tt>. Tries to avoid overflows. For <tt>k<30</tt> simply looks up a table in O(1). For + * <tt>k>=30</tt> uses stirlings approximation. + * + * @param k must hold <tt>k >= 0</tt>. + */ + public static double logFactorial(int k) { + if (k >= 30) { + + double r = 1.0 / k; + double rr = r * r; + double C7 = -5.95238095238095238e-04; + double C5 = 7.93650793650793651e-04; + double C3 = -2.77777777777777778e-03; + double C1 = 8.33333333333333333e-02; + double C0 = 9.18938533204672742e-01; + return (k + 0.5) * Math.log(k) - k + C0 + r * (C1 + rr * (C3 + rr * (C5 + rr * C7))); + } else { + return LOG_FACTORIALS[k]; + } + } + + /** + * Instantly returns the factorial <tt>k!</tt>. + * + * @param k must hold {@code k >= 0 && k < 21} + */ + public static long longFactorial(int k) { + if (k < 0) { + throw new IllegalArgumentException("Negative k"); + } + + if (k < LONG_FACTORIALS.length) { + return LONG_FACTORIALS[k]; + } + throw new IllegalArgumentException("Overflow"); + } + + /** + * Returns the StirlingCorrection. <p> Correction term of the Stirling approximation for <tt>log(k!)</tt> (series in + * 1/k, or table values for small k) with int parameter k. </p> <tt> log k! = (k + 1/2)log(k + 1) - (k + 1) + + * (1/2)log(2Pi) + STIRLING_CORRECTION(k + 1) log k! = (k + 1/2)log(k) - k + (1/2)log(2Pi) + + * STIRLING_CORRECTION(k) </tt> + */ + public static double stirlingCorrection(int k) { + + if (k > 30) { + double r = 1.0 / k; + double rr = r * r; + double C7 = -5.95238095238095238e-04; // -1/1680 + double C5 = 7.93650793650793651e-04; // +1/1260 + double C3 = -2.77777777777777778e-03; // -1/360 + double C1 = 8.33333333333333333e-02; // +1/12 + return r * (C1 + rr * (C3 + rr * (C5 + rr * C7))); + } else { + return STIRLING_CORRECTION[k]; + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/collections/Constants.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/collections/Constants.java b/core/src/main/java/org/apache/mahout/collections/Constants.java new file mode 100644 index 0000000..007bd3f --- /dev/null +++ b/core/src/main/java/org/apache/mahout/collections/Constants.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* +Copyright 1999 CERN - European Organization for Nuclear Research. +Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose +is hereby granted without fee, provided that the above copyright notice appear in all copies and +that both that copyright notice and this permission notice appear in supporting documentation. +CERN makes no representations about the suitability of this software for any purpose. +It is provided "as is" without expressed or implied warranty. +*/ +package org.apache.mahout.collections; + +/** + * Defines some useful constants. + */ +public class Constants { + /* + * machine constants + */ + protected static final double MACHEP = 1.11022302462515654042E-16; + protected static final double MAXLOG = 7.09782712893383996732E2; + protected static final double MINLOG = -7.451332191019412076235E2; + protected static final double MAXGAM = 171.624376956302725; + protected static final double SQTPI = 2.50662827463100050242E0; + protected static final double SQRTH = 7.07106781186547524401E-1; + protected static final double LOGPI = 1.14472988584940017414; + + protected static final double BIG = 4.503599627370496e15; + protected static final double BIGINV = 2.22044604925031308085e-16; + + + /* + * MACHEP = 1.38777878078144567553E-17 2**-56 + * MAXLOG = 8.8029691931113054295988E1 log(2**127) + * MINLOG = -8.872283911167299960540E1 log(2**-128) + * MAXNUM = 1.701411834604692317316873e38 2**127 + * + * For IEEE arithmetic (IBMPC): + * MACHEP = 1.11022302462515654042E-16 2**-53 + * MAXLOG = 7.09782712893383996843E2 log(2**1024) + * MINLOG = -7.08396418532264106224E2 log(2**-1022) + * MAXNUM = 1.7976931348623158E308 2**1024 + * + * The global symbols for mathematical constants are + * PI = 3.14159265358979323846 pi + * PIO2 = 1.57079632679489661923 pi/2 + * PIO4 = 7.85398163397448309616E-1 pi/4 + * SQRT2 = 1.41421356237309504880 sqrt(2) + * SQRTH = 7.07106781186547524401E-1 sqrt(2)/2 + * LOG2E = 1.4426950408889634073599 1/log(2) + * SQ2OPI = 7.9788456080286535587989E-1 sqrt( 2/pi ) + * LOGE2 = 6.93147180559945309417E-1 log(2) + * LOGSQ2 = 3.46573590279972654709E-1 log(2)/2 + * THPIO4 = 2.35619449019234492885 3*pi/4 + * TWOOPI = 6.36619772367581343075535E-1 2/pi + */ + protected Constants() {} +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/common/RandomUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/common/RandomUtils.java b/core/src/main/java/org/apache/mahout/common/RandomUtils.java new file mode 100644 index 0000000..ba71292 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/common/RandomUtils.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import java.util.Collections; +import java.util.Map; +import java.util.Random; +import java.util.WeakHashMap; + +import com.google.common.primitives.Longs; +import org.apache.commons.math3.primes.Primes; + +/** + * <p> + * The source of random stuff for the whole project. This lets us make all randomness in the project + * predictable, if desired, for when we run unit tests, which should be repeatable. + * </p> + */ +public final class RandomUtils { + + /** The largest prime less than 2<sup>31</sup>-1 that is the smaller of a twin prime pair. */ + public static final int MAX_INT_SMALLER_TWIN_PRIME = 2147482949; + + private static final Map<RandomWrapper,Boolean> INSTANCES = + Collections.synchronizedMap(new WeakHashMap<RandomWrapper,Boolean>()); + + private static boolean testSeed = false; + + private RandomUtils() { } + + public static void useTestSeed() { + testSeed = true; + synchronized (INSTANCES) { + for (RandomWrapper rng : INSTANCES.keySet()) { + rng.resetToTestSeed(); + } + } + } + + public static RandomWrapper getRandom() { + RandomWrapper random = new RandomWrapper(); + if (testSeed) { + random.resetToTestSeed(); + } + INSTANCES.put(random, Boolean.TRUE); + return random; + } + + public static Random getRandom(long seed) { + RandomWrapper random = new RandomWrapper(seed); + INSTANCES.put(random, Boolean.TRUE); + return random; + } + + /** @return what {@link Double#hashCode()} would return for the same value */ + public static int hashDouble(double value) { + return Longs.hashCode(Double.doubleToLongBits(value)); + } + + /** @return what {@link Float#hashCode()} would return for the same value */ + public static int hashFloat(float value) { + return Float.floatToIntBits(value); + } + + /** + * <p> + * Finds next-largest "twin primes": numbers p and p+2 such that both are prime. Finds the smallest such p + * such that the smaller twin, p, is greater than or equal to n. Returns p+2, the larger of the two twins. + * </p> + */ + public static int nextTwinPrime(int n) { + if (n > MAX_INT_SMALLER_TWIN_PRIME) { + throw new IllegalArgumentException(); + } + if (n <= 3) { + return 5; + } + int next = Primes.nextPrime(n); + while (!Primes.isPrime(next + 2)) { + next = Primes.nextPrime(next + 4); + } + return next + 2; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/common/RandomWrapper.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/common/RandomWrapper.java b/core/src/main/java/org/apache/mahout/common/RandomWrapper.java new file mode 100644 index 0000000..802291b --- /dev/null +++ b/core/src/main/java/org/apache/mahout/common/RandomWrapper.java @@ -0,0 +1,105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import org.apache.commons.math3.random.MersenneTwister; +import org.apache.commons.math3.random.RandomGenerator; + +import java.util.Random; + +public final class RandomWrapper extends Random { + + private static final long STANDARD_SEED = 0xCAFEDEADBEEFBABEL; + + private final RandomGenerator random; + + RandomWrapper() { + random = new MersenneTwister(); + random.setSeed(System.currentTimeMillis() + System.identityHashCode(random)); + } + + RandomWrapper(long seed) { + random = new MersenneTwister(seed); + } + + @Override + public void setSeed(long seed) { + // Since this will be called by the java.util.Random() constructor before we construct + // the delegate... and because we don't actually care about the result of this for our + // purpose: + if (random != null) { + random.setSeed(seed); + } + } + + void resetToTestSeed() { + setSeed(STANDARD_SEED); + } + + public RandomGenerator getRandomGenerator() { + return random; + } + + @Override + protected int next(int bits) { + // Ugh, can't delegate this method -- it's protected + // Callers can't use it and other methods are delegated, so shouldn't matter + throw new UnsupportedOperationException(); + } + + @Override + public void nextBytes(byte[] bytes) { + random.nextBytes(bytes); + } + + @Override + public int nextInt() { + return random.nextInt(); + } + + @Override + public int nextInt(int n) { + return random.nextInt(n); + } + + @Override + public long nextLong() { + return random.nextLong(); + } + + @Override + public boolean nextBoolean() { + return random.nextBoolean(); + } + + @Override + public float nextFloat() { + return random.nextFloat(); + } + + @Override + public double nextDouble() { + return random.nextDouble(); + } + + @Override + public double nextGaussian() { + return random.nextGaussian(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/AbstractMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/AbstractMatrix.java b/core/src/main/java/org/apache/mahout/math/AbstractMatrix.java new file mode 100644 index 0000000..eaaa397 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/AbstractMatrix.java @@ -0,0 +1,834 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Maps; +import org.apache.mahout.math.flavor.MatrixFlavor; +import org.apache.mahout.math.function.DoubleDoubleFunction; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.function.PlusMult; +import org.apache.mahout.math.function.VectorFunction; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +/** + * A few universal implementations of convenience functions for a JVM-backed matrix. + */ +public abstract class AbstractMatrix implements Matrix { + + protected Map<String, Integer> columnLabelBindings; + protected Map<String, Integer> rowLabelBindings; + protected int rows; + protected int columns; + + protected AbstractMatrix(int rows, int columns) { + this.rows = rows; + this.columns = columns; + } + + @Override + public int columnSize() { + return columns; + } + + @Override + public int rowSize() { + return rows; + } + + @Override + public Iterator<MatrixSlice> iterator() { + return iterateAll(); + } + + @Override + public Iterator<MatrixSlice> iterateAll() { + return new AbstractIterator<MatrixSlice>() { + private int row; + + @Override + protected MatrixSlice computeNext() { + if (row >= numRows()) { + return endOfData(); + } + int i = row++; + return new MatrixSlice(viewRow(i), i); + } + }; + } + + @Override + public Iterator<MatrixSlice> iterateNonEmpty() { + return iterator(); + } + + /** + * Abstracted out for the iterator + * + * @return numRows() for row-based iterator, numColumns() for column-based. + */ + @Override + public int numSlices() { + return numRows(); + } + + @Override + public double get(String rowLabel, String columnLabel) { + if (columnLabelBindings == null || rowLabelBindings == null) { + throw new IllegalStateException("Unbound label"); + } + Integer row = rowLabelBindings.get(rowLabel); + Integer col = columnLabelBindings.get(columnLabel); + if (row == null || col == null) { + throw new IllegalStateException("Unbound label"); + } + + return get(row, col); + } + + @Override + public Map<String, Integer> getColumnLabelBindings() { + return columnLabelBindings; + } + + @Override + public Map<String, Integer> getRowLabelBindings() { + return rowLabelBindings; + } + + @Override + public void set(String rowLabel, double[] rowData) { + if (columnLabelBindings == null) { + throw new IllegalStateException("Unbound label"); + } + Integer row = rowLabelBindings.get(rowLabel); + if (row == null) { + throw new IllegalStateException("Unbound label"); + } + set(row, rowData); + } + + @Override + public void set(String rowLabel, int row, double[] rowData) { + if (rowLabelBindings == null) { + rowLabelBindings = new HashMap<>(); + } + rowLabelBindings.put(rowLabel, row); + set(row, rowData); + } + + @Override + public void set(String rowLabel, String columnLabel, double value) { + if (columnLabelBindings == null || rowLabelBindings == null) { + throw new IllegalStateException("Unbound label"); + } + Integer row = rowLabelBindings.get(rowLabel); + Integer col = columnLabelBindings.get(columnLabel); + if (row == null || col == null) { + throw new IllegalStateException("Unbound label"); + } + set(row, col, value); + } + + @Override + public void set(String rowLabel, String columnLabel, int row, int column, double value) { + if (rowLabelBindings == null) { + rowLabelBindings = new HashMap<>(); + } + rowLabelBindings.put(rowLabel, row); + if (columnLabelBindings == null) { + columnLabelBindings = new HashMap<>(); + } + columnLabelBindings.put(columnLabel, column); + + set(row, column, value); + } + + @Override + public void setColumnLabelBindings(Map<String, Integer> bindings) { + columnLabelBindings = bindings; + } + + @Override + public void setRowLabelBindings(Map<String, Integer> bindings) { + rowLabelBindings = bindings; + } + + // index into int[2] for column value + public static final int COL = 1; + + // index into int[2] for row value + public static final int ROW = 0; + + @Override + public int numRows() { + return rowSize(); + } + + @Override + public int numCols() { + return columnSize(); + } + + @Override + public String asFormatString() { + return toString(); + } + + @Override + public Matrix assign(double value) { + int rows = rowSize(); + int columns = columnSize(); + for (int row = 0; row < rows; row++) { + for (int col = 0; col < columns; col++) { + setQuick(row, col, value); + } + } + return this; + } + + @Override + public Matrix assign(double[][] values) { + int rows = rowSize(); + if (rows != values.length) { + throw new CardinalityException(rows, values.length); + } + int columns = columnSize(); + for (int row = 0; row < rows; row++) { + if (columns == values[row].length) { + for (int col = 0; col < columns; col++) { + setQuick(row, col, values[row][col]); + } + } else { + throw new CardinalityException(columns, values[row].length); + } + } + return this; + } + + @Override + public Matrix assign(Matrix other, DoubleDoubleFunction function) { + int rows = rowSize(); + if (rows != other.rowSize()) { + throw new CardinalityException(rows, other.rowSize()); + } + int columns = columnSize(); + if (columns != other.columnSize()) { + throw new CardinalityException(columns, other.columnSize()); + } + for (int row = 0; row < rows; row++) { + for (int col = 0; col < columns; col++) { + setQuick(row, col, function.apply(getQuick(row, col), other.getQuick( + row, col))); + } + } + return this; + } + + @Override + public Matrix assign(Matrix other) { + int rows = rowSize(); + if (rows != other.rowSize()) { + throw new CardinalityException(rows, other.rowSize()); + } + int columns = columnSize(); + if (columns != other.columnSize()) { + throw new CardinalityException(columns, other.columnSize()); + } + for (int row = 0; row < rows; row++) { + for (int col = 0; col < columns; col++) { + setQuick(row, col, other.getQuick(row, col)); + } + } + return this; + } + + @Override + public Matrix assign(DoubleFunction function) { + int rows = rowSize(); + int columns = columnSize(); + for (int row = 0; row < rows; row++) { + for (int col = 0; col < columns; col++) { + setQuick(row, col, function.apply(getQuick(row, col))); + } + } + return this; + } + + /** + * Collects the results of a function applied to each row of a matrix. + * + * @param f The function to be applied to each row. + * @return The vector of results. + */ + @Override + public Vector aggregateRows(VectorFunction f) { + Vector r = new DenseVector(numRows()); + int n = numRows(); + for (int row = 0; row < n; row++) { + r.set(row, f.apply(viewRow(row))); + } + return r; + } + + /** + * Returns a view of a row. Changes to the view will affect the original. + * + * @param row Which row to return. + * @return A vector that references the desired row. + */ + @Override + public Vector viewRow(int row) { + return new MatrixVectorView(this, row, 0, 0, 1); + } + + + /** + * Returns a view of a row. Changes to the view will affect the original. + * + * @param column Which column to return. + * @return A vector that references the desired column. + */ + @Override + public Vector viewColumn(int column) { + return new MatrixVectorView(this, 0, column, 1, 0); + } + + /** + * Provides a view of the diagonal of a matrix. + */ + @Override + public Vector viewDiagonal() { + return new MatrixVectorView(this, 0, 0, 1, 1); + } + + /** + * Collects the results of a function applied to each element of a matrix and then aggregated. + * + * @param combiner A function that combines the results of the mapper. + * @param mapper A function to apply to each element. + * @return The result. + */ + @Override + public double aggregate(final DoubleDoubleFunction combiner, final DoubleFunction mapper) { + return aggregateRows(new VectorFunction() { + @Override + public double apply(Vector v) { + return v.aggregate(combiner, mapper); + } + }).aggregate(combiner, Functions.IDENTITY); + } + + /** + * Collects the results of a function applied to each column of a matrix. + * + * @param f The function to be applied to each column. + * @return The vector of results. + */ + @Override + public Vector aggregateColumns(VectorFunction f) { + Vector r = new DenseVector(numCols()); + for (int col = 0; col < numCols(); col++) { + r.set(col, f.apply(viewColumn(col))); + } + return r; + } + + @Override + public double determinant() { + int rows = rowSize(); + int columns = columnSize(); + if (rows != columns) { + throw new CardinalityException(rows, columns); + } + + if (rows == 2) { + return getQuick(0, 0) * getQuick(1, 1) - getQuick(0, 1) * getQuick(1, 0); + } else { + // TODO: this really should just be one line: + // TODO: new CholeskyDecomposition(this).getL().viewDiagonal().aggregate(Functions.TIMES) + int sign = 1; + double ret = 0; + + for (int i = 0; i < columns; i++) { + Matrix minor = new DenseMatrix(rows - 1, columns - 1); + for (int j = 1; j < rows; j++) { + boolean flag = false; /* column offset flag */ + for (int k = 0; k < columns; k++) { + if (k == i) { + flag = true; + continue; + } + minor.set(j - 1, flag ? k - 1 : k, getQuick(j, k)); + } + } + ret += getQuick(0, i) * sign * minor.determinant(); + sign *= -1; + + } + + return ret; + } + + } + + @SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException") + @Override + public Matrix clone() { + AbstractMatrix clone; + try { + clone = (AbstractMatrix) super.clone(); + } catch (CloneNotSupportedException cnse) { + throw new IllegalStateException(cnse); // can't happen + } + if (rowLabelBindings != null) { + clone.rowLabelBindings = Maps.newHashMap(rowLabelBindings); + } + if (columnLabelBindings != null) { + clone.columnLabelBindings = Maps.newHashMap(columnLabelBindings); + } + return clone; + } + + @Override + public Matrix divide(double x) { + Matrix result = like(); + for (int row = 0; row < rowSize(); row++) { + for (int col = 0; col < columnSize(); col++) { + result.setQuick(row, col, getQuick(row, col) / x); + } + } + return result; + } + + @Override + public double get(int row, int column) { + if (row < 0 || row >= rowSize()) { + throw new IndexException(row, rowSize()); + } + if (column < 0 || column >= columnSize()) { + throw new IndexException(column, columnSize()); + } + return getQuick(row, column); + } + + @Override + public Matrix minus(Matrix other) { + int rows = rowSize(); + if (rows != other.rowSize()) { + throw new CardinalityException(rows, other.rowSize()); + } + int columns = columnSize(); + if (columns != other.columnSize()) { + throw new CardinalityException(columns, other.columnSize()); + } + Matrix result = like(); + for (int row = 0; row < rows; row++) { + for (int col = 0; col < columns; col++) { + result.setQuick(row, col, getQuick(row, col) + - other.getQuick(row, col)); + } + } + return result; + } + + @Override + public Matrix plus(double x) { + Matrix result = like(); + int rows = rowSize(); + int columns = columnSize(); + for (int row = 0; row < rows; row++) { + for (int col = 0; col < columns; col++) { + result.setQuick(row, col, getQuick(row, col) + x); + } + } + return result; + } + + @Override + public Matrix plus(Matrix other) { + int rows = rowSize(); + if (rows != other.rowSize()) { + throw new CardinalityException(rows, other.rowSize()); + } + int columns = columnSize(); + if (columns != other.columnSize()) { + throw new CardinalityException(columns, other.columnSize()); + } + Matrix result = like(); + for (int row = 0; row < rows; row++) { + for (int col = 0; col < columns; col++) { + result.setQuick(row, col, getQuick(row, col) + + other.getQuick(row, col)); + } + } + return result; + } + + @Override + public void set(int row, int column, double value) { + if (row < 0 || row >= rowSize()) { + throw new IndexException(row, rowSize()); + } + if (column < 0 || column >= columnSize()) { + throw new IndexException(column, columnSize()); + } + setQuick(row, column, value); + } + + @Override + public void set(int row, double[] data) { + int columns = columnSize(); + if (columns < data.length) { + throw new CardinalityException(columns, data.length); + } + int rows = rowSize(); + if (row < 0 || row >= rows) { + throw new IndexException(row, rowSize()); + } + for (int i = 0; i < columns; i++) { + setQuick(row, i, data[i]); + } + } + + @Override + public Matrix times(double x) { + Matrix result = like(); + int rows = rowSize(); + int columns = columnSize(); + for (int row = 0; row < rows; row++) { + for (int col = 0; col < columns; col++) { + result.setQuick(row, col, getQuick(row, col) * x); + } + } + return result; + } + + @Override + public Matrix times(Matrix other) { + int columns = columnSize(); + if (columns != other.rowSize()) { + throw new CardinalityException(columns, other.rowSize()); + } + int rows = rowSize(); + int otherColumns = other.columnSize(); + Matrix result = like(rows, otherColumns); + for (int row = 0; row < rows; row++) { + for (int col = 0; col < otherColumns; col++) { + double sum = 0.0; + for (int k = 0; k < columns; k++) { + sum += getQuick(row, k) * other.getQuick(k, col); + } + result.setQuick(row, col, sum); + } + } + return result; + } + + @Override + public Vector times(Vector v) { + int columns = columnSize(); + if (columns != v.size()) { + throw new CardinalityException(columns, v.size()); + } + int rows = rowSize(); + Vector w = new DenseVector(rows); + for (int row = 0; row < rows; row++) { + w.setQuick(row, v.dot(viewRow(row))); + } + return w; + } + + @Override + public Vector timesSquared(Vector v) { + int columns = columnSize(); + if (columns != v.size()) { + throw new CardinalityException(columns, v.size()); + } + int rows = rowSize(); + Vector w = new DenseVector(columns); + for (int i = 0; i < rows; i++) { + Vector xi = viewRow(i); + double d = xi.dot(v); + if (d != 0.0) { + w.assign(xi, new PlusMult(d)); + } + + } + return w; + } + + @Override + public Matrix transpose() { + int rows = rowSize(); + int columns = columnSize(); + Matrix result = like(columns, rows); + for (int row = 0; row < rows; row++) { + for (int col = 0; col < columns; col++) { + result.setQuick(col, row, getQuick(row, col)); + } + } + return result; + } + + @Override + public Matrix viewPart(int rowOffset, int rowsRequested, int columnOffset, int columnsRequested) { + return viewPart(new int[]{rowOffset, columnOffset}, new int[]{rowsRequested, columnsRequested}); + } + + @Override + public Matrix viewPart(int[] offset, int[] size) { + + if (offset[ROW] < 0) { + throw new IndexException(offset[ROW], 0); + } + if (offset[ROW] + size[ROW] > rowSize()) { + throw new IndexException(offset[ROW] + size[ROW], rowSize()); + } + if (offset[COL] < 0) { + throw new IndexException(offset[COL], 0); + } + if (offset[COL] + size[COL] > columnSize()) { + throw new IndexException(offset[COL] + size[COL], columnSize()); + } + + return new MatrixView(this, offset, size); + } + + + @Override + public double zSum() { + double result = 0; + for (int row = 0; row < rowSize(); row++) { + for (int col = 0; col < columnSize(); col++) { + result += getQuick(row, col); + } + } + return result; + } + + @Override + public int[] getNumNondefaultElements() { + return new int[]{rowSize(), columnSize()}; + } + + protected static class TransposeViewVector extends AbstractVector { + + private final Matrix matrix; + private final int transposeOffset; + private final int numCols; + private final boolean rowToColumn; + + protected TransposeViewVector(Matrix m, int offset) { + this(m, offset, true); + } + + protected TransposeViewVector(Matrix m, int offset, boolean rowToColumn) { + super(rowToColumn ? m.numRows() : m.numCols()); + matrix = m; + this.transposeOffset = offset; + this.rowToColumn = rowToColumn; + numCols = rowToColumn ? m.numCols() : m.numRows(); + } + + @SuppressWarnings("CloneDoesntCallSuperClone") + @Override + public Vector clone() { + Vector v = new DenseVector(size()); + v.assign(this, Functions.PLUS); + return v; + } + + @Override + public boolean isDense() { + return true; + } + + @Override + public boolean isSequentialAccess() { + return true; + } + + @Override + protected Matrix matrixLike(int rows, int columns) { + return matrix.like(rows, columns); + } + + @Override + public Iterator<Element> iterator() { + return new AbstractIterator<Element>() { + private int i; + + @Override + protected Element computeNext() { + if (i >= size()) { + return endOfData(); + } + return getElement(i++); + } + }; + } + + /** + * Currently delegates to {@link #iterator()}. + * TODO: This could be optimized to at least skip empty rows if there are many of them. + * + * @return an iterator (currently dense). + */ + @Override + public Iterator<Element> iterateNonZero() { + return iterator(); + } + + @Override + public Element getElement(final int i) { + return new Element() { + @Override + public double get() { + return getQuick(i); + } + + @Override + public int index() { + return i; + } + + @Override + public void set(double value) { + setQuick(i, value); + } + }; + } + + /** + * Used internally by assign() to update multiple indices and values at once. + * Only really useful for sparse vectors (especially SequentialAccessSparseVector). + * <p> + * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector. + * + * @param updates a mapping of indices to values to merge in the vector. + */ + @Override + public void mergeUpdates(OrderedIntDoubleMapping updates) { + throw new UnsupportedOperationException("Cannot mutate TransposeViewVector"); + } + + @Override + public double getQuick(int index) { + Vector v = rowToColumn ? matrix.viewColumn(index) : matrix.viewRow(index); + return v == null ? 0.0 : v.getQuick(transposeOffset); + } + + @Override + public void setQuick(int index, double value) { + Vector v = rowToColumn ? matrix.viewColumn(index) : matrix.viewRow(index); + if (v == null) { + v = newVector(numCols); + if (rowToColumn) { + matrix.assignColumn(index, v); + } else { + matrix.assignRow(index, v); + } + } + v.setQuick(transposeOffset, value); + } + + protected Vector newVector(int cardinality) { + return new DenseVector(cardinality); + } + + @Override + public Vector like() { + return new DenseVector(size()); + } + + public Vector like(int cardinality) { + return new DenseVector(cardinality); + } + + /** + * TODO: currently I don't know of an efficient way to getVector this value correctly. + * + * @return the number of nonzero entries + */ + @Override + public int getNumNondefaultElements() { + return size(); + } + + @Override + public double getLookupCost() { + return (rowToColumn ? matrix.viewColumn(0) : matrix.viewRow(0)).getLookupCost(); + } + + @Override + public double getIteratorAdvanceCost() { + return (rowToColumn ? matrix.viewColumn(0) : matrix.viewRow(0)).getIteratorAdvanceCost(); + } + + @Override + public boolean isAddConstantTime() { + return (rowToColumn ? matrix.viewColumn(0) : matrix.viewRow(0)).isAddConstantTime(); + } + } + + @Override + public String toString() { + int row = 0; + int maxRowsToDisplay = 10; + int maxColsToDisplay = 20; + int colsToDisplay = maxColsToDisplay; + + if(maxColsToDisplay > columnSize()){ + colsToDisplay = columnSize(); + } + + + StringBuilder s = new StringBuilder("{\n"); + Iterator<MatrixSlice> it = iterator(); + while ((it.hasNext()) && (row < maxRowsToDisplay)) { + MatrixSlice next = it.next(); + s.append(" ").append(next.index()) + .append(" =>\t") + .append(new VectorView(next.vector(), 0, colsToDisplay)) + .append('\n'); + row ++; + } + String returnString = s.toString(); + if (maxColsToDisplay <= columnSize()) { + returnString = returnString.replace("}", " ... } "); + } + if(maxRowsToDisplay <= rowSize()) + return returnString + ("... }"); + else{ + return returnString + ("}"); + } + } + + @Override + public MatrixFlavor getFlavor() { + throw new UnsupportedOperationException("Flavor support not implemented for this matrix."); + } + + ////////////// Matrix flavor trait /////////////////// + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/AbstractVector.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/AbstractVector.java b/core/src/main/java/org/apache/mahout/math/AbstractVector.java new file mode 100644 index 0000000..27eddbc --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/AbstractVector.java @@ -0,0 +1,684 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import java.util.Iterator; + +import com.google.common.base.Preconditions; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.function.DoubleDoubleFunction; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.function.Functions; + +/** Implementations of generic capabilities like sum of elements and dot products */ +public abstract class AbstractVector implements Vector, LengthCachingVector { + + private int size; + protected double lengthSquared = -1.0; + + protected AbstractVector(int size) { + this.size = size; + } + + @Override + public Iterable<Element> all() { + return new Iterable<Element>() { + @Override + public Iterator<Element> iterator() { + return AbstractVector.this.iterator(); + } + }; + } + + @Override + public Iterable<Element> nonZeroes() { + return new Iterable<Element>() { + @Override + public Iterator<Element> iterator() { + return iterateNonZero(); + } + }; + } + + /** + * Iterates over all elements <p> + * NOTE: Implementations may choose to reuse the Element returned for performance + * reasons, so if you need a copy of it, you should call {@link #getElement(int)} for the given index + * + * @return An {@link Iterator} over all elements + */ + protected abstract Iterator<Element> iterator(); + + /** + * Iterates over all non-zero elements. <p> + * NOTE: Implementations may choose to reuse the Element returned for + * performance reasons, so if you need a copy of it, you should call {@link #getElement(int)} for the given index + * + * @return An {@link Iterator} over all non-zero elements + */ + protected abstract Iterator<Element> iterateNonZero(); + /** + * Aggregates a vector by applying a mapping function fm(x) to every component and aggregating + * the results with an aggregating function fa(x, y). + * + * @param aggregator used to combine the current value of the aggregation with the result of map.apply(nextValue) + * @param map a function to apply to each element of the vector in turn before passing to the aggregator + * @return the result of the aggregation + */ + @Override + public double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map) { + if (size == 0) { + return 0; + } + + // If the aggregator is associative and commutative and it's likeLeftMult (fa(0, y) = 0), and there is + // at least one zero in the vector (size > getNumNondefaultElements) and applying fm(0) = 0, the result + // gets cascaded through the aggregation and the final result will be 0. + if (aggregator.isAssociativeAndCommutative() && aggregator.isLikeLeftMult() + && size > getNumNondefaultElements() && !map.isDensifying()) { + return 0; + } + + double result; + if (isSequentialAccess() || aggregator.isAssociativeAndCommutative()) { + Iterator<Element> iterator; + // If fm(0) = 0 and fa(x, 0) = x, we can skip all zero values. + if (!map.isDensifying() && aggregator.isLikeRightPlus()) { + iterator = iterateNonZero(); + if (!iterator.hasNext()) { + return 0; + } + } else { + iterator = iterator(); + } + Element element = iterator.next(); + result = map.apply(element.get()); + while (iterator.hasNext()) { + element = iterator.next(); + result = aggregator.apply(result, map.apply(element.get())); + } + } else { + result = map.apply(getQuick(0)); + for (int i = 1; i < size; i++) { + result = aggregator.apply(result, map.apply(getQuick(i))); + } + } + + return result; + } + + @Override + public double aggregate(Vector other, DoubleDoubleFunction aggregator, DoubleDoubleFunction combiner) { + Preconditions.checkArgument(size == other.size(), "Vector sizes differ"); + if (size == 0) { + return 0; + } + return VectorBinaryAggregate.aggregateBest(this, other, aggregator, combiner); + } + + /** + * Subclasses must override to return an appropriately sparse or dense result + * + * @param rows the row cardinality + * @param columns the column cardinality + * @return a Matrix + */ + protected abstract Matrix matrixLike(int rows, int columns); + + @Override + public Vector viewPart(int offset, int length) { + if (offset < 0) { + throw new IndexException(offset, size); + } + if (offset + length > size) { + throw new IndexException(offset + length, size); + } + return new VectorView(this, offset, length); + } + + @SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException") + @Override + public Vector clone() { + try { + AbstractVector r = (AbstractVector) super.clone(); + r.size = size; + r.lengthSquared = lengthSquared; + return r; + } catch (CloneNotSupportedException e) { + throw new IllegalStateException("Can't happen"); + } + } + + @Override + public Vector divide(double x) { + if (x == 1.0) { + return clone(); + } + Vector result = createOptimizedCopy(); + for (Element element : result.nonZeroes()) { + element.set(element.get() / x); + } + return result; + } + + @Override + public double dot(Vector x) { + if (size != x.size()) { + throw new CardinalityException(size, x.size()); + } + if (this == x) { + return getLengthSquared(); + } + return aggregate(x, Functions.PLUS, Functions.MULT); + } + + protected double dotSelf() { + return aggregate(Functions.PLUS, Functions.pow(2)); + } + + @Override + public double get(int index) { + if (index < 0 || index >= size) { + throw new IndexException(index, size); + } + return getQuick(index); + } + + @Override + public Element getElement(int index) { + return new LocalElement(index); + } + + @Override + public Vector normalize() { + return divide(Math.sqrt(getLengthSquared())); + } + + @Override + public Vector normalize(double power) { + return divide(norm(power)); + } + + @Override + public Vector logNormalize() { + return logNormalize(2.0, Math.sqrt(getLengthSquared())); + } + + @Override + public Vector logNormalize(double power) { + return logNormalize(power, norm(power)); + } + + public Vector logNormalize(double power, double normLength) { + // we can special case certain powers + if (Double.isInfinite(power) || power <= 1.0) { + throw new IllegalArgumentException("Power must be > 1 and < infinity"); + } else { + double denominator = normLength * Math.log(power); + Vector result = createOptimizedCopy(); + for (Element element : result.nonZeroes()) { + element.set(Math.log1p(element.get()) / denominator); + } + return result; + } + } + + @Override + public double norm(double power) { + if (power < 0.0) { + throw new IllegalArgumentException("Power must be >= 0"); + } + // We can special case certain powers. + if (Double.isInfinite(power)) { + return aggregate(Functions.MAX, Functions.ABS); + } else if (power == 2.0) { + return Math.sqrt(getLengthSquared()); + } else if (power == 1.0) { + double result = 0.0; + Iterator<Element> iterator = this.iterateNonZero(); + while (iterator.hasNext()) { + result += Math.abs(iterator.next().get()); + } + return result; + // TODO: this should ideally be used, but it's slower. + // return aggregate(Functions.PLUS, Functions.ABS); + } else if (power == 0.0) { + return getNumNonZeroElements(); + } else { + return Math.pow(aggregate(Functions.PLUS, Functions.pow(power)), 1.0 / power); + } + } + + @Override + public double getLengthSquared() { + if (lengthSquared >= 0.0) { + return lengthSquared; + } + return lengthSquared = dotSelf(); + } + + @Override + public void invalidateCachedLength() { + lengthSquared = -1; + } + + @Override + public double getDistanceSquared(Vector that) { + if (size != that.size()) { + throw new CardinalityException(size, that.size()); + } + double thisLength = getLengthSquared(); + double thatLength = that.getLengthSquared(); + double dot = dot(that); + double distanceEstimate = thisLength + thatLength - 2 * dot; + if (distanceEstimate > 1.0e-3 * (thisLength + thatLength)) { + // The vectors are far enough from each other that the formula is accurate. + return Math.max(distanceEstimate, 0); + } else { + return aggregate(that, Functions.PLUS, Functions.MINUS_SQUARED); + } + } + + @Override + public double maxValue() { + if (size == 0) { + return Double.NEGATIVE_INFINITY; + } + return aggregate(Functions.MAX, Functions.IDENTITY); + } + + @Override + public int maxValueIndex() { + int result = -1; + double max = Double.NEGATIVE_INFINITY; + int nonZeroElements = 0; + Iterator<Element> iter = this.iterateNonZero(); + while (iter.hasNext()) { + nonZeroElements++; + Element element = iter.next(); + double tmp = element.get(); + if (tmp > max) { + max = tmp; + result = element.index(); + } + } + // if the maxElement is negative and the vector is sparse then any + // unfilled element(0.0) could be the maxValue hence we need to + // find one of those elements + if (nonZeroElements < size && max < 0.0) { + for (Element element : all()) { + if (element.get() == 0.0) { + return element.index(); + } + } + } + return result; + } + + @Override + public double minValue() { + if (size == 0) { + return Double.POSITIVE_INFINITY; + } + return aggregate(Functions.MIN, Functions.IDENTITY); + } + + @Override + public int minValueIndex() { + int result = -1; + double min = Double.POSITIVE_INFINITY; + int nonZeroElements = 0; + Iterator<Element> iter = this.iterateNonZero(); + while (iter.hasNext()) { + nonZeroElements++; + Element element = iter.next(); + double tmp = element.get(); + if (tmp < min) { + min = tmp; + result = element.index(); + } + } + // if the maxElement is positive and the vector is sparse then any + // unfilled element(0.0) could be the maxValue hence we need to + // find one of those elements + if (nonZeroElements < size && min > 0.0) { + for (Element element : all()) { + if (element.get() == 0.0) { + return element.index(); + } + } + } + return result; + } + + @Override + public Vector plus(double x) { + Vector result = createOptimizedCopy(); + if (x == 0.0) { + return result; + } + return result.assign(Functions.plus(x)); + } + + @Override + public Vector plus(Vector that) { + if (size != that.size()) { + throw new CardinalityException(size, that.size()); + } + return createOptimizedCopy().assign(that, Functions.PLUS); + } + + @Override + public Vector minus(Vector that) { + if (size != that.size()) { + throw new CardinalityException(size, that.size()); + } + return createOptimizedCopy().assign(that, Functions.MINUS); + } + + @Override + public void set(int index, double value) { + if (index < 0 || index >= size) { + throw new IndexException(index, size); + } + setQuick(index, value); + } + + @Override + public void incrementQuick(int index, double increment) { + setQuick(index, getQuick(index) + increment); + } + + @Override + public Vector times(double x) { + if (x == 0.0) { + return like(); + } + return createOptimizedCopy().assign(Functions.mult(x)); + } + + /** + * Copy the current vector in the most optimum fashion. Used by immutable methods like plus(), minus(). + * Use this instead of vector.like().assign(vector). Sub-class can choose to override this method. + * + * @return a copy of the current vector. + */ + protected Vector createOptimizedCopy() { + return createOptimizedCopy(this); + } + + private static Vector createOptimizedCopy(Vector vector) { + Vector result; + if (vector.isDense()) { + result = vector.like().assign(vector, Functions.SECOND_LEFT_ZERO); + } else { + result = vector.clone(); + } + return result; + } + + @Override + public Vector times(Vector that) { + if (size != that.size()) { + throw new CardinalityException(size, that.size()); + } + + if (this.getNumNondefaultElements() <= that.getNumNondefaultElements()) { + return createOptimizedCopy(this).assign(that, Functions.MULT); + } else { + return createOptimizedCopy(that).assign(this, Functions.MULT); + } + } + + @Override + public double zSum() { + return aggregate(Functions.PLUS, Functions.IDENTITY); + } + + @Override + public int getNumNonZeroElements() { + int count = 0; + Iterator<Element> it = iterateNonZero(); + while (it.hasNext()) { + if (it.next().get() != 0.0) { + count++; + } + } + return count; + } + + @Override + public Vector assign(double value) { + Iterator<Element> it; + if (value == 0.0) { + // Make all the non-zero values 0. + it = iterateNonZero(); + while (it.hasNext()) { + it.next().set(value); + } + } else { + if (isSequentialAccess() && !isAddConstantTime()) { + // Update all the non-zero values and queue the updates for the zero vaues. + // The vector will become dense. + it = iterator(); + OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(); + while (it.hasNext()) { + Element element = it.next(); + if (element.get() == 0.0) { + updates.set(element.index(), value); + } else { + element.set(value); + } + } + mergeUpdates(updates); + } else { + for (int i = 0; i < size; ++i) { + setQuick(i, value); + } + } + } + invalidateCachedLength(); + return this; + } + + @Override + public Vector assign(double[] values) { + if (size != values.length) { + throw new CardinalityException(size, values.length); + } + if (isSequentialAccess() && !isAddConstantTime()) { + OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping(); + Iterator<Element> it = iterator(); + while (it.hasNext()) { + Element element = it.next(); + int index = element.index(); + if (element.get() == 0.0) { + updates.set(index, values[index]); + } else { + element.set(values[index]); + } + } + mergeUpdates(updates); + } else { + for (int i = 0; i < size; ++i) { + setQuick(i, values[i]); + } + } + invalidateCachedLength(); + return this; + } + + @Override + public Vector assign(Vector other) { + return assign(other, Functions.SECOND); + } + + @Override + public Vector assign(DoubleDoubleFunction f, double y) { + Iterator<Element> iterator = f.apply(0, y) == 0 ? iterateNonZero() : iterator(); + while (iterator.hasNext()) { + Element element = iterator.next(); + element.set(f.apply(element.get(), y)); + } + invalidateCachedLength(); + return this; + } + + @Override + public Vector assign(DoubleFunction f) { + Iterator<Element> iterator = !f.isDensifying() ? iterateNonZero() : iterator(); + while (iterator.hasNext()) { + Element element = iterator.next(); + element.set(f.apply(element.get())); + } + invalidateCachedLength(); + return this; + } + + @Override + public Vector assign(Vector other, DoubleDoubleFunction function) { + if (size != other.size()) { + throw new CardinalityException(size, other.size()); + } + VectorBinaryAssign.assignBest(this, other, function); + invalidateCachedLength(); + return this; + } + + @Override + public Matrix cross(Vector other) { + Matrix result = matrixLike(size, other.size()); + Iterator<Vector.Element> it = iterateNonZero(); + while (it.hasNext()) { + Vector.Element e = it.next(); + int row = e.index(); + result.assignRow(row, other.times(getQuick(row))); + } + return result; + } + + @Override + public final int size() { + return size; + } + + @Override + public String asFormatString() { + return toString(); + } + + @Override + public int hashCode() { + int result = size; + Iterator<Element> iter = iterateNonZero(); + while (iter.hasNext()) { + Element ele = iter.next(); + result += ele.index() * RandomUtils.hashDouble(ele.get()); + } + return result; + } + + /** + * Determines whether this {@link Vector} represents the same logical vector as another + * object. Two {@link Vector}s are equal (regardless of implementation) if the value at + * each index is the same, and the cardinalities are the same. + */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Vector)) { + return false; + } + Vector that = (Vector) o; + return size == that.size() && aggregate(that, Functions.PLUS, Functions.MINUS_ABS) == 0.0; + } + + @Override + public String toString() { + return toString(null); + } + + public String toString(String[] dictionary) { + StringBuilder result = new StringBuilder(); + result.append('{'); + for (int index = 0; index < size; index++) { + double value = getQuick(index); + if (value != 0.0) { + result.append(dictionary != null && dictionary.length > index ? dictionary[index] : index); + result.append(':'); + result.append(value); + result.append(','); + } + } + if (result.length() > 1) { + result.setCharAt(result.length() - 1, '}'); + } else { + result.append('}'); + } + return result.toString(); + } + + /** + * toString() implementation for sparse vectors via {@link #nonZeroes()} method + * @return String representation of the vector + */ + public String sparseVectorToString() { + Iterator<Element> it = iterateNonZero(); + if (!it.hasNext()) { + return "{}"; + } + else { + StringBuilder result = new StringBuilder(); + result.append('{'); + while (it.hasNext()) { + Vector.Element e = it.next(); + result.append(e.index()); + result.append(':'); + result.append(e.get()); + result.append(','); + } + result.setCharAt(result.length() - 1, '}'); + return result.toString(); + } + } + + protected final class LocalElement implements Element { + int index; + + LocalElement(int index) { + this.index = index; + } + + @Override + public double get() { + return getQuick(index); + } + + @Override + public int index() { + return index; + } + + @Override + public void set(double value) { + setQuick(index, value); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/Algebra.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/Algebra.java b/core/src/main/java/org/apache/mahout/math/Algebra.java new file mode 100644 index 0000000..3049057 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/Algebra.java @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +public final class Algebra { + + private Algebra() { + } + + public static Vector mult(Matrix m, Vector v) { + if (m.numRows() != v.size()) { + throw new CardinalityException(m.numRows(), v.size()); + } + // Use a Dense Vector for the moment, + Vector result = new DenseVector(m.numRows()); + + for (int i = 0; i < m.numRows(); i++) { + result.set(i, m.viewRow(i).dot(v)); + } + + return result; + } + + /** Returns sqrt(a^2 + b^2) without under/overflow. */ + public static double hypot(double a, double b) { + double r; + if (Math.abs(a) > Math.abs(b)) { + r = b / a; + r = Math.abs(a) * Math.sqrt(1 + r * r); + } else if (b != 0) { + r = a / b; + r = Math.abs(b) * Math.sqrt(1 + r * r); + } else { + r = 0.0; + } + return r; + } + + /** + * Compute Maximum Absolute Row Sum Norm of input Matrix m + * http://mathworld.wolfram.com/MaximumAbsoluteRowSumNorm.html + */ + public static double getNorm(Matrix m) { + double max = 0.0; + for (int i = 0; i < m.numRows(); i++) { + int sum = 0; + Vector cv = m.viewRow(i); + for (int j = 0; j < cv.size(); j++) { + sum += (int) Math.abs(cv.getQuick(j)); + } + if (sum > max) { + max = sum; + } + } + return max; + } + +}
