http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java deleted file mode 100644 index ed6bf36..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/ArtificialRegressionDatasets.java +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.linear; - -/** - * Artificial regression datasets to be used in regression trainers tests. These datasets were generated by scikit-learn - * tools, {@code sklearn.datasets.make_regression} procedure. - */ -public class ArtificialRegressionDatasets { - /** - * Artificial dataset with 10 observations described by 1 feature. - */ - public static final TestDataset regression10x1 = new TestDataset(new double[][] { - {1.97657990214, 0.197725444973}, - {-5.0835948878, -0.279921224228}, - {-5.09032600779, -0.352291245969}, - {9.67660993007, 0.755464872441}, - {4.95927629958, 0.451981771462}, - {29.2635107429, 2.2277440173}, - {-18.3122588459, -1.25363275369}, - {-3.61729307199, -0.273362913982}, - {-7.19042139249, -0.473846634967}, - {3.68008403347, 0.353883097536} - }, new double[] {13.554054703}, -0.808655936776); - - /** - * Artificial dataset with 10 observations described by 5 features. - */ - public static final TestDataset regression10x5 = new TestDataset(new double[][] { - {118.635647237, 0.687593385888, -1.18956185502, -0.305420702986, 1.98794097418, -0.776629036361}, - {-18.2808432286, -0.165921853684, -0.156162539573, 1.56284391134, -0.198876782109, -0.0921618505605}, - {22.6110523992, 0.0268106268606, 0.702141470035, -0.41503615392, -1.09726502337, 1.30830482813}, - {209.820435262, 0.379809113402, -0.192097238579, -1.27460497119, 2.48052002019, -0.574430888865}, - {-253.750024054, -1.48044570917, -0.331747484523, 0.387993627712, 0.372583756237, -2.27404065923}, - {-24.6467766166, -0.66991474156, 0.269042238935, -0.271412703096, -0.561166818525, 1.37067541854}, - {-311.903650717, 0.268274438122, -1.10491275353, -1.06738703543, -2.24387799735, -0.207431467989}, - {74.2055323536, -0.329489531894, -0.493350762533, -0.644851462227, 0.661220945573, 1.65950140864}, - {57.0312289904, -1.07266578457, 0.80375035572, -0.45207210139, 1.69314420969, -1.10526080856}, - {12.149399645, 1.46504629281, -1.05843246079, 0.266225365277, -0.0113100353869, -0.983495425471} - }, new double[] {99.8393653561, 82.4948224094, 20.2087724072, 97.3306384162, 55.7502297387}, 3.98444039189); - - /** - * Artificial dataset with 100 observations described by 5 features. - */ - public static final TestDataset regression100x5 = new TestDataset(new double[][] { - {-44.2310642946, -0.0331360137605, -0.5290800706, -0.634340342338, -0.428433927151, 0.830582347183}, - {76.2539139721, -0.216200869652, 0.513212019048, -0.693404511747, 0.132995973133, 1.28470259833}, - {293.369799914, 2.90735870802, 0.457740818846, -0.490470696097, -0.442343455187, 0.584038258781}, - {124.258807314, 1.64158129148, 0.0616936820145, 1.24082841519, -1.20126518593, -0.542298907742}, - {13.6610807249, -1.10834821778, 0.545508208111, 1.81361288715, -0.786543112444, 0.250772626496}, - {101.924582305, -0.433526394969, 0.257594734335, 1.22333193911, 0.76626554927, -0.0400734567005}, - {25.5963186303, -0.202003301507, 0.717101151637, -0.486881225605, 1.15215024807, -0.921615554612}, - {75.7959681263, -0.604173187402, 0.0364386836472, 1.67544714536, 0.394743148877, 0.0237966550759}, - {-97.539357166, -0.774517689169, -0.0966902473883, -0.152250704254, -0.325472625458, 0.0720711851256}, - {0.394748999236, -0.559303402754, -0.0493339259273, -1.10840277768, -0.0800969523557, 1.80939282066}, - {-62.0138166431, 0.062614716778, -0.844143618016, 0.55269949861, -2.32580899335, 1.58020577369}, - {584.427692931, 2.13184767906, 1.22222461994, 1.71894070494, 2.69512281718, 0.294123497874}, - {-59.8323709765, 1.00006112818, -1.54481230765, -0.781282316493, 0.0255925284853, -0.0821173744608}, - {101.565711925, -0.38699836725, 1.06934591441, -0.260429311097, 1.02628949564, 0.0431473245174}, - {-141.592607814, 0.993279116267, -0.371768203378, -0.851483217286, -1.96241293548, -0.612279404296}, - {34.8038723379, -0.0182719243972, 0.306367604506, -0.650526589206, 1.30693112283, -0.587465952557}, - {-16.9554534069, -0.703006786668, -0.770718401931, 0.748423272307, 0.502544067819, 0.346625621533}, - {-76.2896177709, -0.16440174812, -1.77431555198, 0.195326723837, 2.01240994405, -1.19559207119}, - {-3.23827624818, -0.674138419631, -1.62238580284, 2.02235607862, 0.679194838679, 0.150203732584}, - {-21.962456854, -0.766271014206, 0.958599712131, -0.313045794728, 0.232655576106, -0.360950549871}, - {349.583669646, 1.75976166947, 1.47271612346, 0.0346005603489, 0.474907228495, 0.61379496381}, - {-418.397356757, -1.83395936566, -0.911702678716, -0.532478094882, -2.03835348133, -0.423005552518}, - {55.0298153952, -0.0301384716096, -0.0137929430966, -0.348583692759, 0.986486580719, 0.154436524434}, - {127.150063206, 1.92682560465, -0.434844790414, 0.1082898967, -0.00723338222402, -0.513199251824}, - {89.6172507626, 1.02463790902, 0.744369837717, 1.250323683, -1.58252612128, -0.588242778808}, - {92.5124829355, -0.403298547743, 0.0422774545428, -0.175000467434, 1.61110066857, 0.422330077287}, - {-303.040366788, 0.611569308879, -1.21926246291, -2.49250330276, -0.789166929605, -1.30166501196}, - {-17.4020602839, 1.72337202371, -1.83540537288, 0.731588761841, -0.338642535062, -1.11053518125}, - {114.918701324, 0.437385758628, 0.975885170381, 0.439444038872, 1.51666514156, -1.93095020264}, - {-8.43548064928, -0.799507968686, -0.00842968328782, -0.154994093964, 1.09169753491, -0.0114818657732}, - {109.209286025, 2.56472965015, -2.07047248035, -0.46764001177, 0.845267147375, -0.236767841427}, - {61.5259982971, -0.379391870148, -0.131017762354, -0.220275015864, 1.82097825699, -0.0568354876403}, - {-71.3872099588, 0.642138455414, -1.00242489879, 0.536780074488, 0.350977275771, -1.8204862883}, - {-21.2768078629, -0.454268998895, 0.0992324274219, 0.0363496803224, 0.281940751723, -0.198435570828}, - {-8.07838891387, -0.331642089041, -0.494067341253, 0.386035842816, -0.738221128298, 1.18236299649}, - {30.4818041751, 0.099206096537, 0.150688905006, 0.332932621949, 0.194845631964, -0.446717875795}, - {237.209150991, 1.12560447042, 0.448488431264, -0.724623711259, 0.401868257097, 1.67129001163}, - {185.172816475, 0.36594142556, -0.0796476435741, 0.473836257, 1.30890722633, 0.592415068693}, - {19.8830237044, 1.52497319332, 0.466906090264, -0.716635613964, -1.19532276745, -0.697663531684}, - {209.396793626, 0.368478789658, 0.699162303982, 1.96702434462, -0.815379139879, 0.863369634396}, - {-215.100514168, -1.83902416164, -1.14966820385, -1.01044860587, 1.76881340629, -0.32165916241}, - {-33.4687353426, -0.0451102002703, 0.642212950033, 0.580822065219, -1.02341504063, -0.781229325942}, - {150.251474823, 0.220170650298, 0.224858901011, 0.541299425328, 1.15151550963, 0.0329044069571}, - {92.2160506097, 1.86450932451, -0.991150940533, -1.49137866968, 1.02113774105, 0.0544762857136}, - {41.2138467595, -0.778892265105, 0.714957464344, 1.79833618993, -0.335322825621, -0.397548301803}, - {13.151262759, 0.301745607362, 0.129778280739, 0.260094818273, -0.10587841585, -0.599330307629}, - {-367.864703951, -1.68695981263, -0.611957677512, -0.0362971579679, -1.2169760515, -1.43224375134}, - {-57.218869838, 0.428806849751, 0.654302177028, -1.31651788496, 0.363857431276, -1.49953703016}, - {53.0877462955, -0.411907760185, -0.192634094071, -0.275879375023, 0.603562526571, 1.16508196734}, - {-8.11860742896, 1.00263982158, -0.157031169267, -1.11795623393, 0.35711440521, -0.851124640982}, - {-49.1878248403, -0.0253797866589, -0.574767070714, 0.200339045636, -0.0107042446803, -0.351288977927}, - {-73.8835407053, -2.07980276724, 1.12235566491, -0.917150593536, 0.741384768556, 0.56229424235}, - {143.163604045, 0.33627769945, 1.07948757447, 0.894869929963, 1.18688316974, -1.54722487849}, - {92.7045830908, 0.944091525689, 0.693296229491, 0.700097596814, -1.23666276942, -0.203890113084}, - {79.1878852355, -0.221973023853, -0.566066329011, 1.57683748648, 0.52854717911, 0.147924782476}, - {30.6547392801, -1.03466213359, 0.606784904328, -0.298096511956, 0.83332987683, 0.636339018254}, - {-329.128386019, -1.41363866598, -1.34966434823, -0.989010564149, 0.46889477248, -1.20493210784}, - {121.190205512, 0.0393914245697, 1.98392444232, -0.65310705226, -0.385899987099, 0.444982471471}, - {-97.0333075649, 0.264325871992, -0.43074811924, -1.14737761316, -0.453134140655, -0.038507405311}, - {158.273624516, 0.302255432981, -0.292046617818, 1.0704087606, 0.815965268115, 0.470631083546}, - {8.24795061818, -1.15155524496, 1.29538707184, -0.4650881541, 0.805123486308, -0.134706887329}, - {87.1140049059, -0.103540823781, -0.192259440773, 1.79648860085, -1.07525447993, 1.06985127941}, - {-25.1300772481, -0.97140742052, 0.033393948794, -0.698311192672, 0.74417168942, 0.752776770225}, - {-285.477057638, -0.480612406803, -1.46081500036, -1.92518386336, -0.426454066275, -0.0539099489597}, - {-65.1269988498, -1.22733468764, 0.121538452336, 0.752958777557, -0.40643211762, 0.257674949803}, - {-17.1813504942, 0.823753836891, 0.445142465255, 0.185644700144, -1.99733367514, -0.247899323048}, - {-46.7543447303, 0.183482778928, -0.934858705943, -1.21961947396, 0.460921844744, 0.571388077177}, - {-1.7536190499, -0.107517908181, 0.0334282610968, -0.556676121428, -0.485957577159, 0.943570398164}, - {-42.8460452689, 0.944999215632, 0.00530052154909, -0.348526283976, -1.724125354, -0.122649339813}, - {62.6291497267, 0.249619894002, 1.3139125969, -1.5644227783, 0.117605482783, 0.304844650662}, - {97.4552176343, 1.59332799639, -1.17868305562, 1.02998378902, -0.31959491258, -0.183038322076}, - {-6.19358885758, 0.437951016253, 0.373339269494, -0.204072768495, 0.477969349931, -1.52176449389}, - {34.0350630099, 0.839319087287, -0.610157662489, 1.73881448393, -1.89200107709, 0.204946415522}, - {54.9790822536, -0.191792583114, 0.989791127554, -0.502154080064, 0.469939512389, -0.102304071079}, - {58.8272402843, 0.0769623906454, 0.501297284297, -0.410054999243, 0.595712387781, -0.0968329050729}, - {95.3620983209, 0.0661481959314, 0.0935137309086, 1.11823292347, -0.612960777903, 0.767865072757}, - {62.4278196648, 0.78350610065, -1.09977017652, 0.526824784479, 1.41310104196, -0.887902707319}, - {57.6298676729, 0.60084172954, -0.785932027202, 0.0271301584637, -0.134109499719, 0.877256170191}, - {5.14112905382, -0.738359365006, 1.40242539359, -0.852833010305, -0.68365080837, 0.88561193696}, - {11.6057244034, -0.958911227571, 1.15715937023, 1.20108425431, 0.882980929338, -1.77404120156}, - {-265.758185272, -1.2092434823, -0.0550151798639, 0.00703735243613, -1.01767244359, -1.40616581707}, - {180.625928828, -0.139091127126, 0.243250756129, 2.17509702585, -0.541735827898, 1.2109459934}, - {-183.604103216, -0.324555097769, -1.71317286749, 1.03645005723, 0.497569347608, -1.96688185911}, - {9.93237328848, 0.825483591345, 0.910287997312, -1.64938108528, 0.98964075968, -1.65748940528}, - {-88.6846949813, -0.0759295112746, -0.593311990101, -0.578711915019, 0.256298822361, -0.429322890198}, - {175.367391479, 0.9361754906, -0.0172852897292, 1.04078658833, 0.919566407184, -0.554923019093}, - {-175.538247146, -1.43498590417, 0.37233438556, -0.897205352198, -0.339309952316, -0.0321624527843}, - {-126.331680318, 0.160446617623, 0.816642363249, -1.39863371652, 0.199747744327, -2.13493607457}, - {116.677107593, 1.19300905847, -0.404409346893, 0.646338976096, -0.534204093869, 0.36692724765}, - {-181.675962893, -1.57613169533, -0.41549571451, -0.956673746013, 0.35723782515, 0.318317395128}, - {-55.1457877823, 0.63723030991, -0.324480386466, 0.296028333894, -1.68117515658, -0.131945601375}, - {25.2534791013, 0.594818219911, -0.0247380403547, -0.101492246071, -0.0745619242015, -0.370837128867}, - {63.6006283756, -1.53493473818, 0.946464097439, 0.637741397831, 0.938866921166, 0.54405291856}, - {-69.6245547661, 0.328482934094, -0.776881060846, -0.285133098443, -1.06107824512, 0.49952182341}, - {233.425957233, 3.10582399189, -0.0854710508706, 0.455873479133, -0.0974589364949, -1.18914783551}, - {-86.5564290626, -0.819839276484, 0.584745927593, -0.544737106102, -1.21927675581, 0.758502626434}, - {425.357285631, 1.70712253847, 1.19892647853, 1.60619661301, 0.36832665241, 0.880791322709}, - {111.797225426, 0.558940594145, -0.746492420236, 1.90172101792, 0.853590062366, -0.867970723941}, - {-253.616801014, -0.426513440051, 0.0388582291888, -1.18576061365, -2.70895868242, 0.26982210287}, - {-394.801501024, -1.65087241498, 0.735525201393, -2.02413077052, -0.96492749037, -1.89014065613} - }, new double[] {93.3843533037, 72.3610889215, 57.5295295915, 63.7287541653, 65.2263084024}, 6.85683020686); - - /** - * Artificial dataset with 100 observations described by 10 features. - */ - public static final TestDataset regression100x10 = new TestDataset(new double[][] { - {69.5794204114, -0.684238565877, 0.175665643732, 0.882115894035, 0.612844187624, - -0.685301720572, -0.8266500007, -0.0383407025118, 1.7105205222, 0.457436379836, -0.291563926494}, - {80.1390102826, -1.80708821811, 0.811271788195, 0.30248512861, 0.910658009566, - -1.61869762501, -0.148325085362, -0.0714164596509, 0.671646742271, 2.15160094956, -0.0495754979721}, - {-156.975447515, 0.170702943934, -0.973403372054, -0.093974528453, 1.54577255871, - -0.0969022857972, -1.10639617368, 1.51752480948, -2.86016865032, 1.24063030602, -0.521785751026}, - {-158.134931891, 0.0890071395055, -0.0811824442353, -0.737354274843, -1.7575255492, - 0.265777246641, 0.0745347238144, -0.457603542683, -1.37034043839, 1.86011799875, 0.651214189491}, - {-131.465820263, 0.0767565260375, 0.651724194978, 0.142113799753, 0.244367469855, - -0.334395162837, -0.069092305876, -0.691806779713, -1.28386786177, -1.43647491141, 0.00721053414234}, - {-125.468890054, 0.43361925912, -0.800231440065, -0.576001094593, 0.0783664516431, - -1.33613252233, -0.968385062126, -1.22077801286, 0.193456109638, -3.09372314386, 0.817979620215}, - {-44.1113403874, -0.595796803171, 1.29482131972, -0.784513985654, 0.364702038003, - -3.2452492093, -0.451605560847, 0.988546607514, 0.492096628873, -0.343018842342, -0.519231306954}, - {61.2269707872, -0.0289059337716, -1.00409238976, 0.329908621635, 1.41965097539, - 0.0395065997587, -0.477939549336, 0.842336765911, -0.808790019648, 1.70241718768, -0.117194118865}, - {301.434286126, 0.430005308515, 1.01290089725, -0.228221561554, 0.463405921629, - -0.602413489517, 1.13832440088, 0.930949226185, -0.196440161506, 1.46304624346, 1.23831509056}, - {-270.454814681, -1.43805412632, -0.256309572507, -0.358047601174, 0.265151660237, - 1.07087986377, -1.93784654681, -0.854440691754, 0.665691996289, -1.87508012738, -0.387092423365}, - {-97.6198688184, -1.67658167161, -0.170246709551, -2.26863722189, 0.280289356338, - -0.690038347855, -1.69282684019, 0.978606053022, 1.28237852256, -1.2941998486, 0.766405365374}, - {-29.5630902399, -1.75615633921, 0.633927486329, -1.24117311555, -0.15884687004, - 0.31296863712, -1.29513272039, 0.344090683606, 1.19598425093, -1.96195019104, 1.81415061059}, - {-130.896377427, 0.577719366939, -0.087267771748, -0.060088767013, 0.469803880788, - -1.03078212088, -1.41547398887, 1.38980586981, -0.37118000595, -1.81689513712, -0.3099432567}, - {79.6300698059, 1.23408625633, 1.06464588017, 1.23403332691, -1.10993859098, - 0.874825200577, 0.589337796957, -1.10266185141, 0.842960469618, -0.89231962021, 0.284074900504}, - {-154.712112815, -1.64474237898, -0.328581696933, 0.38834343178, 0.02682160335, - -0.251167527796, -0.199330632103, -0.0405837345525, -0.908200250794, -1.3283756975, 0.540894408264}, - {233.447381562, 0.395156450609, 0.156412599781, 0.126453148554, 2.40829068933, - 1.01623530754, -0.0856520211145, -0.874970377099, 0.280617145254, -0.307070438514, 0.4599616054}, - {209.012380432, -0.848646647675, 0.558383548084, -0.259628264419, 1.1624126549, - -0.0755949979572, -0.373930759448, 0.985903312667, 0.435839508011, -0.760916312668, 1.89847574116}, - {-39.8987262091, 0.176656582642, 0.508538223618, 0.995038391204, -2.08809409812, - 0.743926580134, 0.246007971514, -0.458288599906, -0.579976479473, 0.0591577146017, 1.64321662761}, - {222.078510236, -0.24031989218, -0.168104260522, -0.727838425954, 0.557181757624, - -0.164906646307, 2.01559331734, 0.897263594222, 0.0921535309562, 0.351910490325, -0.018228500121}, - {-250.916272061, -2.71504637339, 0.498966191294, -3.16410707344, -0.842488891776, - 1.27425275951, 0.0141733666756, 0.695942743199, 0.0917995810179, -0.501447196978, -0.355738068451}, - {134.07259088, 0.0845637591619, 0.237410106679, -0.291458113729, 1.39418566986, - -1.18813057956, -0.683117067763, -0.518910379335, 1.35998426879, -1.28404562245, 0.489131754943}, - {104.988440209, 0.00770925058526, 0.47113239214, -0.606231247854, 0.310679840217, - 0.146297599928, 0.732013998647, -0.284544010865, 0.402622530153, -0.0217367745613, 0.0742970687987}, - {155.558071031, 1.11171654653, 0.726629222799, -0.195820863177, 0.801333855535, - 0.744034755544, 1.11377275513, -0.75673532139, -0.114117607244, -0.158966474923, -0.29701120385}, - {90.7600194013, -0.104364079622, -0.0165109945217, 0.933002972987, -1.80652594466, - -1.34760892883, -0.304511906801, 0.0584734540581, 1.5332169392, 0.478835797824, 1.71534051065}, - {-313.910553214, 0.149908925551, 0.232806828559, -0.0708920471592, -0.0649553559745, - 0.377753357707, -0.957292311668, 0.545360522582, -1.37905464371, -0.940702110994, -1.53620430047}, - {-80.9380113754, 0.135586606896, 0.95759558815, -1.36879020479, 0.735413996144, - 0.637984100201, -1.79563152885, 1.55025691631, 0.634702068786, -0.203690334141, -0.83954824721}, - {-244.336816695, -0.179127343947, -2.12396005014, -0.431179356484, -0.860562153749, - -1.10270688639, -0.986886012982, -0.945091656162, -0.445428453767, 1.32269756209, -0.223712672168}, - {123.069612745, 0.703857129626, 0.291605144784, 1.40233051946, 0.278603787802, - -0.693567967466, -0.15587953395, 2.10213915684, 0.130663329174, -0.393184478882, 0.0874812844555}, - {-148.274944223, 1.66294967732, 0.0830002694123, 0.32492930502, 1.11864359687, - -0.381901627785, -1.06367037132, -0.392583620174, -1.16283326187, 0.104931461025, -1.64719611405}, - {-82.0018788235, 0.497118817453, 0.731125358012, -0.00976413646786, -0.0178930713492, - -0.814978582886, 0.0602834712523, -0.661940479055, -0.957902899386, -1.34489251111, 0.22166518707}, - {-35.742996986, 0.0661349516701, -0.204314495629, 1.17101314753, -2.53846825562, - -0.560282479298, -0.393442894828, 0.988953809491, -0.911281277704, 0.86862242698, 2.59576940486}, - {-109.588885664, -0.0793151346628, -0.408962434518, -0.598817776528, 0.0277205469561, - 0.116291018958, 0.0280416838086, -0.72544170676, -0.669302814774, 0.0751898759816, -0.311002356179}, - {57.8285173441, 0.53753903532, 0.676340503752, -2.10608342721, 0.477714987751, - 0.465695114442, 0.245966562421, -1.05230350808, -0.309794163113, -1.12067331828, 1.07841453304}, - {204.660622582, -0.717565166685, 0.295179660279, -0.377579912697, 1.88425526905, - 0.251875238436, -0.900214103232, -1.02877401105, 0.291693915093, 1.24889067987, 1.78506220081}, - {350.949109103, 2.82276814452, -0.429358342127, 1.12140362367, 1.18120725208, - -1.63913834939, 1.61441562446, -0.364003766916, -0.258752942225, -0.808124680189, 0.556463488303}, - {170.960252153, 0.147245922081, 0.3257117575, 0.211749283649, -0.0150701808404, - -0.888523132148, 0.777862088798, 0.296729270892, -0.332927550718, 0.888968144245, 1.20913118467}, - {112.192270383, 0.129846138824, -0.934371449036, -0.595825303214, 1.74749214629, - -0.0500069421443, -0.161976298602, -2.54100791613, 1.99632530735, -0.0691582773758, -0.863939367415}, - {-56.7847711121, 0.0950532853751, -0.467349228201, -0.26457152362, -0.422134692317, - -0.0734763062127, 0.90128235602, -1.68470856275, -0.0699692697335, -0.463335845504, -0.301754321169}, - {-37.9223252258, -1.40835827778, 0.566142056244, -3.22393318933, 0.228823495106, - -1.8480727782, 0.129468321643, -1.77392686536, 0.0112549619662, 0.146433267822, 1.29379901303}, - {-59.7303066136, 0.835675535576, -0.552173157548, 1.90730898966, -0.520145317195, - 1.55174485912, -1.37531768692, -0.408165743742, 0.0939675842223, 0.318004128812, 0.324378038446}, - {-0.916090786983, 0.425763794043, -0.295541268984, -0.066619586336, 2.03494974978, - -0.197109278058, -0.823307883209, 0.895531446352, -0.276435938737, -1.54580056755, -0.820051830246}, - {-20.3601082842, 0.56420556369, 0.741234589387, -0.565853617392, -0.311399905686, - 2.24066463251, -0.071704904286, -1.22796531596, 0.186020404046, -0.786874824874, 0.23140277151}, - {-22.9342855182, -0.0682789648279, -1.30680909143, 0.0486490588348, 0.890275695028, - -0.257961411112, -0.381531755985, 1.56251482581, -2.11808219232, 0.741828675202, 0.696388901165}, - {-157.251026807, -2.3120966502, 0.183734662375, 1.02192264962, 0.591272941061, - -0.0132855098339, -1.02016546348, 1.19642432892, 0.867653154846, -1.37600041722, -1.08542822792}, - {-68.6110752055, -1.2429968179, -0.950064269349, -0.332379873336, 0.25793632341, - 0.145780713577, -0.512109283074, -0.477887632032, 0.448960776324, -0.190215737958, 0.219578347563}, - {-56.1204152481, -0.811729480846, -0.647410362207, 0.934547463984, -0.390943346216, - -0.409981308474, 0.0923465893049, 1.9281242912, -0.624713581674, -0.0599353282306, -0.0188591746808}, - {348.530651658, 2.51721790231, 0.7560998114, -2.69620396681, 0.5174276585, - 0.403570816695, 0.901648571306, 0.269313230294, 1.07811463589, 0.986649559679, 0.514710327657}, - {-105.719065924, 0.679016972998, 0.341319363316, -0.515209647377, 0.800000866847, - -0.795474442628, -0.866849274801, -1.32927961486, 0.17679343917, -1.93744422464, -0.476447619273}, - {-197.389429553, -1.98585668879, -0.962610549884, -2.48860863254, -0.545990524642, - -0.13005685654, -1.23413782366, 1.17443427507, 1.4785554038, -0.193717671824, -0.466403609229}, - {-23.9625285402, -0.392164367603, 1.07583388583, -0.412686712477, -0.89339030785, - -0.774862334739, -0.186491999529, -0.300162444329, 0.177377235999, 0.134038296039, 0.957945226616}, - {-91.145725943, -0.154640540119, 0.732911957939, -0.206326119636, -0.569816760116, - 0.249393336416, -1.02762332953, 0.25096708081, 0.386927162941, -0.346382299592, 0.243099162109}, - {-80.7295722208, -1.72670707303, 0.138139045677, 0.0648055728598, 0.186182854422, - 1.07226527747, -1.26133459043, 0.213883744163, 1.47115466163, -1.54791582859, 0.170924664865}, - {-317.060323531, -0.349785690206, -0.740759426066, -0.407970845617, -0.689282767277, - -1.25608665316, -0.772546119412, -2.02925712813, 0.132949072522, -0.191465137244, -1.29079690284}, - {-252.491508279, -1.24643122869, 1.55335609203, 0.356613424877, 0.817434495353, - -1.74503747683, -0.818046363088, -1.58284235058, 0.357919389759, -1.18942962791, -1.91728745247}, - {-66.8121363157, -0.584246455697, -0.104254351782, 1.17911687508, -0.29288167882, - 0.891836132692, 0.232853863255, 0.423294355343, -0.669493690103, -1.15783890498, 0.188213983735}, - {140.681464689, 1.33156046873, -1.8847915949, -0.666528837988, -0.513356191443, - 0.281290031669, -1.07815005006, 1.22384196227, 1.39093631269, 0.527644817197, 1.21595221509}, - {-174.22326767, 0.475428766034, 0.856847216768, -0.734282773151, -0.923514989791, - 0.917510828772, 0.674878068543, 0.0644776431114, -0.607796192908, 0.867740011912, -1.97799769281}, - {74.3899799579, 0.00915743526294, 0.553578683413, 1.66930486354, 0.15562803404, - 1.8455840688, -0.371704942927, 1.11228894843, -0.37464389118, -0.48789151589, 0.79553866342}, - {70.1167175897, 0.154877045187, 1.47803572976, -0.0355743163524, -2.47914644675, - 0.672384381837, 1.63160379529, 1.81874583854, 1.22797339421, -0.0131258061634, -0.390265963676}, - {-11.0364788877, 0.173049156249, -1.78140521797, -1.29982707214, -0.48025663179, - -0.469112922302, -1.98718063269, 0.585086542043, 0.264611327837, 1.48855512579, 2.00672263496}, - {-112.711292736, -1.59239636827, -0.600613018822, -0.0209667499746, -1.81872893331, - -0.739893084955, 0.140261888569, -0.498107678308, 2.53664045504, -0.536385019089, -0.608755809378}, - {-198.064468217, 0.737175509877, -2.01835515547, -2.18045950065, 0.428584922529, - -1.01848835019, -0.470645361539, -0.00703630153547, -2.2341302754, 1.51483167022, -0.410184418418}, - {70.2747963991, 1.49474111532, -0.19517712503, 0.7392852909, -0.326060871666, - -0.566710349675, 0.14053094122, -0.562830341306, 0.22931613446, -0.0344439061448, 0.175150510551}, - {207.909021337, 0.839887009159, 0.268826583246, -0.313047158862, 1.12009996015, - 0.214209976971, -0.396147338251, 2.16039704403, 0.699141312749, 0.756192350992, -0.145368196901}, - {169.428609429, -1.13702350819, 1.23964530597, -0.864443556622, -0.885630795949, - -0.523872327352, 0.467159824748, 0.476596383923, 0.4343735578, 1.4075417896, 2.22939328991}, - {-176.909833405, 0.0875512760866, -0.455542269288, 0.539742307764, -0.762003092788, - 0.41829123457, -0.818116139644, -2.01761645956, 0.557395073218, 1.5823271814, -1.0168826293}, - {-27.734298611, -0.841257541979, 0.348961259301, 1.36935991472, -0.0694528057586, - -1.27303784913, 0.152155656569, 1.9279466651, 0.9589415766, -1.76634370106, -1.08831026428}, - {-55.8416853588, 0.927711536927, 0.157856746063, -0.295628714893, 0.0296602829783, - 1.75198587897, -0.38285446366, -0.253287154535, -1.64032395229, -0.842089054965, 1.00493779183}, - {56.0899797005, 0.326117761734, -1.93514762146, 1.0229172721, 0.125568968732, - 2.37760000658, -0.498532972011, -0.733375842271, -0.757445726993, -0.49515057432, 2.01559891524}, - {-176.220234909, 1.571129843, -0.867707605929, -0.709690799512, -1.51535538937, - 1.27424225477, -0.109513704468, -1.46822183, 0.281077088939, -1.97084024232, -0.322309524179}, - {37.7155152941, 0.363383774219, -0.0240881298641, -1.60692745228, -1.26961656439, - -0.41299134216, 1.2890099968, -1.34101694629, -0.455387485256, -0.14055003482, 1.5407059956}, - {-102.163416997, -2.05927378316, -0.470182865756, -0.875528863204, 0.0361720859253, - -1.03713912263, 0.417362606334, 0.707587625276, -0.0591627772581, -2.58905252006, 0.516573345216}, - {-206.47095321, 0.270030584651, 1.85544202116, -0.144189208964, -0.696400687327, - 0.0226388634283, -0.490952489106, -1.69209527849, 0.00973614309272, -0.484105876992, -0.991474668217}, - {201.50637416, 0.513659215697, -0.335630132208, -0.140006500483, 0.149679720127, - -1.89526167503, -0.0614973894156, 0.0813221153552, 0.630952530848, 2.40201011339, 0.997708264073}, - {-72.0667371571, 0.0841570292899, -0.216125859013, -1.77155215764, 2.15081767322, - 0.00953341785443, -1.0826077946, -0.791135571106, -0.989393577892, -0.791485083644, -0.063560999686}, - {-162.903837815, -0.273764637097, 0.282387854873, -1.39881596931, 0.554941097854, - -0.88790718926, -0.693189960902, 0.398762630571, -1.61878562893, -0.345976341096, 0.138298909959}, - {-34.3291926715, -0.499883755911, -0.847296893019, -0.323673126437, 0.531205373462, - -0.0204345595983, 0.284954510306, 0.565031773028, -0.272049818708, -0.130369799738, -0.617572026201}, - {76.1272883187, -0.908810282403, -1.04139421904, 0.890678872055, 1.32990256154, - -0.0150445428835, 0.593918101047, 0.356897732999, 0.824651162423, -1.54544256217, -0.795703905296}, - {171.833705285, -0.0425219657568, -0.884042952325, 1.91202504537, 0.381908223898, - -0.205693527739, 1.53656598237, 0.534880398015, 0.291950716831, -1.1258051056, -0.0612803476297}, - {-235.445792009, 0.261252102941, -0.170931758001, 1.67878144235, 0.0278283741792, - -1.23194408479, -0.190931886594, 1.0000157972, -2.18792142659, -0.230654984288, -1.36626493512}, - {348.968834231, 1.35713154434, 0.950377770072, 0.0700577471848, 0.96907140156, - 2.00890422081, 0.0896405239806, 0.614309607351, 1.07723409067, 2.58506968136, 0.202889806148}, - {-61.0128039201, 0.465438505031, -1.31448530533, 0.374781933416, -0.0118298606041, - -0.477338357738, -0.587656108109, 1.66449545077, 0.435836048385, -0.287027953004, -1.06613472784}, - {-50.687090469, 0.382331825989, -0.597140322197, 1.1276065465, -1.35593777887, - 1.14949964423, -0.858742432885, -0.563211485633, -0.57167161928, 0.0294891749132, 1.9571639493}, - {-186.653649045, -0.00981380006029, 1.0371088941, -1.25319048981, -0.694043021068, - 1.7280802541, -0.191210409232, -0.866039238001, -0.0791927416078, -0.232228656558, -0.93723545053}, - {34.5395591744, 0.680943971029, -0.075875481801, -0.144408300848, -0.869070791528, - 0.496870904214, 1.0940401388, -0.510489750436, -0.47562728601, 0.951406841944, 0.12983846382}, - {-23.7618645627, 0.527032820313, -0.58295129357, -0.3894567306, -0.0547905472556, - -1.86103603537, 0.0506988360667, 1.02778539291, -0.0613720063422, 0.411280841442, -0.665810811374}, - {116.007776415, 0.441750249008, 0.549342185228, 0.731558201455, -0.903624700864, - -2.13208328824, 0.381223328983, 0.283479210749, 1.17705098922, -2.38800904207, 1.32108350152}, - {-148.479593311, -0.814604260049, -0.821204361946, -1.08768677334, -0.0659445766599, - 0.583741297405, 0.669345853296, -0.0935352010726, -0.254906787938, -0.394599725657, -1.26305927257}, - {244.865845084, 0.776784257443, 0.267205388558, 2.37746488031, -0.379275360853, - -0.157454754411, -0.359580726073, 0.886887721861, 1.53707627973, 0.634390546684, 0.984864824122}, - {-81.9954096721, 0.594841146008, -1.22273253129, 0.532466794358, 1.69864239257, - -0.12293671327, -2.06645974171, 0.611808231703, -1.32291985291, 0.722066660478, -0.0021343848511}, - {-245.715046329, -1.77850303496, -0.176518810079, 1.20463434525, -0.597826204963, - -1.45842350123, -0.765730251727, -2.17764204443, 0.12996635702, -0.705509516482, 0.170639846082}, - {123.011946043, -0.909707162714, 0.92357208515, 0.373251929121, 1.24629576577, - 0.0662688299998, -0.372240547929, -0.739353735168, 0.323495756066, 0.954154005738, 0.69606859977}, - {-70.4564963177, 0.650682297051, 0.378131376232, 1.37860253614, -0.924042783872, - 0.802851073842, -0.450299927542, 0.235646185302, -0.148779896161, 1.01308126122, -0.48206889502}, - {21.5288687935, 0.290876355386, 0.0765702960599, 0.905225489744, 0.252841861521, - 1.26729272819, 0.315397441908, -2.00317261368, -0.250990653758, 0.425615332405, 0.0875320802483}, - {231.370169905, 0.535138021352, -1.07151617232, 0.824383756287, 1.84428896701, - -0.890892034494, 0.0480296332924, -0.59251208055, 0.267564961845, -0.230698441998, 0.857077278291}, - {38.8318274023, 2.63547217711, -0.585553060394, 0.430550920323, -0.532619160993, - 1.25335488136, -1.65265278435, 0.0433880112291, -0.166143379872, 0.534066441314, 1.18929937797}, - {116.362219013, -0.275949982433, 0.468069787645, -0.879814121059, 0.862799331322, - 1.18464846725, 0.747084253268, 1.39202500691, -1.23374181275, 0.0949815110503, 0.696546907194}, - {260.540154731, 1.13798788241, -0.0991903174656, 0.1241636043, -0.201415073037, - 1.57683389508, 1.81535629587, 1.07873616646, -0.355800782882, 2.18333193195, 0.0711071144615}, - {-165.835194521, -2.76613178307, 0.805314338858, 0.81526046683, -0.710489036197, - -1.20189542317, -0.692110074722, -0.117239516622, 1.0431459458, -0.111898596299, -0.0775811519297}, - {-341.189958588, 0.668555635008, -1.0940034941, -0.497881262778, -0.603682823779, - -0.396875163796, -0.849144848521, 0.403936807183, -1.82076277475, -0.137500972546, -1.22769896568} - }, new double[] {45.8685095528, 11.9400336005, 16.3984976652, 79.9069814034, 5.65486853464, - 83.6427296424, 27.4571268153, 73.5881193584, 27.1465364511, 79.4095449062}, -5.14077007134); - - /** */ - public static class TestDataset { - - /** */ - private final double[][] data; - - /** */ - private final double[] expWeights; - - /** */ - private final double expIntercept; - - /** */ - TestDataset(double[][] data, double[] expWeights, double expIntercept) { - this.data = data; - this.expWeights = expWeights; - this.expIntercept = expIntercept; - } - - /** */ - public double[][] getData() { - return data; - } - - /** */ - public double[] getExpWeights() { - return expWeights; - } - - /** */ - public double getExpIntercept() { - return expIntercept; - } - } -} \ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java deleted file mode 100644 index 0c09d75..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionQRTrainerTest.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.linear; - -import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix; -import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector; - -/** - * Tests for {@link LinearRegressionQRTrainer} on {@link SparseBlockDistributedMatrix}. - */ -public class BlockDistributedLinearRegressionQRTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest { - /** */ - public BlockDistributedLinearRegressionQRTrainerTest() { - super( - new LinearRegressionQRTrainer(), - SparseBlockDistributedMatrix::new, - SparseBlockDistributedVector::new, - 1e-6 - ); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java deleted file mode 100644 index 2a506d9..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionQRTrainerTest.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.linear; - -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; - -/** - * Tests for {@link LinearRegressionQRTrainer} on {@link SparseDistributedMatrix}. - */ -public class DistributedLinearRegressionQRTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest { - /** */ - public DistributedLinearRegressionQRTrainerTest() { - super( - new LinearRegressionQRTrainer(), - SparseDistributedMatrix::new, - SparseDistributedVector::new, - 1e-6 - ); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java deleted file mode 100644 index a55623c..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GenericLinearRegressionTrainerTest.java +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.linear; - -import java.util.Scanner; -import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.Trainer; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.junit.Test; - -/** - * Base class for all linear regression trainers. - */ -public class GenericLinearRegressionTrainerTest { - /** */ - private final Trainer<LinearRegressionModel, Matrix> trainer; - - /** */ - private final IgniteFunction<double[][], Matrix> matrixCreator; - - /** */ - private final IgniteFunction<double[], Vector> vectorCreator; - - /** */ - private final double precision; - - /** */ - public GenericLinearRegressionTrainerTest( - Trainer<LinearRegressionModel, Matrix> trainer, - IgniteFunction<double[][], Matrix> matrixCreator, - IgniteFunction<double[], Vector> vectorCreator, - double precision) { - this.trainer = trainer; - this.matrixCreator = matrixCreator; - this.vectorCreator = vectorCreator; - this.precision = precision; - } - - /** - * Test trainer on regression model y = 2 * x. - */ - @Test - public void testTrainWithoutIntercept() { - Matrix data = matrixCreator.apply(new double[][] { - {2.0, 1.0}, - {4.0, 2.0} - }); - - LinearRegressionModel mdl = trainer.train(data); - - TestUtils.assertEquals(4, mdl.apply(vectorCreator.apply(new double[] {2})), precision); - TestUtils.assertEquals(6, mdl.apply(vectorCreator.apply(new double[] {3})), precision); - TestUtils.assertEquals(8, mdl.apply(vectorCreator.apply(new double[] {4})), precision); - } - - /** - * Test trainer on regression model y = -1 * x + 1. - */ - @Test - public void testTrainWithIntercept() { - Matrix data = matrixCreator.apply(new double[][] { - {1.0, 0.0}, - {0.0, 1.0} - }); - - LinearRegressionModel mdl = trainer.train(data); - - TestUtils.assertEquals(0.5, mdl.apply(vectorCreator.apply(new double[] {0.5})), precision); - TestUtils.assertEquals(2, mdl.apply(vectorCreator.apply(new double[] {-1})), precision); - TestUtils.assertEquals(-1, mdl.apply(vectorCreator.apply(new double[] {2})), precision); - } - - /** - * Test trainer on diabetes dataset. - */ - @Test - public void testTrainOnDiabetesDataset() { - Matrix data = loadDataset("datasets/regression/diabetes.csv", 442, 10); - - LinearRegressionModel mdl = trainer.train(data); - - Vector expWeights = vectorCreator.apply(new double[] { - -10.01219782, -239.81908937, 519.83978679, 324.39042769, -792.18416163, - 476.74583782, 101.04457032, 177.06417623, 751.27932109, 67.62538639 - }); - - double expIntercept = 152.13348416; - - TestUtils.assertEquals("Wrong weights", expWeights, mdl.getWeights(), precision); - TestUtils.assertEquals("Wrong intercept", expIntercept, mdl.getIntercept(), precision); - } - - /** - * Test trainer on boston dataset. - */ - @Test - public void testTrainOnBostonDataset() { - Matrix data = loadDataset("datasets/regression/boston.csv", 506, 13); - - LinearRegressionModel mdl = trainer.train(data); - - Vector expWeights = vectorCreator.apply(new double[] { - -1.07170557e-01, 4.63952195e-02, 2.08602395e-02, 2.68856140e+00, -1.77957587e+01, 3.80475246e+00, - 7.51061703e-04, -1.47575880e+00, 3.05655038e-01, -1.23293463e-02, -9.53463555e-01, 9.39251272e-03, - -5.25466633e-01 - }); - - double expIntercept = 36.4911032804; - - TestUtils.assertEquals("Wrong weights", expWeights, mdl.getWeights(), precision); - TestUtils.assertEquals("Wrong intercept", expIntercept, mdl.getIntercept(), precision); - } - - /** - * Tests trainer on artificial dataset with 10 observations described by 1 feature. - */ - @Test - public void testTrainOnArtificialDataset10x1() { - ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression10x1; - - LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData())); - - TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision); - TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision); - } - - /** - * Tests trainer on artificial dataset with 10 observations described by 5 features. - */ - @Test - public void testTrainOnArtificialDataset10x5() { - ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression10x5; - - LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData())); - - TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision); - TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision); - } - - /** - * Tests trainer on artificial dataset with 100 observations described by 5 features. - */ - @Test - public void testTrainOnArtificialDataset100x5() { - ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression100x5; - - LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData())); - - TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision); - TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision); - } - - /** - * Tests trainer on artificial dataset with 100 observations described by 10 features. - */ - @Test - public void testTrainOnArtificialDataset100x10() { - ArtificialRegressionDatasets.TestDataset dataset = ArtificialRegressionDatasets.regression100x10; - - LinearRegressionModel mdl = trainer.train(matrixCreator.apply(dataset.getData())); - - TestUtils.assertEquals("Wrong weights", dataset.getExpWeights(), mdl.getWeights(), precision); - TestUtils.assertEquals("Wrong intercept", dataset.getExpIntercept(), mdl.getIntercept(), precision); - } - - /** - * Loads dataset file and returns corresponding matrix. - * - * @param fileName Dataset file name - * @param nobs Number of observations - * @param nvars Number of features - * @return Data matrix - */ - private Matrix loadDataset(String fileName, int nobs, int nvars) { - double[][] matrix = new double[nobs][nvars + 1]; - Scanner scanner = new Scanner(this.getClass().getClassLoader().getResourceAsStream(fileName)); - int i = 0; - while (scanner.hasNextLine()) { - String row = scanner.nextLine(); - int j = 0; - for (String feature : row.split(",")) { - matrix[i][j] = Double.parseDouble(feature); - j++; - } - i++; - } - return matrixCreator.apply(matrix); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java deleted file mode 100644 index 9b75bd4..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.linear; - -import org.apache.ignite.Ignite; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.Trainer; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; -import org.junit.Test; - -/** - * Grid aware abstract linear regression trainer test. - */ -public abstract class GridAwareAbstractLinearRegressionTrainerTest extends GridCommonAbstractTest { - /** Number of nodes in grid */ - private static final int NODE_COUNT = 3; - - /** - * Delegate actually performs tests. - */ - private final GenericLinearRegressionTrainerTest delegate; - - /** */ - private Ignite ignite; - - /** */ - public GridAwareAbstractLinearRegressionTrainerTest( - Trainer<LinearRegressionModel, Matrix> trainer, - IgniteFunction<double[][], Matrix> matrixCreator, - IgniteFunction<double[], Vector> vectorCreator, - double precision) { - delegate = new GenericLinearRegressionTrainerTest(trainer, matrixCreator, vectorCreator, precision); - } - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() { - stopAllGrids(); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() throws Exception { - /* Grid instance. */ - ignite = grid(NODE_COUNT); - ignite.configuration().setPeerClassLoadingEnabled(true); - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - } - - /** - * Test trainer on regression model y = 2 * x. - */ - @Test - public void testTrainWithoutIntercept() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - delegate.testTrainWithoutIntercept(); - } - - /** - * Test trainer on regression model y = -1 * x + 1. - */ - @Test - public void testTrainWithIntercept() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - delegate.testTrainWithIntercept(); - } - - /** - * Tests trainer on artificial dataset with 10 observations described by 1 feature. - */ - @Test - public void testTrainOnArtificialDataset10x1() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - delegate.testTrainOnArtificialDataset10x1(); - } - - /** - * Tests trainer on artificial dataset with 10 observations described by 5 features. - */ - @Test - public void testTrainOnArtificialDataset10x5() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - delegate.testTrainOnArtificialDataset10x5(); - } - - /** - * Tests trainer on artificial dataset with 100 observations described by 5 features. - */ - @Test - public void testTrainOnArtificialDataset100x5() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - delegate.testTrainOnArtificialDataset100x5(); - } - - /** - * Tests trainer on artificial dataset with 100 observations described by 10 features. - */ - @Test - public void testTrainOnArtificialDataset100x10() { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - delegate.testTrainOnArtificialDataset100x10(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java index fa8fac4..c62cca5 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java @@ -19,7 +19,7 @@ package org.apache.ignite.ml.regressions.linear; import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.trainers.group.UpdatesStrategy; +import org.apache.ignite.ml.nn.UpdatesStrategy; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionQRTrainerTest.java deleted file mode 100644 index f37d71d..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LocalLinearRegressionQRTrainerTest.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.linear; - -import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; - -/** - * Tests for {@link LinearRegressionQRTrainer} on {@link DenseLocalOnHeapMatrix}. - */ -public class LocalLinearRegressionQRTrainerTest extends GenericLinearRegressionTrainerTest { - /** */ - public LocalLinearRegressionQRTrainerTest() { - super( - new LinearRegressionQRTrainer(), - DenseLocalOnHeapMatrix::new, - DenseLocalOnHeapVector::new, - 1e-6 - ); - } -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/DistributedWorkersChainTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/DistributedWorkersChainTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/DistributedWorkersChainTest.java deleted file mode 100644 index 7ad59d1..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/DistributedWorkersChainTest.java +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.stream.Stream; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.trainers.group.chain.Chains; -import org.apache.ignite.ml.trainers.group.chain.ComputationsChain; -import org.apache.ignite.ml.trainers.group.chain.EntryAndContext; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; -import org.junit.Assert; - -/** */ -public class DistributedWorkersChainTest extends GridCommonAbstractTest { - /** Count of nodes. */ - private static final int NODE_COUNT = 3; - - /** Grid instance. */ - protected Ignite ignite; - - /** - * Default constructor. - */ - public DistributedWorkersChainTest() { - super(false); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() throws Exception { - ignite = grid(NODE_COUNT); - TestGroupTrainingCache.getOrCreate(ignite).removeAll(); - TestGroupTrainingSecondCache.getOrCreate(ignite).removeAll(); - } - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() throws Exception { - stopAllGrids(); - } - - /** */ - public void testId() { - ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create(); - - UUID trainingUUID = UUID.randomUUID(); - Integer res = chain.process(1, new GroupTrainingContext<>(new TestLocalContext(0, trainingUUID), TestGroupTrainingCache.getOrCreate(ignite), ignite)); - - Assert.assertEquals(1L, (long)res); - } - - /** */ - public void testSimpleLocal() { - ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create(); - - IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite); - int init = 1; - int initLocCtxData = 0; - UUID trainingUUID = UUID.randomUUID(); - TestLocalContext locCtx = new TestLocalContext(initLocCtxData, trainingUUID); - - Integer res = chain. - thenLocally((prev, lc) -> prev + 1). - process(init, new GroupTrainingContext<>(locCtx, cache, ignite)); - - Assert.assertEquals(init + 1, (long)res); - Assert.assertEquals(initLocCtxData, locCtx.data()); - } - - /** */ - public void testChainLocal() { - ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create(); - - IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite); - int init = 1; - int initLocCtxData = 0; - UUID trainingUUID = UUID.randomUUID(); - TestLocalContext locCtx = new TestLocalContext(initLocCtxData, trainingUUID); - - Integer res = chain. - thenLocally((prev, lc) -> prev + 1). - thenLocally((prev, lc) -> prev * 5). - process(init, new GroupTrainingContext<>(locCtx, cache, ignite)); - - Assert.assertEquals((init + 1) * 5, (long)res); - Assert.assertEquals(initLocCtxData, locCtx.data()); - } - - /** */ - public void testChangeLocalContext() { - ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create(); - IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite); - int init = 1; - int newData = 10; - UUID trainingUUID = UUID.randomUUID(); - TestLocalContext locCtx = new TestLocalContext(0, trainingUUID); - - Integer res = chain. - thenLocally((prev, lc) -> { lc.setData(newData); return prev;}). - process(init, new GroupTrainingContext<>(locCtx, cache, ignite)); - - Assert.assertEquals(newData, locCtx.data()); - Assert.assertEquals(init, res.intValue()); - } - - /** */ - public void testDistributed() { - ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create(); - IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite); - int init = 1; - UUID trainingUUID = UUID.randomUUID(); - TestLocalContext locCtx = new TestLocalContext(0, trainingUUID); - - Map<GroupTrainerCacheKey<Double>, Integer> m = new HashMap<>(); - m.put(new GroupTrainerCacheKey<>(0L, 1.0, trainingUUID), 1); - m.put(new GroupTrainerCacheKey<>(1L, 2.0, trainingUUID), 2); - m.put(new GroupTrainerCacheKey<>(2L, 3.0, trainingUUID), 3); - m.put(new GroupTrainerCacheKey<>(3L, 4.0, trainingUUID), 4); - - Stream<GroupTrainerCacheKey<Double>> keys = m.keySet().stream(); - - cache.putAll(m); - - IgniteBiFunction<Integer, TestLocalContext, IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>>> function = (o, l) -> () -> keys; - IgniteFunction<List<Integer>, Integer> max = ints -> ints.stream().mapToInt(x -> x).max().orElse(Integer.MIN_VALUE); - - Integer res = chain. - thenDistributedForEntries((integer, context) -> () -> null, this::readAndIncrement, function, max). - process(init, new GroupTrainingContext<>(locCtx, cache, ignite)); - - int localMax = m.values().stream().max(Comparator.comparingInt(i -> i)).orElse(Integer.MIN_VALUE); - - assertEquals((long)localMax, (long)res); - - for (GroupTrainerCacheKey<Double> key : m.keySet()) - m.compute(key, (k, v) -> v + 1); - - assertMapEqualsCache(m, cache); - } - - /** */ - private ResultAndUpdates<Integer> readAndIncrement(EntryAndContext<Double, Integer, Void> ec) { - Integer val = ec.entry().getValue(); - - ResultAndUpdates<Integer> res = ResultAndUpdates.of(val); - res.updateCache(TestGroupTrainingCache.getOrCreate(Ignition.localIgnite()), ec.entry().getKey(), val + 1); - - return res; - } - - /** */ - private <K, V> void assertMapEqualsCache(Map<K, V> m, IgniteCache<K, V> cache) { - assertEquals(m.size(), cache.size()); - - for (K k : m.keySet()) - assertEquals(m.get(k), cache.get(k)); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/GroupTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/GroupTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/GroupTrainerTest.java deleted file mode 100644 index 5bb9a47..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/GroupTrainerTest.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.util.HashMap; -import java.util.Map; -import org.apache.ignite.Ignite; -import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; - -/** - * Test of {@link GroupTrainer}. - */ -public class GroupTrainerTest extends GridCommonAbstractTest { - /** Count of nodes. */ - private static final int NODE_COUNT = 3; - - /** Grid instance. */ - private Ignite ignite; - - /** - * Default constructor. - */ - public GroupTrainerTest() { - super(false); - } - - /** - * {@inheritDoc} - */ - @Override protected void beforeTest() throws Exception { - ignite = grid(NODE_COUNT); - TestGroupTrainingCache.getOrCreate(ignite).removeAll(); - TestGroupTrainingSecondCache.getOrCreate(ignite).removeAll(); - } - - /** {@inheritDoc} */ - @Override protected void beforeTestsStarted() throws Exception { - for (int i = 1; i <= NODE_COUNT; i++) - startGrid(i); - } - - /** {@inheritDoc} */ - @Override protected void afterTestsStopped() throws Exception { - stopAllGrids(); - } - - /** */ - public void testGroupTrainer() { - TestGroupTrainer trainer = new TestGroupTrainer(ignite); - - int limit = 5; - int eachNumCnt = 3; - int iterCnt = 2; - - ConstModel<Integer> mdl = trainer.train(new SimpleGroupTrainerInput(limit, eachNumCnt, iterCnt)); - int locRes = computeLocally(limit, eachNumCnt, iterCnt); - assertEquals(locRes, (int)mdl.apply(10)); - } - - /** */ - private int computeLocally(int limit, int eachNumCnt, int iterCnt) { - Map<GroupTrainerCacheKey<Double>, Integer> m = new HashMap<>(); - - for (int i = 0; i < limit; i++) { - for (int j = 0; j < eachNumCnt; j++) - m.put(new GroupTrainerCacheKey<>(i, (double)j, null), i); - } - - for (int i = 0; i < iterCnt; i++) - for (GroupTrainerCacheKey<Double> key : m.keySet()) - m.compute(key, (key1, integer) -> integer * integer); - - return m.values().stream().filter(x -> x % 2 == 0).mapToInt(i -> i).sum(); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/SimpleGroupTrainerInput.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/SimpleGroupTrainerInput.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/SimpleGroupTrainerInput.java deleted file mode 100644 index db1adc7..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/SimpleGroupTrainerInput.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.util.UUID; -import java.util.stream.IntStream; -import java.util.stream.Stream; -import org.apache.ignite.ml.math.functions.IgniteSupplier; - -public class SimpleGroupTrainerInput implements GroupTrainerInput<Double> { - /** */ - private int limit; - - /** */ - private int eachNumberCount; - - /** */ - private int iterCnt; - - /** */ - public SimpleGroupTrainerInput(int limit, int eachNumCnt, int iterCnt) { - this.limit = limit; - this.eachNumberCount = eachNumCnt; - this.iterCnt = iterCnt; - } - - /** {@inheritDoc} */ - @Override public IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>> initialKeys(UUID trainingUUID) { - int lim = limit; - UUID uuid = trainingUUID; - return () -> IntStream.range(0, lim).mapToObj(i -> new GroupTrainerCacheKey<>(i, 0.0, uuid)); - } - - /** */ - public int limit() { - return limit; - } - - /** */ - public int iterCnt() { - return iterCnt; - } - - /** */ - public int eachNumberCount() { - return eachNumberCount; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java deleted file mode 100644 index 0a49fe0..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.util.List; -import java.util.UUID; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; -import org.apache.ignite.ml.trainers.group.chain.Chains; -import org.apache.ignite.ml.trainers.group.chain.ComputationsChain; -import org.apache.ignite.ml.trainers.group.chain.EntryAndContext; - -/** - * Test group trainer. - */ -public class TestGroupTrainer extends GroupTrainer<TestGroupTrainerLocalContext, Double, Integer, Integer, Integer, - Double, ConstModel<Integer>, SimpleGroupTrainerInput, Void> { - /** - * Construct instance of this class with given parameters. - * - * @param ignite Ignite instance. - */ - public TestGroupTrainer(Ignite ignite) { - super(TestGroupTrainingCache.getOrCreate(ignite), ignite); - } - - /** {@inheritDoc} */ - @Override protected TestGroupTrainerLocalContext initialLocalContext(SimpleGroupTrainerInput data, - UUID trainingUUID) { - return new TestGroupTrainerLocalContext(data.iterCnt(), data.eachNumberCount(), data.limit(), trainingUUID); - } - - /** {@inheritDoc} */ - @Override protected IgniteFunction<GroupTrainerCacheKey<Double>, ResultAndUpdates<Integer>> distributedInitializer( - SimpleGroupTrainerInput data) { - return key -> { - long i = key.nodeLocalEntityIndex(); - UUID trainingUUID = key.trainingUUID(); - IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache - = TestGroupTrainingCache.getOrCreate(Ignition.localIgnite()); - - long sum = i * data.eachNumberCount(); - - ResultAndUpdates<Integer> res = ResultAndUpdates.of((int)sum); - - for (int j = 0; j < data.eachNumberCount(); j++) - res.updateCache(cache, new GroupTrainerCacheKey<>(i, (double)j, trainingUUID), (int)i); - - return res; - }; - } - - /** {@inheritDoc} */ - @Override protected IgniteFunction<List<Integer>, Integer> reduceDistributedInitData() { - return id -> id.stream().mapToInt(x -> x).sum(); - } - - /** {@inheritDoc} */ - @Override protected Double locallyProcessInitData(Integer data, TestGroupTrainerLocalContext locCtx) { - return data.doubleValue(); - } - - /** {@inheritDoc} */ - @Override protected ComputationsChain<TestGroupTrainerLocalContext, - Double, Integer, Double, Double> trainingLoopStep() { - // TODO:IGNITE-7405 here we should explicitly create variable because we cannot infer context type, think about it. - ComputationsChain<TestGroupTrainerLocalContext, Double, Integer, Double, Double> chain = Chains. - create(new TestTrainingLoopStep()); - return chain. - thenLocally((aDouble, context) -> { - context.incCnt(); - return aDouble; - }); - } - - /** {@inheritDoc} */ - @Override protected boolean shouldContinue(Double data, TestGroupTrainerLocalContext locCtx) { - return locCtx.cnt() < locCtx.maxCnt(); - } - - /** {@inheritDoc} */ - @Override protected IgniteSupplier<Void> extractContextForFinalResultCreation(Double data, - TestGroupTrainerLocalContext locCtx) { - // No context is needed. - return () -> null; - } - - /** {@inheritDoc} */ - @Override protected IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>> finalResultKeys(Double data, - TestGroupTrainerLocalContext locCtx) { - int limit = locCtx.limit(); - int cnt = locCtx.eachNumberCnt(); - UUID uuid = locCtx.trainingUUID(); - - return () -> TestGroupTrainingCache.allKeys(limit, cnt, uuid); - } - - /** {@inheritDoc} */ - @Override protected IgniteFunction<EntryAndContext<Double, Integer, Void>, - ResultAndUpdates<Integer>> finalResultsExtractor() { - return entryAndCtx -> { - Integer val = entryAndCtx.entry().getValue(); - return ResultAndUpdates.of(val % 2 == 0 ? val : 0); - }; - } - - /** {@inheritDoc} */ - @Override protected IgniteFunction<List<Integer>, Integer> finalResultsReducer() { - return id -> id.stream().mapToInt(x -> x).sum(); - } - - /** {@inheritDoc} */ - @Override protected ConstModel<Integer> mapFinalResult(Integer res, TestGroupTrainerLocalContext locCtx) { - return new ConstModel<>(res); - } - - /** {@inheritDoc} */ - @Override protected void cleanup(TestGroupTrainerLocalContext locCtx) { - Stream<GroupTrainerCacheKey<Double>> toRemote = TestGroupTrainingCache.allKeys(locCtx.limit(), - locCtx.eachNumberCnt(), locCtx.trainingUUID()); - - TestGroupTrainingCache.getOrCreate(ignite).removeAll(toRemote.collect(Collectors.toSet())); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainerLocalContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainerLocalContext.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainerLocalContext.java deleted file mode 100644 index e1a533b..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainerLocalContext.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.util.UUID; -import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID; - -/** */ -public class TestGroupTrainerLocalContext implements HasTrainingUUID { - /** */ - private int cnt = 0; - - /** */ - private int maxCnt; - - /** */ - private int eachNumberCnt; - - /** */ - private int limit; - - /** */ - private UUID trainingUUID; - - /** */ - public TestGroupTrainerLocalContext(int maxCnt, int eachNumberCnt, int limit, UUID trainingUUID) { - this.maxCnt = maxCnt; - this.eachNumberCnt = eachNumberCnt; - this.limit = limit; - this.trainingUUID = trainingUUID; - this.cnt = 0; - } - - /** */ - public int cnt() { - return cnt; - } - - /** */ - public void setCnt(int cnt) { - this.cnt = cnt; - } - - /** */ - public TestGroupTrainerLocalContext incCnt() { - this.cnt++; - - return this; - } - - /** */ - public int maxCnt() { - return maxCnt; - } - - /** */ - public int eachNumberCnt() { - return eachNumberCnt; - } - - /** */ - public int limit() { - return limit; - } - - /** {@inheritDoc} */ - @Override public UUID trainingUUID() { - return trainingUUID; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingCache.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingCache.java deleted file mode 100644 index afee674..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingCache.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.util.Arrays; -import java.util.UUID; -import java.util.stream.Stream; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.cache.CacheAtomicityMode; -import org.apache.ignite.cache.CacheMode; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.configuration.CacheConfiguration; - -/** */ -public class TestGroupTrainingCache { - /** */ - public static String CACHE_NAME = "TEST_GROUP_TRAINING_CACHE"; - - /** */ - public static IgniteCache<GroupTrainerCacheKey<Double>, Integer> getOrCreate(Ignite ignite) { - CacheConfiguration<GroupTrainerCacheKey<Double>, Integer> cfg = new CacheConfiguration<>(); - - // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); - - // Atomic transactions only. - cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); - - // No copying of values. - cfg.setCopyOnRead(false); - - // Cache is partitioned. - cfg.setCacheMode(CacheMode.PARTITIONED); - - cfg.setBackups(0); - - cfg.setOnheapCacheEnabled(true); - - cfg.setName(CACHE_NAME); - - return ignite.getOrCreateCache(cfg); - } - - /** */ - public static Stream<GroupTrainerCacheKey<Double>> allKeys(int limit, int eachNumberCnt, UUID trainingUUID) { - GroupTrainerCacheKey<Double>[] a =new GroupTrainerCacheKey[limit * eachNumberCnt]; - - for (int num = 0; num < limit; num++) - for (int i = 0; i < eachNumberCnt; i++) - a[num * eachNumberCnt + i] = new GroupTrainerCacheKey<>(num, (double)i, trainingUUID); - - return Arrays.stream(a); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingSecondCache.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingSecondCache.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingSecondCache.java deleted file mode 100644 index e16ed7c..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainingSecondCache.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.cache.CacheAtomicityMode; -import org.apache.ignite.cache.CacheMode; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.configuration.CacheConfiguration; - -/** */ -public class TestGroupTrainingSecondCache { - /** */ - public static String CACHE_NAME = "TEST_GROUP_TRAINING_SECOND_CACHE"; - - /** */ - public static IgniteCache<GroupTrainerCacheKey<Character>, Integer> getOrCreate(Ignite ignite) { - CacheConfiguration<GroupTrainerCacheKey<Character>, Integer> cfg = new CacheConfiguration<>(); - - // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); - - // Atomic transactions only. - cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); - - // No copying of values. - cfg.setCopyOnRead(false); - - // Cache is partitioned. - cfg.setCacheMode(CacheMode.PARTITIONED); - - cfg.setBackups(0); - - cfg.setOnheapCacheEnabled(true); - - cfg.setName(CACHE_NAME); - - return ignite.getOrCreateCache(cfg); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/47cfdc27/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestLocalContext.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestLocalContext.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestLocalContext.java deleted file mode 100644 index 3f0237f..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestLocalContext.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.trainers.group; - -import java.util.UUID; -import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID; - -/** */ -public class TestLocalContext implements HasTrainingUUID { - /** */ - private final UUID trainingUUID; - - /** */ - private int data; - - /** */ - public TestLocalContext(int data, UUID trainingUUID) { - this.data = data; - this.trainingUUID = trainingUUID; - } - - /** */ - public int data() { - return data; - } - - /** */ - public void setData(int data) { - this.data = data; - } - - /** {@inheritDoc} */ - @Override public UUID trainingUUID() { - return trainingUUID; - } -}